├── .github ├── ISSUE_TEMPLATE │ ├── ✍️-other-issue.md │ ├── 🐛-bug-report-and-issues.md │ ├── 📑-new-task-dataset-proposals.md │ └── 🚀-feature-request.md └── workflows │ └── python-app.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── assets ├── GoLLIE.png ├── datasets.png ├── snippets │ ├── space.png │ ├── space_guidelines.png │ ├── space_result.png │ ├── space_text.png │ ├── space_transparent.png │ └── space_two_terminals.png └── zero_shot_results.png ├── bash_scripts ├── Baseline-7B_CodeLLaMA.sh ├── GoLLIE-13B_CodeLLaMA.sh ├── GoLLIE-34B_CodeLLaMA.sh ├── GoLLIE-7B_CodeLLaMA.sh ├── GoLLIE-7B_CodeLLaMA_abl_dropout.sh ├── GoLLIE-7B_CodeLLaMA_abl_masking.sh ├── GoLLIE-7B_CodeLLaMA_ablation_candidates.sh ├── GoLLIE-7B_CodeLLaMA_train_full_model.sh ├── eval │ ├── Baseline-7B_CodeLLaMA.sh │ ├── CoLLIE-7B_CodeLLaMA.sh │ ├── CoLLIE-7B_CodeLLaMA_ablation_candidates.sh │ ├── GoLLIE-13B_CodeLLaMA.sh │ ├── GoLLIE-34B_CodeLLaMA.sh │ ├── GoLLIE-7B_CodeLLaMA_abl_dropout.sh │ ├── GoLLIE-7B_CodeLLaMA_abl_masking.sh │ ├── eval_all.sh │ └── run_all_metrics.sh ├── generate_ablation_data.sh ├── generate_data.sh ├── generate_data_modified.sh ├── preprocess_ace.sh ├── reformat_code.sh └── run_paraphrase.sh ├── configs ├── data_configs │ ├── ace_config.json │ ├── bc5cdr_config.json │ ├── broadtwitter_config.json │ ├── casie_config.json │ ├── conll03_config.json │ ├── crossner_ai_config.json │ ├── crossner_ai_wo_misc_config.json │ ├── crossner_literature_config.json │ ├── crossner_literature_wo_misc_config.json │ ├── crossner_music_config.json │ ├── crossner_music_wo_misc_config.json │ ├── crossner_politics_config.json │ ├── crossner_politics_wo_misc_config.json │ ├── crossner_science_config.json │ ├── crossner_science_wo_misc_config.json │ ├── diann_config.json │ ├── e3c_config.json │ ├── europarl_config.json │ ├── fabner_config.json │ ├── harveyner_config.json │ ├── mitmovie_config.json │ ├── mitrestaurant_config.json │ ├── multinerd_config.json │ ├── ncbidisease_config.json │ ├── ontonotes_config.json │ ├── rams_config.json │ ├── tacred_config.json │ ├── wikievents_config.json │ └── wnut17_config.json ├── deepspeed_configs │ ├── deepspeed_zero2.json │ ├── deepspeed_zero2_offload.json │ ├── deepspeed_zero3.json │ └── deepspeed_zero3_offload.json ├── model_configs │ ├── Baseline-7B_CodeLLaMA.yaml │ ├── GoLLIE-13B_CodeLLaMA.yaml │ ├── GoLLIE-34B_CodeLLaMA.yaml │ ├── GoLLIE-7B_CodeLLaMA.yaml │ ├── GoLLIE-7B_CodeLLaMA_ablation_candiates.yaml │ ├── GoLLIE-7B_CodeLLaMA_ablation_dropout.yaml │ ├── GoLLIE-7B_CodeLLaMA_ablation_masking.yaml │ ├── GoLLIE-7B_CodeLLaMA_train_full_model.yaml │ └── eval │ │ ├── Baseline-7B_CodeLLaMA.yaml │ │ ├── CoLLIE-7B_CodeLLaMA.yaml │ │ ├── GoLLIE-13B_CodeLLaMA.yaml │ │ ├── GoLLIE-34B_CodeLLaMA.yaml │ │ ├── GoLLIE-7B_CodeLLaMA_ablation_candidates.yaml │ │ ├── GoLLIE-7B_CodeLLaMA_ablation_dropout.yaml │ │ └── GoLLIE-7B_CodeLLaMA_ablation_masking.yaml └── pharapharse_config │ ├── LlaMA2-Chat.yaml │ ├── Vicunav1.3-33B.yaml │ ├── generation_config.json │ └── gpt2.yaml ├── docs ├── _layouts │ └── default.html ├── assets │ ├── openai.svg │ └── user.svg └── index.md ├── notebooks ├── Create Custom Task.ipynb ├── Event Extraction.ipynb ├── Named Entity Recognition.ipynb └── Relation Extraction.ipynb ├── pyproject.toml ├── src ├── __init__.py ├── config.py ├── dataset │ ├── __init__.py │ └── dataset.py ├── evaluate.py ├── generate_data.py ├── model │ ├── __init__.py │ ├── load_model.py │ ├── model_utils.py │ └── patch_models │ │ ├── README.md │ │ ├── __init__.py │ │ ├── modeling_flash_llama.py │ │ ├── patching.py │ │ ├── patching_llama.py │ │ ├── patching_neox.py │ │ └── patching_utils.py ├── paraphrase │ ├── __init__.py │ ├── config.py │ ├── dataset.py │ ├── run_paraphrasing.py │ └── utils.py ├── run.py ├── scripts │ ├── __init__.py │ ├── compare_class_scores.py │ ├── get_examples.py │ ├── get_result_table.py │ ├── plot_f1_curves.py │ ├── plot_results.py │ ├── test_context_batch_size.py │ └── visualize_example.py ├── tasks │ ├── __init__.py │ ├── ace │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── preprocess_ace.py │ │ ├── prompts.py │ │ └── scorer.py │ ├── bc5cdr │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── prompts.py │ │ └── scorer.py │ ├── broadtwitter │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── prompts.py │ │ └── scorer.py │ ├── casie │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── preprocess_casie.py │ │ ├── prompts_eae.py │ │ ├── prompts_ed.py │ │ └── scorer.py │ ├── conll03 │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── prompts.py │ │ └── scorer.py │ ├── crossner │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── prompts_ai.py │ │ ├── prompts_literature.py │ │ ├── prompts_music.py │ │ ├── prompts_natural_science.py │ │ ├── prompts_politics.py │ │ └── scorer.py │ ├── diann │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── prompts.py │ │ └── scorer.py │ ├── e3c │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── prompts.py │ │ └── scorer.py │ ├── fabner │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── prompts.py │ │ └── scorer.py │ ├── harveyner │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── prompts.py │ │ └── scorer.py │ ├── label_encoding.py │ ├── mitmovie │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── prompts.py │ │ └── scorer.py │ ├── mitrestaurant │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── prompts.py │ │ └── scorer.py │ ├── multiconer2 │ │ └── prompts.py │ ├── multinerd │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── prompts.py │ │ └── scorer.py │ ├── ncbidisease │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── prompts.py │ │ └── scorer.py │ ├── ner_tasks.csv │ ├── ontonotes │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── prompts.py │ │ └── scorer.py │ ├── rams │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── prompts.py │ │ └── scorer.py │ ├── tacred │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── prompts.py │ │ └── scorer.py │ ├── utils_data.py │ ├── utils_scorer.py │ ├── utils_typing.py │ ├── wikievents │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── preprocess_wikievents.py │ │ ├── prompts.py │ │ └── scorer.py │ └── wnut │ │ ├── __init__.py │ │ ├── data_loader.py │ │ ├── guidelines.py │ │ ├── guidelines_gold.py │ │ ├── prompts.py │ │ └── scorer.py ├── tests │ ├── __init__.py │ ├── test_dataloaders.py │ ├── test_dataset.py │ ├── test_evaluate.py │ ├── test_label_encoding.py │ ├── test_prompts.py │ ├── test_scorers.py │ └── test_trainer.py └── trainer.py └── templates ├── prompt.txt ├── prompt_ace_eae.txt ├── prompt_ace_rc.txt ├── prompt_ace_re.txt ├── prompt_eae.txt └── prompt_tacred.txt /.github/ISSUE_TEMPLATE/✍️-other-issue.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "✍️ Other Issue" 3 | about: For issues not covered by the other templates. 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | Please proviede a concise description of your issue. Proviede as much information as possible, including, if required, your system hardware, software version... and any other info that could be usefull to find a solution. 11 | 12 | Please, If you want to report a bug, or any error you found while running the code, use the "Bug report and Issues" Template. 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/🐛-bug-report-and-issues.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F41B Bug report and Issues" 3 | about: Submit a bug report to help us improve GoLLIE 4 | title: "[BUG] Bug or issue report" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the task** 11 | 1. Model: Which GoLLIE model are you attemping to run? 12 | 2. Task: Which task are you attemping to run (training, evaluation, generate the dataset,...)? 13 | 14 | **Describe the bug** 15 | A clear and concise description of what the bug is. You can add the error traceback or screenshots here. 16 | 17 | **To Reproduce** 18 | Steps to reproduce the behavior: 19 | 1. Load model X 20 | ```Python 21 | model, tokenizer = load_model( 22 | inference=True, 23 | model_weights_name_or_path="HiTZ/GoLLIE-7B", 24 | quantization=None, 25 | use_lora=False, 26 | force_auto_device_map=True, 27 | use_flash_attention=True, 28 | torch_dtype="bfloat16" 29 | ) 30 | ``` 31 | 2. Run X function 32 | ```Python 33 | model_ouput = model.generate( 34 | **model_input.to(model.device), 35 | max_new_tokens=128, 36 | do_sample=False, 37 | min_new_tokens=0, 38 | num_beams=1, 39 | num_return_sequences=1, 40 | ) 41 | ``` 42 | 3. Any other step required to reproduce the behaviour 43 | 44 | **Expected behavior** 45 | A clear and concise description of what you expected to happen. 46 | 47 | **System Info** 48 | 1. GPU: (i.e Nvidia A100) 49 | 2. Pytorch version: 50 | 3. Transformers version: 51 | 4. Model configuration: Are you using 4 / 8 bits quantization? Are you using mGPU? etc.. 52 | 5. Any other relevant information: 53 | 54 | 55 | **Additional context** 56 | Add any other context about the problem here. 57 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/📑-new-task-dataset-proposals.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F4D1 New Task/Dataset Proposals" 3 | about: Submit a proposal/request for a new task support in GoLLIE 4 | title: "[TASK] I want task X to be supported by GoLLIE" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the task you'd like to implement** 11 | A clear and concise description of the task/dataset. If you want to propose a new dataset, please explain if the task is already supported by GoLLIE (i.e Named Entity Recognition, Event Extraction, Relation Extraction...). If it is not, explain how you would implement it. 12 | 13 | **Data** 14 | Is the data for the task publicly available? If it is, provide a link to it. If it is not, explain how the data can be obtained. 15 | 16 | **Guidelines** 17 | Are guidelines available? If guidelines are available, provide a link to them. If they are not, describe how you would generate them. 18 | 19 | **Your contribution** 20 | Is there any way that you could help, e.g. by submitting a PR? 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/🚀-feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F680 Feature request" 3 | about: Submit a proposal/request for a new transformers feature 4 | title: "[FEATURE] I want feature X supported by GoLLIE" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Feature request** 11 | A clear and concise description of the feature proposal. 12 | 13 | **Motivation** 14 | Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too. 15 | 16 | 17 | **Your contribution** 18 | Is there any way that you could help, e.g. by submitting a PR? 19 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | build: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | - name: Set up Python 3.10 23 | uses: actions/setup-python@v3 24 | with: 25 | python-version: "3.10" 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip pytest 29 | pip install black numpy torch transformers datasets tqdm scikit-learn rich psutil ruff bitsandbytes peft wandb 30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 31 | - name: Lint with black 32 | run: | 33 | # stop the build if there are Python syntax errors or undefined names 34 | make style 35 | - name: Test with unittest 36 | run: | 37 | python -m unittest discover -v -s ./src/tests -p test_*.py 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .slurm/* 2 | .vscode/* 3 | .ignore/ 4 | data/* 5 | wandb/* 6 | assets/plots/* 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | 138 | .idea/ 139 | 140 | pretrained_models/ 141 | collie/ 142 | configs/model_configs/legacy/debug.yaml 143 | 144 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | check_dirs := src 2 | 3 | style: 4 | black $(check_dirs) 5 | ruff check $(check_dirs) --fix 6 | -------------------------------------------------------------------------------- /assets/GoLLIE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitz-zentroa/GoLLIE/164c611743fdc1befe71bbdf03e08c5eb4e35957/assets/GoLLIE.png -------------------------------------------------------------------------------- /assets/datasets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitz-zentroa/GoLLIE/164c611743fdc1befe71bbdf03e08c5eb4e35957/assets/datasets.png -------------------------------------------------------------------------------- /assets/snippets/space.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitz-zentroa/GoLLIE/164c611743fdc1befe71bbdf03e08c5eb4e35957/assets/snippets/space.png -------------------------------------------------------------------------------- /assets/snippets/space_guidelines.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitz-zentroa/GoLLIE/164c611743fdc1befe71bbdf03e08c5eb4e35957/assets/snippets/space_guidelines.png -------------------------------------------------------------------------------- /assets/snippets/space_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitz-zentroa/GoLLIE/164c611743fdc1befe71bbdf03e08c5eb4e35957/assets/snippets/space_result.png -------------------------------------------------------------------------------- /assets/snippets/space_text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitz-zentroa/GoLLIE/164c611743fdc1befe71bbdf03e08c5eb4e35957/assets/snippets/space_text.png -------------------------------------------------------------------------------- /assets/snippets/space_transparent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitz-zentroa/GoLLIE/164c611743fdc1befe71bbdf03e08c5eb4e35957/assets/snippets/space_transparent.png -------------------------------------------------------------------------------- /assets/snippets/space_two_terminals.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitz-zentroa/GoLLIE/164c611743fdc1befe71bbdf03e08c5eb4e35957/assets/snippets/space_two_terminals.png -------------------------------------------------------------------------------- /assets/zero_shot_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitz-zentroa/GoLLIE/164c611743fdc1befe71bbdf03e08c5eb4e35957/assets/zero_shot_results.png -------------------------------------------------------------------------------- /bash_scripts/Baseline-7B_CodeLLaMA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=Baseline-7B_CodeLLaMA 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --mem=128G 6 | #SBATCH --output=.slurm/Baseline-7B_CodeLLaMA.out.txt 7 | #SBATCH --error=.slurm/Baseline-7B_CodeLLaMA.err.txt 8 | 9 | 10 | source /ikerlariak/osainz006/venvs/GoLLIE/bin/activate 11 | 12 | 13 | export LC_ALL=en_US.UTF-8 14 | export LANG=en_US.UTF-8 15 | export LANGUAGE=en_US.UTF-8 16 | export TOKENIZERS_PARALLELISM=true 17 | export TRANSFORMERS_NO_ADVISORY_WARNINGS=true 18 | export WANDB_ENTITY=hitz-GoLLIE 19 | export WANDB_PROJECT=GoLLIEv1.0 20 | 21 | echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}" 22 | 23 | CONFIGS_FOLDER="configs/model_configs" 24 | 25 | 26 | # Call this script from root directory as: sbatch bash_scripts/Baseline-7B_CodeLLaMA.sh 27 | 28 | 29 | python3 -m src.run ${CONFIGS_FOLDER}/Baseline-7B_CodeLLaMA.yaml 30 | -------------------------------------------------------------------------------- /bash_scripts/GoLLIE-13B_CodeLLaMA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GoLLIE-13B_CodeLLaMA 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --mem=128G 6 | #SBATCH --output=.slurm/GoLLIE-13B_CodeLLaMA.out.txt 7 | #SBATCH --error=.slurm/GoLLIE-13B_CodeLLaMA.err.txt 8 | 9 | 10 | source /ikerlariak/osainz006/venvs/GoLLIE/bin/activate 11 | 12 | 13 | export LC_ALL=en_US.UTF-8 14 | export LANG=en_US.UTF-8 15 | export LANGUAGE=en_US.UTF-8 16 | export TOKENIZERS_PARALLELISM=true 17 | export TRANSFORMERS_NO_ADVISORY_WARNINGS=true 18 | export WANDB_ENTITY=hitz-GoLLIE 19 | export WANDB_PROJECT=GoLLIEv1.0 20 | 21 | echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}" 22 | 23 | CONFIGS_FOLDER="configs/model_configs" 24 | 25 | 26 | # Call this script from root directory as: sbatch bash_scripts/GoLLIE-13B_CodeLLaMA.sh 27 | 28 | 29 | python3 -m src.run ${CONFIGS_FOLDER}/GoLLIE-13B_CodeLLaMA.yaml 30 | 31 | -------------------------------------------------------------------------------- /bash_scripts/GoLLIE-34B_CodeLLaMA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GoLLIE-34B_CodeLLaMA 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --gres=gpu:2 5 | #SBATCH --mem=128G 6 | #SBATCH --output=.slurm/GoLLIE-34B_CodeLLaMA.out.txt 7 | #SBATCH --error=.slurm/GoLLIE-34B_CodeLLaMA.err.txt 8 | 9 | 10 | source /ikerlariak/osainz006/venvs/GoLLIE/bin/activate 11 | 12 | 13 | export LC_ALL=en_US.UTF-8 14 | export LANG=en_US.UTF-8 15 | export LANGUAGE=en_US.UTF-8 16 | export TOKENIZERS_PARALLELISM=true 17 | export TRANSFORMERS_NO_ADVISORY_WARNINGS=true 18 | export WANDB_ENTITY=hitz-GoLLIE 19 | export WANDB_PROJECT=GoLLIEv1.0 20 | 21 | echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}" 22 | 23 | export PYTHONPATH="$PYTHONPATH:$PWD" 24 | CONFIGS_FOLDER="configs/model_configs" 25 | 26 | # Call this script from root directory as: sbatch bash_scripts/GoLLIE-34B_CodeLLaMA.sh 27 | 28 | torchrun --standalone --master_port 37229 --nproc_per_node=2 src/run.py ${CONFIGS_FOLDER}/GoLLIE-34B_CodeLLaMA.yaml 29 | 30 | -------------------------------------------------------------------------------- /bash_scripts/GoLLIE-7B_CodeLLaMA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GoLLIE-7B_CodeLLaMA 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --mem=128G 6 | #SBATCH --output=.slurm/GoLLIE-7B_CodeLLaMA.out.txt 7 | #SBATCH --error=.slurm/GoLLIE-7B_CodeLLaMA.err.txt 8 | 9 | 10 | source /ikerlariak/osainz006/venvs/GoLLIE/bin/activate 11 | 12 | 13 | export LC_ALL=en_US.UTF-8 14 | export LANG=en_US.UTF-8 15 | export LANGUAGE=en_US.UTF-8 16 | export TOKENIZERS_PARALLELISM=true 17 | export TRANSFORMERS_NO_ADVISORY_WARNINGS=true 18 | export WANDB_ENTITY=hitz-GoLLIE 19 | export WANDB_PROJECT=GoLLIEv1.0 20 | 21 | echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}" 22 | 23 | CONFIGS_FOLDER="configs/model_configs" 24 | 25 | 26 | # Call this script from root directory as: sbatch bash_scripts/GoLLIE-7B_CodeLLaMA.sh 27 | 28 | 29 | # python3 -m src.run ${CONFIGS_FOLDER}/GoLLIE-7B_CodeLLaMA.yaml 30 | 31 | python3 -m src.run ${CONFIGS_FOLDER}/GoLLIE-7B_CodeLLaMA.yaml 32 | -------------------------------------------------------------------------------- /bash_scripts/GoLLIE-7B_CodeLLaMA_abl_dropout.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GoLLIE-7B_CodeLLaMA_abl_dropout 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --mem=128G 6 | #SBATCH --output=.slurm/GoLLIE-7B_CodeLLaMA_abl_dropout.out.txt 7 | #SBATCH --error=.slurm/GoLLIE-7B_CodeLLaMA_abl_dropout.err.txt 8 | 9 | 10 | source /ikerlariak/osainz006/venvs/GoLLIE/bin/activate 11 | 12 | 13 | export LC_ALL=en_US.UTF-8 14 | export LANG=en_US.UTF-8 15 | export LANGUAGE=en_US.UTF-8 16 | export TOKENIZERS_PARALLELISM=true 17 | export TRANSFORMERS_NO_ADVISORY_WARNINGS=true 18 | export WANDB_ENTITY=hitz-GoLLIE 19 | export WANDB_PROJECT=GoLLIEv1.0 20 | 21 | echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}" 22 | 23 | CONFIGS_FOLDER="configs/model_configs" 24 | 25 | 26 | # Call this script from root directory as: sbatch bash_scripts/GoLLIE-7B_CodeLLaMA_abl_dropout.sh 27 | 28 | python3 -m src.run ${CONFIGS_FOLDER}/GoLLIE-7B_CodeLLaMA_ablation_dropout.yaml -------------------------------------------------------------------------------- /bash_scripts/GoLLIE-7B_CodeLLaMA_abl_masking.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GoLLIE-7B_CodeLLaMA_abl_masking 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --mem=128G 6 | #SBATCH --output=.slurm/GoLLIE-7B_CodeLLaMA_abl_masking.out.txt 7 | #SBATCH --error=.slurm/GoLLIE-7B_CodeLLaMA_abl_masking.err.txt 8 | 9 | 10 | source /ikerlariak/osainz006/venvs/GoLLIE/bin/activate 11 | 12 | 13 | export LC_ALL=en_US.UTF-8 14 | export LANG=en_US.UTF-8 15 | export LANGUAGE=en_US.UTF-8 16 | export TOKENIZERS_PARALLELISM=true 17 | export TRANSFORMERS_NO_ADVISORY_WARNINGS=true 18 | export WANDB_ENTITY=hitz-GoLLIE 19 | export WANDB_PROJECT=GoLLIEv1.0 20 | 21 | echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}" 22 | 23 | CONFIGS_FOLDER="configs/model_configs" 24 | 25 | 26 | # Call this script from root directory as: sbatch bash_scripts/GoLLIE-7B_CodeLLaMA_abl_masking.sh 27 | 28 | python3 -m src.run ${CONFIGS_FOLDER}/GoLLIE-7B_CodeLLaMA_ablation_masking.yaml -------------------------------------------------------------------------------- /bash_scripts/GoLLIE-7B_CodeLLaMA_ablation_candidates.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GoLLIE-7B_CodeLLaMA 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --mem=128G 6 | #SBATCH --output=.slurm/GoLLIE-7B_CodeLLaMA.out.txt 7 | #SBATCH --error=.slurm/GoLLIE-7B_CodeLLaMA.err.txt 8 | 9 | 10 | source /ikerlariak/osainz006/venvs/GoLLIE/bin/activate 11 | 12 | 13 | export LC_ALL=en_US.UTF-8 14 | export LANG=en_US.UTF-8 15 | export LANGUAGE=en_US.UTF-8 16 | export TOKENIZERS_PARALLELISM=true 17 | export TRANSFORMERS_NO_ADVISORY_WARNINGS=true 18 | export WANDB_ENTITY=hitz-GoLLIE 19 | export WANDB_PROJECT=GoLLIEv1.0 20 | 21 | echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}" 22 | 23 | CONFIGS_FOLDER="configs/model_configs" 24 | 25 | 26 | # Call this script from root directory as: sbatch bash_scripts/GoLLIE-7B_CodeLLaMA_ablation_candidates.sh 27 | 28 | 29 | python3 -m src.run ${CONFIGS_FOLDER}/GoLLIE-7B_CodeLLaMA_ablation_candidates.yaml 30 | -------------------------------------------------------------------------------- /bash_scripts/GoLLIE-7B_CodeLLaMA_train_full_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GoLLIE-7B_CodeLLaMA_FULL_MODEL 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --gres=gpu:4 5 | #SBATCH --mem=128G 6 | #SBATCH --output=.slurm/GoLLIE-7B_CodeLLaMA_FULL_MODEL.out.txt 7 | #SBATCH --error=.slurm/GoLLIE-7B_CodeLLaMA_FULL_MODEL.err.txt 8 | 9 | 10 | source /ikerlariak/osainz006/venvs/GoLLIE/bin/activate 11 | 12 | 13 | export LC_ALL=en_US.UTF-8 14 | export LANG=en_US.UTF-8 15 | export LANGUAGE=en_US.UTF-8 16 | export TOKENIZERS_PARALLELISM=true 17 | export TRANSFORMERS_NO_ADVISORY_WARNINGS=true 18 | export WANDB_ENTITY=hitz-GoLLIE 19 | export WANDB_PROJECT=GoLLIEv1.0 20 | 21 | echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}" 22 | 23 | export PYTHONPATH="$PYTHONPATH:$PWD" 24 | CONFIGS_FOLDER="configs/model_configs" 25 | 26 | 27 | # Call this script from root directory as: sbatch bash_scripts/GoLLIE-7B_CodeLLaMA_train_full_model.sh 28 | 29 | 30 | torchrun --standalone --master_port 37223 --nproc_per_node=4 src/run.py ${CONFIGS_FOLDER}/GoLLIE-7B_CodeLLaMA_train_full_model.yaml -------------------------------------------------------------------------------- /bash_scripts/eval/Baseline-7B_CodeLLaMA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=Baseline-7B_CodeLLaMA 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --mem=128G 6 | #SBATCH --output=.slurm/Baseline-7B_CodeLLaMA.out.txt 7 | #SBATCH --error=.slurm/Baseline-7B_CodeLLaMA.err.txt 8 | 9 | 10 | source /ikerlariak/osainz006/venvs/GoLLIE/bin/activate 11 | 12 | 13 | export LC_ALL=en_US.UTF-8 14 | export LANG=en_US.UTF-8 15 | export LANGUAGE=en_US.UTF-8 16 | export TOKENIZERS_PARALLELISM=true 17 | export TRANSFORMERS_NO_ADVISORY_WARNINGS=true 18 | export WANDB_ENTITY=hitz-GoLLIE 19 | export WANDB_PROJECT=GoLLIEv1.0 20 | 21 | echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}" 22 | 23 | CONFIGS_FOLDER="configs/model_configs/eval" 24 | 25 | 26 | # Call this script from root directory as: sbatch bash_scripts/Baseline-7B_CodeLLaMA.sh 27 | 28 | 29 | python3 -m src.run ${CONFIGS_FOLDER}/Baseline-7B_CodeLLaMA.yaml 30 | -------------------------------------------------------------------------------- /bash_scripts/eval/CoLLIE-7B_CodeLLaMA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GoLLIE-7B_CodeLLaMA 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --mem=128G 6 | #SBATCH --output=.slurm/GoLLIE-7B_CodeLLaMA.out.txt 7 | #SBATCH --error=.slurm/GoLLIE-7B_CodeLLaMA.err.txt 8 | 9 | 10 | source /ikerlariak/osainz006/venvs/GoLLIE/bin/activate 11 | 12 | 13 | export LC_ALL=en_US.UTF-8 14 | export LANG=en_US.UTF-8 15 | export LANGUAGE=en_US.UTF-8 16 | export TOKENIZERS_PARALLELISM=true 17 | export TRANSFORMERS_NO_ADVISORY_WARNINGS=true 18 | export WANDB_ENTITY=hitz-GoLLIE 19 | export WANDB_PROJECT=GoLLIEv1.0 20 | 21 | echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}" 22 | 23 | CONFIGS_FOLDER="configs/model_configs/eval" 24 | 25 | 26 | # Call this script from root directory as: sbatch bash_scripts/GoLLIE-7B_CodeLLaMA.sh 27 | 28 | 29 | # python3 -m src.run ${CONFIGS_FOLDER}/GoLLIE-7B_CodeLLaMA.yaml 30 | 31 | python3 -m src.run ${CONFIGS_FOLDER}/eval/GoLLIE-7B_CodeLLaMA.yaml -------------------------------------------------------------------------------- /bash_scripts/eval/CoLLIE-7B_CodeLLaMA_ablation_candidates.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GoLLIE-7B_CodeLLaMA 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --mem=128G 6 | #SBATCH --output=.slurm/GoLLIE-7B_CodeLLaMA.out.txt 7 | #SBATCH --error=.slurm/GoLLIE-7B_CodeLLaMA.err.txt 8 | 9 | 10 | source /ikerlariak/osainz006/venvs/GoLLIE/bin/activate 11 | 12 | 13 | export LC_ALL=en_US.UTF-8 14 | export LANG=en_US.UTF-8 15 | export LANGUAGE=en_US.UTF-8 16 | export TOKENIZERS_PARALLELISM=true 17 | export TRANSFORMERS_NO_ADVISORY_WARNINGS=true 18 | export WANDB_ENTITY=hitz-GoLLIE 19 | export WANDB_PROJECT=GoLLIEv1.0 20 | 21 | echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}" 22 | 23 | CONFIGS_FOLDER="configs/model_configs/eval" 24 | 25 | 26 | # Call this script from root directory as: sbatch bash_scripts/GoLLIE-7B_CodeLLaMA_ablation_candidates.sh 27 | 28 | 29 | python3 -m src.run ${CONFIGS_FOLDER}/GoLLIE-7B_CodeLLaMA_ablation_candidates.yaml 30 | -------------------------------------------------------------------------------- /bash_scripts/eval/GoLLIE-13B_CodeLLaMA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GoLLIE-13B_CodeLLaMA 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --mem=128G 6 | #SBATCH --output=.slurm/GoLLIE-13B_CodeLLaMA.out.txt 7 | #SBATCH --error=.slurm/GoLLIE-13B_CodeLLaMA.err.txt 8 | 9 | 10 | source /ikerlariak/osainz006/venvs/GoLLIE/bin/activate 11 | 12 | 13 | export LC_ALL=en_US.UTF-8 14 | export LANG=en_US.UTF-8 15 | export LANGUAGE=en_US.UTF-8 16 | export TOKENIZERS_PARALLELISM=true 17 | export TRANSFORMERS_NO_ADVISORY_WARNINGS=true 18 | export WANDB_ENTITY=hitz-GoLLIE 19 | export WANDB_PROJECT=GoLLIEv1.0 20 | 21 | echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}" 22 | 23 | CONFIGS_FOLDER="configs/model_configs/eval" 24 | 25 | 26 | # Call this script from root directory as: sbatch bash_scripts/GoLLIE-13B_CodeLLaMA.sh 27 | 28 | 29 | python3 -m src.run ${CONFIGS_FOLDER}/GoLLIE-13B_CodeLLaMA.yaml 30 | 31 | -------------------------------------------------------------------------------- /bash_scripts/eval/GoLLIE-34B_CodeLLaMA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GoLLIE-34B_CodeLLaMA 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --gres=gpu:2 5 | #SBATCH --mem=128G 6 | #SBATCH --output=.slurm/GoLLIE-34B_CodeLLaMA.out.txt 7 | #SBATCH --error=.slurm/GoLLIE-34B_CodeLLaMA.err.txt 8 | 9 | 10 | source /ikerlariak/osainz006/venvs/GoLLIE/bin/activate 11 | 12 | 13 | export LC_ALL=en_US.UTF-8 14 | export LANG=en_US.UTF-8 15 | export LANGUAGE=en_US.UTF-8 16 | export TOKENIZERS_PARALLELISM=true 17 | export TRANSFORMERS_NO_ADVISORY_WARNINGS=true 18 | export WANDB_ENTITY=hitz-GoLLIE 19 | export WANDB_PROJECT=GoLLIEv1.0 20 | 21 | echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}" 22 | 23 | export PYTHONPATH="$PYTHONPATH:$PWD" 24 | CONFIGS_FOLDER="configs/model_configs/eval" 25 | 26 | # Call this script from root directory as: sbatch bash_scripts/GoLLIE-34B_CodeLLaMA.sh 27 | 28 | torchrun --standalone --master_port 37228 --nproc_per_node=2 src/run.py ${CONFIGS_FOLDER}/GoLLIE-34B_CodeLLaMA.yaml 29 | 30 | -------------------------------------------------------------------------------- /bash_scripts/eval/GoLLIE-7B_CodeLLaMA_abl_dropout.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GoLLIE-7B_CodeLLaMA_abl_dropout 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --mem=128G 6 | #SBATCH --output=.slurm/GoLLIE-7B_CodeLLaMA_abl_dropout.out.txt 7 | #SBATCH --error=.slurm/GoLLIE-7B_CodeLLaMA_abl_dropout.err.txt 8 | 9 | 10 | source /ikerlariak/osainz006/venvs/GoLLIE/bin/activate 11 | 12 | 13 | export LC_ALL=en_US.UTF-8 14 | export LANG=en_US.UTF-8 15 | export LANGUAGE=en_US.UTF-8 16 | export TOKENIZERS_PARALLELISM=true 17 | export TRANSFORMERS_NO_ADVISORY_WARNINGS=true 18 | export WANDB_ENTITY=hitz-GoLLIE 19 | export WANDB_PROJECT=GoLLIEv1.0 20 | 21 | echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}" 22 | 23 | CONFIGS_FOLDER="configs/model_configs/eval" 24 | 25 | 26 | # Call this script from root directory as: sbatch bash_scripts/GoLLIE-7B_CodeLLaMA_abl_dropout.sh 27 | 28 | python3 -m src.run ${CONFIGS_FOLDER}/GoLLIE-7B_CodeLLaMA_ablation_dropout.yaml -------------------------------------------------------------------------------- /bash_scripts/eval/GoLLIE-7B_CodeLLaMA_abl_masking.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GoLLIE-7B_CodeLLaMA_abl_masking 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --mem=128G 6 | #SBATCH --output=.slurm/GoLLIE-7B_CodeLLaMA_abl_masking.out.txt 7 | #SBATCH --error=.slurm/GoLLIE-7B_CodeLLaMA_abl_masking.err.txt 8 | 9 | 10 | source /ikerlariak/osainz006/venvs/GoLLIE/bin/activate 11 | 12 | 13 | export LC_ALL=en_US.UTF-8 14 | export LANG=en_US.UTF-8 15 | export LANGUAGE=en_US.UTF-8 16 | export TOKENIZERS_PARALLELISM=true 17 | export TRANSFORMERS_NO_ADVISORY_WARNINGS=true 18 | export WANDB_ENTITY=hitz-GoLLIE 19 | export WANDB_PROJECT=GoLLIEv1.0 20 | 21 | echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}" 22 | 23 | CONFIGS_FOLDER="configs/model_configs/eval" 24 | 25 | 26 | # Call this script from root directory as: sbatch bash_scripts/GoLLIE-7B_CodeLLaMA_abl_masking.sh 27 | 28 | python3 -m src.run ${CONFIGS_FOLDER}/GoLLIE-7B_CodeLLaMA_ablation_masking.yaml -------------------------------------------------------------------------------- /bash_scripts/eval/eval_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GoLLIE_Eval_ALL 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --mem=128G 6 | #SBATCH --output=.slurm/GoLLIE_Eval_ALL.out.txt 7 | #SBATCH --error=.slurm/GoLLIE_Eval_ALL.err.txt 8 | 9 | 10 | source /ikerlariak/osainz006/venvs/GoLLIE/bin/activate 11 | 12 | 13 | export LC_ALL=en_US.UTF-8 14 | export LANG=en_US.UTF-8 15 | export LANGUAGE=en_US.UTF-8 16 | export TOKENIZERS_PARALLELISM=true 17 | export TRANSFORMERS_NO_ADVISORY_WARNINGS=true 18 | export WANDB_ENTITY=hitz-GoLLIE 19 | export WANDB_PROJECT=GoLLIEv1.0 20 | 21 | echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}" 22 | 23 | CONFIGS_FOLDER="configs/model_configs/eval" 24 | 25 | # Call this script from root directory as: sbatch bash_scripts/eval_all.sh 26 | 27 | 28 | python3 -m src.run ${CONFIGS_FOLDER}/Baseline-7B_CodeLLaMA.yaml 29 | python3 -m src.run ${CONFIGS_FOLDER}/GoLLIE-7B_CodeLLaMA.yaml 30 | python3 -m src.run ${CONFIGS_FOLDER}/GoLLIE-7B_CodeLLaMA_ablation_candidates.yaml 31 | python3 -m src.run ${CONFIGS_FOLDER}/GoLLIE-7B_CodeLLaMA_ablation_dropout.yaml 32 | python3 -m src.run ${CONFIGS_FOLDER}/GoLLIE-7B_CodeLLaMA_ablation_masking.yaml 33 | python3 -m src.run ${CONFIGS_FOLDER}/GoLLIE-13B_CodeLLaMA.yaml 34 | python3 -m src.run ${CONFIGS_FOLDER}/GoLLIE-34B_CodeLLaMA.yaml 35 | 36 | 37 | -------------------------------------------------------------------------------- /bash_scripts/eval/run_all_metrics.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GoLLIE_Eval_ALL 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --mem=128G 6 | #SBATCH --output=.slurm/GoLLIE_Eval_ALL.out.txt 7 | #SBATCH --error=.slurm/GoLLIE_Eval_ALL.err.txt 8 | 9 | 10 | source /ikerlariak/osainz006/venvs/GoLLIE/bin/activate 11 | 12 | 13 | export LC_ALL=en_US.UTF-8 14 | export LANG=en_US.UTF-8 15 | export LANGUAGE=en_US.UTF-8 16 | export TOKENIZERS_PARALLELISM=true 17 | export TRANSFORMERS_NO_ADVISORY_WARNINGS=true 18 | export WANDB_ENTITY=hitz-GoLLIE 19 | export WANDB_PROJECT=GoLLIEv1.0 20 | 21 | echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}" 22 | 23 | CONFIGS_FOLDER="configs/model_configs/eval" 24 | 25 | # Call this script from root directory as: sbatch bash_scripts/eval_all.sh 26 | 27 | python3 -m src.evaluate ${CONFIGS_FOLDER}/Baseline-7B_CodeLLaMA.yaml 28 | python3 -m src.evaluate ${CONFIGS_FOLDER}/GoLLIE-7B_CodeLLaMA.yaml 29 | python3 -m src.evaluate ${CONFIGS_FOLDER}/GoLLIE-7B_CodeLLaMA_ablation_dropout.yaml 30 | python3 -m src.evaluate ${CONFIGS_FOLDER}/GoLLIE-7B_CodeLLaMA_ablation_masking.yaml 31 | python3 -m src.evaluate ${CONFIGS_FOLDER}/GoLLIE-13B_CodeLLaMA.yaml 32 | python3 -m src.evaluate ${CONFIGS_FOLDER}/GoLLIE-34B_CodeLLaMA.yaml 33 | 34 | 35 | -------------------------------------------------------------------------------- /bash_scripts/generate_ablation_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source /gscratch4/users/osainz006/GoLLIE/venv/GoLLIE/bin/activate 4 | 5 | CONFIG_DIR="configs/data_configs" 6 | 7 | OUTPUT_DIR="data/processed_w_examples_abl_dropout" 8 | 9 | python -m src.generate_data \ 10 | --configs \ 11 | ${CONFIG_DIR}/ace_config.json \ 12 | ${CONFIG_DIR}/bc5cdr_config.json \ 13 | ${CONFIG_DIR}/broadtwitter_config.json \ 14 | ${CONFIG_DIR}/casie_config.json \ 15 | ${CONFIG_DIR}/conll03_config.json \ 16 | ${CONFIG_DIR}/crossner_ai_config.json \ 17 | ${CONFIG_DIR}/crossner_literature_config.json \ 18 | ${CONFIG_DIR}/crossner_music_config.json \ 19 | ${CONFIG_DIR}/crossner_politics_config.json \ 20 | ${CONFIG_DIR}/crossner_science_config.json \ 21 | ${CONFIG_DIR}/diann_config.json \ 22 | ${CONFIG_DIR}/e3c_config.json \ 23 | ${CONFIG_DIR}/europarl_config.json \ 24 | ${CONFIG_DIR}/fabner_config.json \ 25 | ${CONFIG_DIR}/harveyner_config.json \ 26 | ${CONFIG_DIR}/mitmovie_config.json \ 27 | ${CONFIG_DIR}/mitrestaurant_config.json \ 28 | ${CONFIG_DIR}/mitmovie_config.json \ 29 | ${CONFIG_DIR}/multinerd_config.json \ 30 | ${CONFIG_DIR}/ncbidisease_config.json \ 31 | ${CONFIG_DIR}/ontonotes_config.json \ 32 | ${CONFIG_DIR}/rams_config.json \ 33 | ${CONFIG_DIR}/tacred_config.json \ 34 | ${CONFIG_DIR}/wikievents_config.json \ 35 | ${CONFIG_DIR}/wnut17_config.json \ 36 | --output ${OUTPUT_DIR} \ 37 | --overwrite_output_dir \ 38 | --include_examples \ 39 | --remove_dropout 40 | 41 | OUTPUT_DIR="data/processed_w_examples_abl_masking" 42 | 43 | python -m src.generate_data \ 44 | --configs \ 45 | ${CONFIG_DIR}/ace_config.json \ 46 | ${CONFIG_DIR}/bc5cdr_config.json \ 47 | ${CONFIG_DIR}/broadtwitter_config.json \ 48 | ${CONFIG_DIR}/casie_config.json \ 49 | ${CONFIG_DIR}/conll03_config.json \ 50 | ${CONFIG_DIR}/crossner_ai_config.json \ 51 | ${CONFIG_DIR}/crossner_literature_config.json \ 52 | ${CONFIG_DIR}/crossner_music_config.json \ 53 | ${CONFIG_DIR}/crossner_politics_config.json \ 54 | ${CONFIG_DIR}/crossner_science_config.json \ 55 | ${CONFIG_DIR}/diann_config.json \ 56 | ${CONFIG_DIR}/e3c_config.json \ 57 | ${CONFIG_DIR}/europarl_config.json \ 58 | ${CONFIG_DIR}/fabner_config.json \ 59 | ${CONFIG_DIR}/harveyner_config.json \ 60 | ${CONFIG_DIR}/mitmovie_config.json \ 61 | ${CONFIG_DIR}/mitrestaurant_config.json \ 62 | ${CONFIG_DIR}/mitmovie_config.json \ 63 | ${CONFIG_DIR}/multinerd_config.json \ 64 | ${CONFIG_DIR}/ncbidisease_config.json \ 65 | ${CONFIG_DIR}/ontonotes_config.json \ 66 | ${CONFIG_DIR}/rams_config.json \ 67 | ${CONFIG_DIR}/tacred_config.json \ 68 | ${CONFIG_DIR}/wikievents_config.json \ 69 | ${CONFIG_DIR}/wnut17_config.json \ 70 | --output ${OUTPUT_DIR} \ 71 | --overwrite_output_dir \ 72 | --include_examples \ 73 | --remove_masking -------------------------------------------------------------------------------- /bash_scripts/generate_data_modified.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source /gscratch4/users/osainz006/GoLLIE/venv/GoLLIE/bin/activate 4 | 5 | CONFIG_DIR="configs/data_configs" 6 | 7 | OUTPUT_DIR="data/processed_w_examples" 8 | 9 | python -m src.generate_data \ 10 | --configs \ 11 | ${CONFIG_DIR}/crossner_ai_config.json \ 12 | ${CONFIG_DIR}/crossner_literature_config.json \ 13 | ${CONFIG_DIR}/crossner_music_config.json \ 14 | ${CONFIG_DIR}/crossner_politics_config.json \ 15 | ${CONFIG_DIR}/crossner_science_config.json \ 16 | ${CONFIG_DIR}/crossner_ai_wo_misc_config.json \ 17 | ${CONFIG_DIR}/crossner_literature_wo_misc_config.json \ 18 | ${CONFIG_DIR}/crossner_music_wo_misc_config.json \ 19 | ${CONFIG_DIR}/crossner_politics_wo_misc_config.json \ 20 | ${CONFIG_DIR}/crossner_science_wo_misc_config.json \ 21 | ${CONFIG_DIR}/mitmovie_config.json \ 22 | ${CONFIG_DIR}/mitrestaurant_config.json \ 23 | --output ${OUTPUT_DIR} \ 24 | --overwrite_output_dir \ 25 | --include_examples 26 | 27 | OUTPUT_DIR="data/processed" 28 | 29 | 30 | python -m src.generate_data \ 31 | --configs \ 32 | ${CONFIG_DIR}/crossner_ai_config.json \ 33 | ${CONFIG_DIR}/crossner_literature_config.json \ 34 | ${CONFIG_DIR}/crossner_music_config.json \ 35 | ${CONFIG_DIR}/crossner_politics_config.json \ 36 | ${CONFIG_DIR}/crossner_science_config.json \ 37 | ${CONFIG_DIR}/crossner_ai_wo_misc_config.json \ 38 | ${CONFIG_DIR}/crossner_literature_wo_misc_config.json \ 39 | ${CONFIG_DIR}/crossner_music_wo_misc_config.json \ 40 | ${CONFIG_DIR}/crossner_politics_wo_misc_config.json \ 41 | ${CONFIG_DIR}/crossner_science_wo_misc_config.json \ 42 | ${CONFIG_DIR}/mitmovie_config.json \ 43 | ${CONFIG_DIR}/mitrestaurant_config.json \ 44 | --output ${OUTPUT_DIR} \ 45 | --overwrite_output_dir 46 | 47 | # Generate baseline data 48 | OUTPUT_DIR="data/baseline" 49 | 50 | python -m src.generate_data \ 51 | --configs \ 52 | ${CONFIG_DIR}/crossner_ai_config.json \ 53 | ${CONFIG_DIR}/crossner_literature_config.json \ 54 | ${CONFIG_DIR}/crossner_music_config.json \ 55 | ${CONFIG_DIR}/crossner_politics_config.json \ 56 | ${CONFIG_DIR}/crossner_science_config.json \ 57 | ${CONFIG_DIR}/crossner_ai_wo_misc_config.json \ 58 | ${CONFIG_DIR}/crossner_literature_wo_misc_config.json \ 59 | ${CONFIG_DIR}/crossner_music_wo_misc_config.json \ 60 | ${CONFIG_DIR}/crossner_politics_wo_misc_config.json \ 61 | ${CONFIG_DIR}/crossner_science_wo_misc_config.json \ 62 | ${CONFIG_DIR}/mitmovie_config.json \ 63 | ${CONFIG_DIR}/mitrestaurant_config.json \ 64 | --output ${OUTPUT_DIR} \ 65 | --overwrite_output_dir \ 66 | --baseline 67 | -------------------------------------------------------------------------------- /bash_scripts/preprocess_ace.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ACE_DATA_PATH="/proiektuak/clever/data/ace_2005_td_v7/data" 4 | ACE_DATA_SPLITS="data/ace05/splits" 5 | LANGUAGE="english" # "chinese" 6 | OUTPUT_PATH="data/ace05" 7 | 8 | mkdir ${OUTPUT_PATH} 9 | 10 | python src/tasks/ace/preprocess_ace.py \ 11 | -i ${ACE_DATA_PATH} \ 12 | -o ${OUTPUT_PATH} \ 13 | -s ${ACE_DATA_SPLITS} \ 14 | -l ${LANGUAGE} \ 15 | --time_and_val 16 | -------------------------------------------------------------------------------- /bash_scripts/reformat_code.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m black . -------------------------------------------------------------------------------- /bash_scripts/run_paraphrase.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=paraphrase 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --mem=128G 6 | #SBATCH --output=.slurm/paraphrase.out.txt 7 | #SBATCH --error=.slurm/paraphrase.err.txt 8 | 9 | source /ikerlariak/osainz006/venvs/GoLLIE/bin/activate 10 | 11 | export LC_ALL=en_US.UTF-8 12 | export LANG=en_US.UTF-8 13 | export LANGUAGE=en_US.UTF-8 14 | export TOKENIZERS_PARALLELISM=true 15 | export TRANSFORMERS_NO_ADVISORY_WARNINGS=true 16 | export WANDB_ENTITY=hitz-GoLLIE 17 | export WANDB_PROJECT=GoLLIE 18 | 19 | 20 | CONFIGS_FOLDER="configs/pharapharse_config" 21 | 22 | 23 | 24 | python3 -m src.paraphrase.run_paraphrasing ${CONFIGS_FOLDER}/LlaMA2-Chat.yaml 25 | 26 | -------------------------------------------------------------------------------- /configs/data_configs/ace_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "ACE05", 3 | "tasks": [ 4 | "NER", 5 | "VER", 6 | "RE", 7 | "RC", 8 | "EE", 9 | "EAE" 10 | ], 11 | "dataloader_cls": "src.tasks.ace.data_loader.ACEDatasetLoader", 12 | "sampler_cls": "src.tasks.ace.data_loader.ACESampler", 13 | "train_file": "data/ace05/train.sentence.json", 14 | "dev_file": "data/ace05/dev.sentence.json", 15 | "test_file": "data/ace05/test.sentence.json", 16 | "prompt_template": "templates/prompt.txt", 17 | "seed": [0, 24, 42], 18 | "label_noise_prob": [0.15, 0.50, 0.75], 19 | "task_configuration": { 20 | "NER": { 21 | "group_by": "sentence", 22 | "parallel_instances": 1, 23 | "max_guidelines": -1, 24 | "guideline_dropout": 0.15, 25 | "scorer": "src.tasks.ace.scorer.ACEEntityScorer" 26 | }, 27 | "VER": { 28 | "group_by": "sentence", 29 | "parallel_instances": 1, 30 | "max_guidelines": -1, 31 | "guideline_dropout": 0.15, 32 | "scorer": "src.tasks.ace.scorer.ACEValueScorer" 33 | }, 34 | "RE": { 35 | "group_by": "sentence", 36 | "parallel_instances": 1, 37 | "max_guidelines": -1, 38 | "guideline_dropout": 0.15, 39 | "scorer": "src.tasks.ace.scorer.ACECoarseRelationScorer" 40 | }, 41 | "RC": { 42 | "group_by": "sentence", 43 | "parallel_instances": 1, 44 | "max_guidelines": -1, 45 | "guideline_dropout": 0.15, 46 | "scorer": "src.tasks.ace.scorer.ACERelationScorer", 47 | "ensure_positives_on_train": true 48 | }, 49 | "EE": { 50 | "group_by": "sentence", 51 | "parallel_instances": 1, 52 | "max_guidelines": -1, 53 | "guideline_dropout": 0.15, 54 | "scorer": "src.tasks.ace.scorer.ACEEventScorer" 55 | }, 56 | "EAE": { 57 | "group_by": "sentence", 58 | "parallel_instances": 1, 59 | "max_guidelines": -1, 60 | "sample_total_guidelines": 5, 61 | "guideline_dropout": 0.15, 62 | "scorer": "src.tasks.ace.scorer.ACEEventArgumentScorer", 63 | "ensure_positives_on_train": true 64 | } 65 | } 66 | } -------------------------------------------------------------------------------- /configs/data_configs/bc5cdr_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "BC5CDR", 3 | "tasks": ["NER"], 4 | "dataloader_cls": "src.tasks.bc5cdr.data_loader.Bc5cdrDatasetLoader", 5 | "sampler_cls": "src.tasks.bc5cdr.data_loader.Bc5cdrSampler", 6 | "train_file": "train", 7 | "dev_file": "validation", 8 | "test_file": "test", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "task_configuration": { 13 | "NER": { 14 | "parallel_instances": 1, 15 | "max_guidelines": -1, 16 | "guideline_dropout": 0.15, 17 | "scorer": "src.tasks.bc5cdr.scorer.Bc5cdrEntityScorer", 18 | "paraphrase_train": true, 19 | "label_noise": 0.5 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /configs/data_configs/broadtwitter_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "BroadTwitter", 3 | "tasks": ["NER"], 4 | "dataloader_cls": "src.tasks.broadtwitter.data_loader.BroadTwitterDatasetLoader", 5 | "sampler_cls": "src.tasks.broadtwitter.data_loader.BroadTwitterSampler", 6 | "train_file": "train", 7 | "dev_file": "validation", 8 | "test_file": "test", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "task_configuration": { 13 | "NER": { 14 | "parallel_instances": 1, 15 | "max_guidelines": -1, 16 | "guideline_dropout": 0.15, 17 | "scorer": "src.tasks.broadtwitter.scorer.BroadTwitterEntityScorer", 18 | "paraphrase_train": true, 19 | "label_noise": 0.5 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /configs/data_configs/casie_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "CASIE", 3 | "tasks": [ 4 | "EE", 5 | "EAE" 6 | ], 7 | "dataloader_cls": "src.tasks.casie.data_loader.CASIEDatasetLoader", 8 | "sampler_cls": "src.tasks.casie.data_loader.CASIESampler", 9 | "dev_file": "data/casie/data.dev.jsonl", 10 | "test_file": "data/casie/data.test.jsonl", 11 | "prompt_template": "templates/prompt.txt", 12 | "seed": 0, 13 | "label_noise_prob": 0.0, 14 | "task_configuration": { 15 | "EE": { 16 | "parallel_instances": 1, 17 | "max_guidelines": -1, 18 | "guideline_dropout": 0.0, 19 | "scorer": "src.tasks.casie.scorer.CASIEEventScorer" 20 | }, 21 | "EAE": { 22 | "parallel_instances": 1, 23 | "max_guidelines": -1, 24 | "sample_total_guidelines": -1, 25 | "guideline_dropout": 0.0, 26 | "scorer": "src.tasks.casie.scorer.CASIEEventArgumentScorer", 27 | "ensure_positives_on_train": true, 28 | "sample_only_gold_guidelines": true 29 | } 30 | } 31 | } -------------------------------------------------------------------------------- /configs/data_configs/conll03_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "CoNLL03", 3 | "tasks": ["NER"], 4 | "dataloader_cls": "src.tasks.conll03.data_loader.CoNLLDatasetLoader", 5 | "sampler_cls": "src.tasks.conll03.data_loader.CoNLL03Sampler", 6 | "train_file": "train", 7 | "dev_file": "validation", 8 | "test_file": "test", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "include_misc": true, 13 | "task_configuration": { 14 | "NER": { 15 | "parallel_instances": 1, 16 | "max_guidelines": -1, 17 | "guideline_dropout": 0.15, 18 | "scorer": "src.tasks.conll03.scorer.CoNLL03EntityScorer", 19 | "paraphrase_train": true, 20 | "label_noise": 0.5 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /configs/data_configs/crossner_ai_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "CrossNER", 3 | "tasks": ["CrossNER_AI"], 4 | "dataloader_cls": "src.tasks.crossner.data_loader.CrossNERDatasetLoader", 5 | "sampler_cls": "src.tasks.crossner.data_loader.CrossNERSampler", 6 | "train_file": "data/CrossNer/ai/train.txt", 7 | "dev_file": "data/CrossNer/ai/dev.txt", 8 | "test_file": "data/CrossNer/ai/test.txt", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "include_misc": true, 13 | "task_configuration": { 14 | "CrossNER_AI": { 15 | "parallel_instances": 1, 16 | "max_guidelines": -1, 17 | "guideline_dropout": 0.0, 18 | "scorer": "src.tasks.crossner.scorer.CrossNERAIEntityScorer", 19 | "paraphrase_train": true, 20 | "label_noise": 0.5 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /configs/data_configs/crossner_ai_wo_misc_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "CrossNER_woMISC", 3 | "tasks": ["CrossNER_AI"], 4 | "dataloader_cls": "src.tasks.crossner.data_loader.CrossNERDatasetLoader", 5 | "sampler_cls": "src.tasks.crossner.data_loader.CrossNERSampler", 6 | "train_file": "data/CrossNer/ai/train.txt", 7 | "dev_file": "data/CrossNer/ai/dev.txt", 8 | "test_file": "data/CrossNer/ai/test.txt", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "include_misc": false, 13 | "task_configuration": { 14 | "CrossNER_AI": { 15 | "parallel_instances": 1, 16 | "max_guidelines": -1, 17 | "guideline_dropout": 0.0, 18 | "scorer": "src.tasks.crossner.scorer.CrossNERAIEntityScorer_woMISC", 19 | "paraphrase_train": true, 20 | "label_noise": 0.5 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /configs/data_configs/crossner_literature_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "CrossNER", 3 | "tasks": ["CrossNER_LITERATURE"], 4 | "dataloader_cls": "src.tasks.crossner.data_loader.CrossNERDatasetLoader", 5 | "sampler_cls": "src.tasks.crossner.data_loader.CrossNERSampler", 6 | "train_file": "data/CrossNer/literature/train.txt", 7 | "dev_file": "data/CrossNer/literature/dev.txt", 8 | "test_file": "data/CrossNer/literature/test.txt", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "include_misc": true, 13 | "task_configuration": { 14 | "CrossNER_LITERATURE": { 15 | "parallel_instances": 1, 16 | "max_guidelines": -1, 17 | "guideline_dropout": 0.0, 18 | "scorer": "src.tasks.crossner.scorer.CrossNERLiteratureEntityScorer", 19 | "paraphrase_train": true, 20 | "label_noise": 0.5 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /configs/data_configs/crossner_literature_wo_misc_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "CrossNER_woMISC", 3 | "tasks": ["CrossNER_LITERATURE"], 4 | "dataloader_cls": "src.tasks.crossner.data_loader.CrossNERDatasetLoader", 5 | "sampler_cls": "src.tasks.crossner.data_loader.CrossNERSampler", 6 | "train_file": "data/CrossNer/literature/train.txt", 7 | "dev_file": "data/CrossNer/literature/dev.txt", 8 | "test_file": "data/CrossNer/literature/test.txt", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "include_misc": false, 13 | "task_configuration": { 14 | "CrossNER_LITERATURE": { 15 | "parallel_instances": 1, 16 | "max_guidelines": -1, 17 | "guideline_dropout": 0.0, 18 | "scorer": "src.tasks.crossner.scorer.CrossNERLiteratureEntityScorer_woMISC", 19 | "paraphrase_train": true, 20 | "label_noise": 0.5 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /configs/data_configs/crossner_music_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "CrossNER", 3 | "tasks": ["CrossNER_MUSIC"], 4 | "dataloader_cls": "src.tasks.crossner.data_loader.CrossNERDatasetLoader", 5 | "sampler_cls": "src.tasks.crossner.data_loader.CrossNERSampler", 6 | "train_file": "data/CrossNer/music/train.txt", 7 | "dev_file": "data/CrossNer/music/dev.txt", 8 | "test_file": "data/CrossNer/music/test.txt", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "include_misc": true, 13 | "task_configuration": { 14 | "CrossNER_MUSIC": { 15 | "parallel_instances": 1, 16 | "max_guidelines": -1, 17 | "guideline_dropout": 0.0, 18 | "scorer": "src.tasks.crossner.scorer.CrossNERMusicEntityScorer", 19 | "paraphrase_train": true, 20 | "label_noise": 0.5 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /configs/data_configs/crossner_music_wo_misc_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "CrossNER_woMISC", 3 | "tasks": ["CrossNER_MUSIC"], 4 | "dataloader_cls": "src.tasks.crossner.data_loader.CrossNERDatasetLoader", 5 | "sampler_cls": "src.tasks.crossner.data_loader.CrossNERSampler", 6 | "train_file": "data/CrossNer/music/train.txt", 7 | "dev_file": "data/CrossNer/music/dev.txt", 8 | "test_file": "data/CrossNer/music/test.txt", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "include_misc": false, 13 | "task_configuration": { 14 | "CrossNER_MUSIC": { 15 | "parallel_instances": 1, 16 | "max_guidelines": -1, 17 | "guideline_dropout": 0.0, 18 | "scorer": "src.tasks.crossner.scorer.CrossNERMusicEntityScorer_woMISC", 19 | "paraphrase_train": true, 20 | "label_noise": 0.5 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /configs/data_configs/crossner_politics_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "CrossNER", 3 | "tasks": ["CrossNER_POLITICS"], 4 | "dataloader_cls": "src.tasks.crossner.data_loader.CrossNERDatasetLoader", 5 | "sampler_cls": "src.tasks.crossner.data_loader.CrossNERSampler", 6 | "train_file": "data/CrossNer/politics/train.txt", 7 | "dev_file": "data/CrossNer/politics/dev.txt", 8 | "test_file": "data/CrossNer/politics/test.txt", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "include_misc": true, 13 | "task_configuration": { 14 | "CrossNER_POLITICS": { 15 | "parallel_instances": 1, 16 | "max_guidelines": -1, 17 | "guideline_dropout": 0.0, 18 | "scorer": "src.tasks.crossner.scorer.CrossNERPoliticsEntityScorer", 19 | "paraphrase_train": true, 20 | "label_noise": 0.5 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /configs/data_configs/crossner_politics_wo_misc_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "CrossNER_woMISC", 3 | "tasks": ["CrossNER_POLITICS"], 4 | "dataloader_cls": "src.tasks.crossner.data_loader.CrossNERDatasetLoader", 5 | "sampler_cls": "src.tasks.crossner.data_loader.CrossNERSampler", 6 | "train_file": "data/CrossNer/politics/train.txt", 7 | "dev_file": "data/CrossNer/politics/dev.txt", 8 | "test_file": "data/CrossNer/politics/test.txt", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "include_misc": false, 13 | "task_configuration": { 14 | "CrossNER_POLITICS": { 15 | "parallel_instances": 1, 16 | "max_guidelines": -1, 17 | "guideline_dropout": 0.0, 18 | "scorer": "src.tasks.crossner.scorer.CrossNERPoliticsEntityScorer_woMISC", 19 | "paraphrase_train": true, 20 | "label_noise": 0.5 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /configs/data_configs/crossner_science_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "CrossNER", 3 | "tasks": ["CrossNER_NATURAL_SCIENCE"], 4 | "dataloader_cls": "src.tasks.crossner.data_loader.CrossNERDatasetLoader", 5 | "sampler_cls": "src.tasks.crossner.data_loader.CrossNERSampler", 6 | "train_file": "data/CrossNer/science/train.txt", 7 | "dev_file": "data/CrossNer/science/dev.txt", 8 | "test_file": "data/CrossNer/science/test.txt", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "include_misc": true, 13 | "task_configuration": { 14 | "CrossNER_NATURAL_SCIENCE": { 15 | "parallel_instances": 1, 16 | "max_guidelines": -1, 17 | "guideline_dropout": 0.0, 18 | "scorer": "src.tasks.crossner.scorer.CrossNERNaturalScienceEntityScorer", 19 | "paraphrase_train": true, 20 | "label_noise": 0.5 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /configs/data_configs/crossner_science_wo_misc_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "CrossNER_woMISC", 3 | "tasks": ["CrossNER_NATURAL_SCIENCE"], 4 | "dataloader_cls": "src.tasks.crossner.data_loader.CrossNERDatasetLoader", 5 | "sampler_cls": "src.tasks.crossner.data_loader.CrossNERSampler", 6 | "train_file": "data/CrossNer/science/train.txt", 7 | "dev_file": "data/CrossNer/science/dev.txt", 8 | "test_file": "data/CrossNer/science/test.txt", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "include_misc": false, 13 | "task_configuration": { 14 | "CrossNER_NATURAL_SCIENCE": { 15 | "parallel_instances": 1, 16 | "max_guidelines": -1, 17 | "guideline_dropout": 0.0, 18 | "scorer": "src.tasks.crossner.scorer.CrossNERNaturalScienceEntityScorer_woMISC", 19 | "paraphrase_train": true, 20 | "label_noise": 0.5 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /configs/data_configs/diann_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "DIANN", 3 | "tasks": ["NER"], 4 | "dataloader_cls": "src.tasks.diann.data_loader.DiannDatasetLoader", 5 | "sampler_cls": "src.tasks.diann.data_loader.DiannSampler", 6 | "train_file": "data/diann/en-diann-train.tsv", 7 | "dev_file": "data/diann/en-diann-dev.tsv", 8 | "test_file": "data/diann/en-diann-test.tsv", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "task_configuration": { 13 | "NER": { 14 | "parallel_instances": 1, 15 | "max_guidelines": -1, 16 | "guideline_dropout": 0.0, 17 | "scorer": "src.tasks.diann.scorer.DiannDiseaseEntityScorer", 18 | "paraphrase_train": true, 19 | "label_noise": 0.5 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /configs/data_configs/e3c_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "E3C", 3 | "tasks": ["NER"], 4 | "dataloader_cls": "src.tasks.e3c.data_loader.E3CDatasetLoader", 5 | "sampler_cls": "src.tasks.e3c.data_loader.E3CSampler", 6 | "train_file": "data/e3c/en-e3c-train.tsv", 7 | "dev_file": "data/e3c/en-e3c-dev.tsv", 8 | "test_file": "data/e3c/en-e3c-test.tsv", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "task_configuration": { 13 | "NER": { 14 | "parallel_instances": 1, 15 | "max_guidelines": -1, 16 | "guideline_dropout": 0.0, 17 | "scorer": "src.tasks.e3c.scorer.E3CEntityScorer", 18 | "paraphrase_train": true, 19 | "label_noise": 0.5 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /configs/data_configs/europarl_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "Europarl", 3 | "tasks": ["NER"], 4 | "dataloader_cls": "src.tasks.conll03.data_loader.CoNLLDatasetLoader", 5 | "sampler_cls": "src.tasks.conll03.data_loader.CoNLL03Sampler", 6 | "test_file": "data/europarl/en.europarl.test.tsv", 7 | "prompt_template": "templates/prompt.txt", 8 | "seed": 0, 9 | "include_misc": true, 10 | "task_configuration": { 11 | "NER": { 12 | "parallel_instances": 1, 13 | "max_guidelines": -1, 14 | "guideline_dropout": 0.15, 15 | "scorer": "src.tasks.conll03.scorer.CoNLL03EntityScorer" 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /configs/data_configs/fabner_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "FabNER", 3 | "tasks": ["NER"], 4 | "dataloader_cls": "src.tasks.fabner.data_loader.FabNerDatasetLoader", 5 | "sampler_cls": "src.tasks.fabner.data_loader.FabNerSampler", 6 | "train_file": "train", 7 | "dev_file": "validation", 8 | "test_file": "test", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "task_configuration": { 13 | "NER": { 14 | "parallel_instances": 1, 15 | "max_guidelines": -1, 16 | "guideline_dropout": 0.15, 17 | "scorer": "src.tasks.fabner.scorer.FabNerEntityScorer", 18 | "paraphrase_train": true, 19 | "label_noise": 0.5 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /configs/data_configs/harveyner_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "HarveyNER", 3 | "tasks": ["NER"], 4 | "dataloader_cls": "src.tasks.harveyner.data_loader.HarveyNerDatasetLoader", 5 | "sampler_cls": "src.tasks.harveyner.data_loader.HarveyNerSampler", 6 | "train_file": "data/HarveyNER/tweets.train.bio", 7 | "dev_file": "data/HarveyNER/tweets.dev.bio", 8 | "test_file": "data/HarveyNER/tweets.test.bio", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "task_configuration": { 13 | "NER": { 14 | "parallel_instances": 1, 15 | "max_guidelines": -1, 16 | "guideline_dropout": 0.15, 17 | "scorer": "src.tasks.harveyner.scorer.HarveyNEREntityScorer", 18 | "paraphrase_train": true, 19 | "label_noise": 0.5 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /configs/data_configs/mitmovie_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "MITMovie", 3 | "tasks": ["NER"], 4 | "dataloader_cls": "src.tasks.mitmovie.data_loader.MitMovieDatasetLoader", 5 | "sampler_cls": "src.tasks.mitmovie.data_loader.MitMovieSampler", 6 | "train_file": "data/MITmovie/engtrain.bio", 7 | "dev_file": "data/MITmovie/engdev.bio", 8 | "test_file": "data/MITmovie/engtest.bio", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "task_configuration": { 13 | "NER": { 14 | "parallel_instances": 1, 15 | "max_guidelines": -1, 16 | "guideline_dropout": 0.15, 17 | "scorer": "src.tasks.mitmovie.scorer.MitMovieEntityScorer", 18 | "paraphrase_train": true, 19 | "label_noise": 0.5 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /configs/data_configs/mitrestaurant_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "MITRestaurant", 3 | "tasks": ["NER"], 4 | "dataloader_cls": "src.tasks.mitrestaurant.data_loader.MitRestaurantDatasetLoader", 5 | "sampler_cls": "src.tasks.mitrestaurant.data_loader.MitRestaurantSampler", 6 | "train_file": "data/MITrestaurant/restauranttrain.bio", 7 | "dev_file": "data/MITrestaurant/restaurantdev.bio", 8 | "test_file": "data/MITrestaurant/restauranttest.bio", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "task_configuration": { 13 | "NER": { 14 | "parallel_instances": 1, 15 | "max_guidelines": -1, 16 | "guideline_dropout": 0.15, 17 | "scorer": "src.tasks.mitrestaurant.scorer.MitRestaurantEntityScorer", 18 | "paraphrase_train": true, 19 | "label_noise": 0.5 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /configs/data_configs/multinerd_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "MultiNERD", 3 | "tasks": ["NER"], 4 | "dataloader_cls": "src.tasks.multinerd.data_loader.MultinerdDatasetLoader", 5 | "sampler_cls": "src.tasks.multinerd.data_loader.MultinerdSampler", 6 | "train_file": "train", 7 | "dev_file": "validation", 8 | "test_file": "test", 9 | "language": "en", 10 | "prompt_template": "templates/prompt.txt", 11 | "seed": [0, 24, 42], 12 | "label_noise_prob": [0.15, 0.50, 0.75], 13 | "task_configuration": { 14 | "NER": { 15 | "parallel_instances": 1, 16 | "max_guidelines": -1, 17 | "guideline_dropout": 0.15, 18 | "scorer": "src.tasks.multinerd.scorer.MultinerdEntityScorer", 19 | "paraphrase_train": true, 20 | "label_noise": 0.5 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /configs/data_configs/ncbidisease_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "NcbiDisease", 3 | "tasks": ["NER"], 4 | "dataloader_cls": "src.tasks.ncbidisease.data_loader.NcbiDiseaseDatasetLoader", 5 | "sampler_cls": "src.tasks.ncbidisease.data_loader.NcbiDiseaseSampler", 6 | "train_file": "train", 7 | "dev_file": "validation", 8 | "test_file": "test", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "task_configuration": { 13 | "NER": { 14 | "parallel_instances": 1, 15 | "max_guidelines": -1, 16 | "guideline_dropout": 0.0, 17 | "scorer": "src.tasks.ncbidisease.scorer.NcbiDiseaseEntityScorer", 18 | "paraphrase_train": true, 19 | "label_noise": 0.5 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /configs/data_configs/ontonotes_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "OntoNotes5", 3 | "tasks": ["NER"], 4 | "dataloader_cls": "src.tasks.ontonotes.data_loader.OntoNotesDatasetLoader", 5 | "sampler_cls": "src.tasks.ontonotes.data_loader.OntoNotesSampler", 6 | "train_file": "train", 7 | "dev_file": "validation", 8 | "test_file": "test", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "task_configuration": { 13 | "NER": { 14 | "parallel_instances": 1, 15 | "max_guidelines": -1, 16 | "guideline_dropout": 0.15, 17 | "scorer": "src.tasks.ontonotes.scorer.OntoNotesEntityScorer", 18 | "paraphrase_train": true, 19 | "label_noise": 0.5 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /configs/data_configs/rams_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "RAMS", 3 | "tasks": [ 4 | "EAE" 5 | ], 6 | "dataloader_cls": "src.tasks.rams.data_loader.RAMSDatasetLoader", 7 | "sampler_cls": "src.tasks.rams.data_loader.RAMSSampler", 8 | "train_file": "data/rams/train.jsonlines", 9 | "dev_file": "data/rams/dev.jsonlines", 10 | "test_file": "data/rams/test.jsonlines", 11 | "prompt_template": "templates/prompt_eae.txt", 12 | "seed": [0, 24, 42], 13 | "label_noise_prob": [0.15, 0.50, 0.75], 14 | "task_configuration": { 15 | "EAE": { 16 | "parallel_instances": 1, 17 | "max_guidelines": 3, 18 | "guideline_dropout": 0, 19 | "scorer": "src.tasks.rams.scorer.RAMSEventScorer", 20 | "sample_only_gold_guidelines": true 21 | } 22 | } 23 | } -------------------------------------------------------------------------------- /configs/data_configs/tacred_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "TACRED", 3 | "tasks": [ 4 | "SF" 5 | ], 6 | "dataloader_cls": "src.tasks.tacred.data_loader.TACREDDatasetLoader", 7 | "sampler_cls": "src.tasks.tacred.data_loader.TACREDSampler", 8 | "train_file": "data/tacred/train.json", 9 | "dev_file": "data/tacred/dev.json", 10 | "test_file": "data/tacred/test.json", 11 | "prompt_template": "templates/prompt_tacred.txt", 12 | "seed": [0, 24, 42], 13 | "label_noise_prob": [0.15, 0.50, 0.75], 14 | "task_configuration": { 15 | "SF": { 16 | "parallel_instances": 1, 17 | "max_guidelines": 1, 18 | "guideline_dropout": 0.0, 19 | "scorer": "src.tasks.tacred.scorer.TACREDTemplateScorer", 20 | "sample_only_gold_guidelines": true 21 | } 22 | } 23 | } -------------------------------------------------------------------------------- /configs/data_configs/wikievents_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "WikiEvents", 3 | "tasks": [ 4 | "NER", 5 | "EE", 6 | "EAE" 7 | ], 8 | "dataloader_cls": "src.tasks.wikievents.data_loader.WikiEventsDatasetLoader", 9 | "sampler_cls": "src.tasks.wikievents.data_loader.WikiEventsSampler", 10 | "train_file": "data/wikievents/train.sentence.jsonl", 11 | "dev_file": "data/wikievents/dev.sentence.jsonl", 12 | "test_file": "data/wikievents/test.sentence.jsonl", 13 | "prompt_template": "templates/prompt.txt", 14 | "seed": [0, 24, 42], 15 | "label_noise_prob": [0.15, 0.50, 0.75], 16 | "task_configuration": { 17 | "NER": { 18 | "group_by": "sentence", 19 | "parallel_instances": 1, 20 | "max_guidelines": -1, 21 | "guideline_dropout": 0.15, 22 | "scorer": "src.tasks.wikievents.scorer.WikiEventsEntityScorer" 23 | }, 24 | "EE": { 25 | "group_by": "sentence", 26 | "parallel_instances": 1, 27 | "max_guidelines": -1, 28 | "guideline_dropout": 0.15, 29 | "scorer": "src.tasks.wikievents.scorer.WikiEventsEventScorer" 30 | }, 31 | "EAE": { 32 | "group_by": "sentence", 33 | "parallel_instances": 1, 34 | "max_guidelines": -1, 35 | "sample_total_guidelines": 5, 36 | "guideline_dropout": 0.15, 37 | "scorer": "src.tasks.wikievents.scorer.WikiEventsEventArgumentScorer", 38 | "ensure_positives_on_train": true 39 | } 40 | } 41 | } -------------------------------------------------------------------------------- /configs/data_configs/wnut17_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "WNUT17", 3 | "tasks": ["NER"], 4 | "dataloader_cls": "src.tasks.wnut.data_loader.WnutDatasetLoader", 5 | "sampler_cls": "src.tasks.wnut.data_loader.WnutSampler", 6 | "train_file": "train", 7 | "dev_file": "validation", 8 | "test_file": "test", 9 | "prompt_template": "templates/prompt.txt", 10 | "seed": [0, 24, 42], 11 | "label_noise_prob": [0.15, 0.50, 0.75], 12 | "task_configuration": { 13 | "NER": { 14 | "parallel_instances": 1, 15 | "max_guidelines": -1, 16 | "guideline_dropout": 0.15, 17 | "scorer": "src.tasks.wnut.scorer.WnutEntityScorer", 18 | "paraphrase_train": true, 19 | "label_noise": 0.5 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /configs/deepspeed_configs/deepspeed_zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto", 28 | "warmup_type": "linear" 29 | } 30 | }, 31 | "zero_optimization": { 32 | "stage": 2, 33 | "allgather_partitions": true, 34 | "allgather_bucket_size": 5e8, 35 | "overlap_comm": true, 36 | "reduce_scatter": true, 37 | "reduce_bucket_size": 5e8, 38 | "contiguous_gradients": true 39 | }, 40 | "gradient_accumulation_steps": "auto", 41 | "gradient_clipping": "auto", 42 | "steps_per_print": 2000, 43 | "train_batch_size": "auto", 44 | "train_micro_batch_size_per_gpu": "auto", 45 | "wall_clock_breakdown": false 46 | } 47 | -------------------------------------------------------------------------------- /configs/deepspeed_configs/deepspeed_zero2_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto", 28 | "warmup_type": "linear" 29 | } 30 | }, 31 | "zero_optimization": { 32 | "stage": 2, 33 | "offload_optimizer": { 34 | "device": "cpu", 35 | "pin_memory": true 36 | }, 37 | "allgather_partitions": true, 38 | "allgather_bucket_size": 5e8, 39 | "overlap_comm": true, 40 | "reduce_scatter": true, 41 | "reduce_bucket_size": 5e8, 42 | "contiguous_gradients": true 43 | }, 44 | "gradient_accumulation_steps": "auto", 45 | "gradient_clipping": "auto", 46 | "steps_per_print": 2000, 47 | "train_batch_size": "auto", 48 | "train_micro_batch_size_per_gpu": "auto", 49 | "wall_clock_breakdown": false 50 | } 51 | -------------------------------------------------------------------------------- /configs/deepspeed_configs/deepspeed_zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupDecayLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto", 28 | "warmup_type": "cosine", 29 | "total_num_steps": "auto" 30 | } 31 | }, 32 | "zero_optimization": { 33 | "stage": 3, 34 | "overlap_comm": true, 35 | "contiguous_gradients": true, 36 | "sub_group_size": 1e9, 37 | "reduce_bucket_size": "auto", 38 | "stage3_prefetch_bucket_size": "auto", 39 | "stage3_param_persistence_threshold": "auto", 40 | "stage3_max_live_parameters": 1e9, 41 | "stage3_max_reuse_distance": 1e9, 42 | "stage3_gather_16bit_weights_on_model_save": true 43 | }, 44 | "gradient_accumulation_steps": "auto", 45 | "gradient_clipping": "auto", 46 | "steps_per_print": 2000, 47 | "train_batch_size": "auto", 48 | "train_micro_batch_size_per_gpu": "auto", 49 | "wall_clock_breakdown": false 50 | } -------------------------------------------------------------------------------- /configs/deepspeed_configs/deepspeed_zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupDecayLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto", 28 | "warmup_type": "cosine", 29 | "total_num_steps": "auto" 30 | } 31 | }, 32 | "zero_optimization": { 33 | "stage": 3, 34 | "offload_optimizer": { 35 | "device": "cpu", 36 | "pin_memory": true 37 | }, 38 | "offload_param": { 39 | "device": "cpu", 40 | "pin_memory": true 41 | }, 42 | "overlap_comm": true, 43 | "contiguous_gradients": true, 44 | "sub_group_size": 1e9, 45 | "reduce_bucket_size": "auto", 46 | "stage3_prefetch_bucket_size": "auto", 47 | "stage3_param_persistence_threshold": "auto", 48 | "stage3_max_live_parameters": 1e9, 49 | "stage3_max_reuse_distance": 1e9, 50 | "stage3_gather_16bit_weights_on_model_save": true 51 | }, 52 | "gradient_accumulation_steps": "auto", 53 | "gradient_clipping": "auto", 54 | "steps_per_print": 2000, 55 | "train_batch_size": "auto", 56 | "train_micro_batch_size_per_gpu": "auto", 57 | "wall_clock_breakdown": false 58 | } 59 | -------------------------------------------------------------------------------- /configs/model_configs/Baseline-7B_CodeLLaMA.yaml: -------------------------------------------------------------------------------- 1 | #Training args 2 | model_name_or_path: codellama/CodeLlama-7b-hf 3 | torch_dtype: bfloat16 4 | use_lora: true 5 | quantization: 4 6 | quantization_inference: null 7 | gradient_checkpointing: true 8 | force_auto_device_map: false 9 | use_flash_attention: true 10 | 11 | # dataset arguments 12 | dataset_dir: 13 | /ikerlariak/osainz006/GoLLIE/data/baseline 14 | train_tasks: 15 | - ace05.eae 16 | - ace05.ee 17 | - ace05.ner 18 | - ace05.rc 19 | - ace05.re 20 | - ace05.ver 21 | - bc5cdr.ner 22 | - conll03.ner 23 | - diann.ner 24 | - ncbidisease.ner 25 | - ontonotes5.ner 26 | - rams.eae 27 | - tacred.sf 28 | - wnut17.ner 29 | validation_tasks: 30 | - ace05.ee 31 | - conll03.ner 32 | test_tasks: 33 | - ace05.eae 34 | - ace05.ee 35 | - ace05.ner 36 | - ace05.rc 37 | - ace05.re 38 | - ace05.ver 39 | - bc5cdr.ner 40 | - conll03.ner 41 | - diann.ner 42 | - ncbidisease.ner 43 | - ontonotes5.ner 44 | - rams.eae 45 | - tacred.sf 46 | - wikievents.eae 47 | - wikievents.ee 48 | - wikievents.ner 49 | - wnut17.ner 50 | - e3c.ner 51 | - broadtwitter.ner 52 | - fabner.ner 53 | - harveyner.ner 54 | - multinerd.ner 55 | - casie.eae 56 | - casie.ee 57 | - mitmovie.ner 58 | - mitrestaurant.ner 59 | - crossner.crossner_ai 60 | - crossner.crossner_music 61 | - crossner.crossner_politics 62 | - crossner.crossner_literature 63 | - crossner.crossner_natural_science 64 | max_examples_per_task_train: 30000 65 | max_examples_per_task_val: 5000 66 | max_examples_per_task_test: null 67 | max_seq_length: 2048 68 | generation_max_length: 2048 69 | ignore_pad_token_for_loss: true 70 | prompt_loss_weight: 0.0 71 | 72 | # checkpoint settings 73 | output_dir: /ikerlariak/osainz006/models/GoLLIE/Baseline-7b_CodeLLaMA 74 | overwrite_output_dir: true 75 | load_best_model_at_end: false 76 | save_strategy: "epoch" 77 | save_steps: 1000 78 | save_total_limit: 999 79 | 80 | # evaluation 81 | do_train: true 82 | do_eval: true 83 | do_predict: true 84 | evaluation_strategy: "epoch" 85 | eval_steps: 500 86 | eval_delay: 0 87 | predict_with_generate: true 88 | evaluate_all_checkpoints: false 89 | 90 | # batch size 91 | per_device_train_batch_size: 32 92 | per_device_eval_batch_size: 8 93 | gradient_accumulation_steps: 1 94 | generation_num_beams: 1 95 | 96 | # optimizer settings 97 | optim: adamw_torch_fused 98 | learning_rate: 0.0003 99 | weight_decay: 0.0 100 | num_train_epochs: 3 101 | lr_scheduler_type: cosine 102 | warmup_ratio: 0.03 103 | adam_epsilon: 1e-7 104 | 105 | # lora settings 106 | lora_r: 8 107 | lora_alpha: 16 108 | lora_dropout: 0.05 109 | lora_target_modules: 110 | - all 111 | 112 | # reporting 113 | logging_strategy: steps 114 | logging_first_step: true 115 | logging_steps: 25 116 | report_to: wandb 117 | run_name: "Baseline-7b_CodeLLaMA" 118 | disable_tqdm: false 119 | 120 | # hub settings 121 | push_to_hub: false 122 | resume_from_checkpoint: false 123 | 124 | # performance 125 | bf16: true 126 | fp16: false 127 | torch_compile: false 128 | ddp_find_unused_parameters: false 129 | -------------------------------------------------------------------------------- /configs/model_configs/GoLLIE-13B_CodeLLaMA.yaml: -------------------------------------------------------------------------------- 1 | #Training args 2 | model_name_or_path: codellama/CodeLlama-13b-hf 3 | torch_dtype: bfloat16 4 | use_lora: true 5 | quantization: 4 6 | quantization_inference: 4 7 | gradient_checkpointing: true 8 | force_auto_device_map: false 9 | use_flash_attention: true 10 | 11 | # dataset arguments 12 | dataset_dir: 13 | /ikerlariak/osainz006/GoLLIE/data/processed_w_examples 14 | train_tasks: 15 | - ace05.eae 16 | - ace05.ee 17 | - ace05.ner 18 | - ace05.rc 19 | - ace05.re 20 | - ace05.ver 21 | - bc5cdr.ner 22 | - conll03.ner 23 | - diann.ner 24 | - ncbidisease.ner 25 | - ontonotes5.ner 26 | - rams.eae 27 | - tacred.sf 28 | - wnut17.ner 29 | validation_tasks: 30 | - ace05.ee 31 | - conll03.ner 32 | test_tasks: 33 | - ace05.eae 34 | - ace05.ee 35 | - ace05.ner 36 | - ace05.rc 37 | - ace05.re 38 | - ace05.ver 39 | - bc5cdr.ner 40 | - conll03.ner 41 | - diann.ner 42 | - ncbidisease.ner 43 | - ontonotes5.ner 44 | - rams.eae 45 | - tacred.sf 46 | - wikievents.eae 47 | - wikievents.ee 48 | - wikievents.ner 49 | - wnut17.ner 50 | - e3c.ner 51 | - broadtwitter.ner 52 | - fabner.ner 53 | - harveyner.ner 54 | - multinerd.ner 55 | - casie.eae 56 | - casie.ee 57 | - mitmovie.ner 58 | - mitrestaurant.ner 59 | - crossner.crossner_ai 60 | - crossner.crossner_music 61 | - crossner.crossner_politics 62 | - crossner.crossner_literature 63 | - crossner.crossner_natural_science 64 | max_examples_per_task_train: 30000 65 | max_examples_per_task_val: 5000 66 | max_examples_per_task_test: null 67 | max_seq_length: 2048 68 | generation_max_length: 2048 69 | ignore_pad_token_for_loss: true 70 | prompt_loss_weight: 0.0 71 | 72 | # checkpoint settings 73 | output_dir: /ikerlariak/osainz006/models/GoLLIE/GoLLIE+-13b_CodeLLaMA 74 | overwrite_output_dir: true 75 | load_best_model_at_end: false 76 | save_strategy: "epoch" 77 | save_steps: 1000 78 | save_total_limit: 999 79 | 80 | # evaluation 81 | do_train: true 82 | do_eval: true 83 | do_predict: true 84 | evaluation_strategy: "epoch" 85 | eval_steps: 500 86 | eval_delay: 0 87 | predict_with_generate: true 88 | evaluate_all_checkpoints: false 89 | 90 | # batch size 91 | per_device_train_batch_size: 16 92 | per_device_eval_batch_size: 8 93 | gradient_accumulation_steps: 2 94 | generation_num_beams: 1 95 | 96 | # optimizer settings 97 | optim: adamw_torch_fused 98 | learning_rate: 0.0003 99 | weight_decay: 0.0 100 | num_train_epochs: 3 101 | lr_scheduler_type: cosine 102 | warmup_ratio: 0.03 103 | adam_epsilon: 1e-7 104 | 105 | # lora settings 106 | lora_r: 8 107 | lora_alpha: 16 108 | lora_dropout: 0.05 109 | lora_target_modules: 110 | - all 111 | 112 | # reporting 113 | logging_strategy: steps 114 | logging_first_step: true 115 | logging_steps: 25 116 | report_to: wandb 117 | run_name: "GoLLIE+-13b_CodeLLaMA" 118 | disable_tqdm: false 119 | 120 | # hub settings 121 | push_to_hub: false 122 | resume_from_checkpoint: false 123 | 124 | # performance 125 | bf16: true 126 | fp16: false 127 | torch_compile: false 128 | ddp_find_unused_parameters: false 129 | -------------------------------------------------------------------------------- /configs/model_configs/GoLLIE-34B_CodeLLaMA.yaml: -------------------------------------------------------------------------------- 1 | #Training args 2 | model_name_or_path: codellama/CodeLlama-34b-hf 3 | torch_dtype: bfloat16 4 | use_lora: true 5 | quantization: 4 6 | quantization_inference: 4 7 | gradient_checkpointing: true 8 | force_auto_device_map: true 9 | max_memory_MB: 80000 10 | use_flash_attention: true 11 | 12 | # dataset arguments 13 | dataset_dir: 14 | /ikerlariak/osainz006/GoLLIE/data/processed_w_examples 15 | train_tasks: 16 | - ace05.eae 17 | - ace05.ee 18 | - ace05.ner 19 | - ace05.rc 20 | - ace05.re 21 | - ace05.ver 22 | - bc5cdr.ner 23 | - conll03.ner 24 | - diann.ner 25 | - ncbidisease.ner 26 | - ontonotes5.ner 27 | - rams.eae 28 | - tacred.sf 29 | - wnut17.ner 30 | validation_tasks: 31 | - ace05.ee 32 | - conll03.ner 33 | test_tasks: 34 | - ace05.eae 35 | - ace05.ee 36 | - ace05.ner 37 | - ace05.rc 38 | - ace05.re 39 | - ace05.ver 40 | - bc5cdr.ner 41 | - conll03.ner 42 | - diann.ner 43 | - ncbidisease.ner 44 | - ontonotes5.ner 45 | - rams.eae 46 | - tacred.sf 47 | - wikievents.eae 48 | - wikievents.ee 49 | - wikievents.ner 50 | - wnut17.ner 51 | - e3c.ner 52 | - broadtwitter.ner 53 | - fabner.ner 54 | - harveyner.ner 55 | - multinerd.ner 56 | - casie.eae 57 | - casie.ee 58 | - mitmovie.ner 59 | - mitrestaurant.ner 60 | - crossner.crossner_ai 61 | - crossner.crossner_music 62 | - crossner.crossner_politics 63 | - crossner.crossner_literature 64 | - crossner.crossner_natural_science 65 | max_examples_per_task_train: 30000 66 | max_examples_per_task_val: 5000 67 | max_examples_per_task_test: null 68 | max_seq_length: 2048 69 | generation_max_length: 2048 70 | ignore_pad_token_for_loss: true 71 | prompt_loss_weight: 0.0 72 | 73 | # checkpoint settings 74 | output_dir: /ikerlariak/osainz006/models/GoLLIE/GoLLIE+-34b_CodeLLaMA 75 | overwrite_output_dir: true 76 | load_best_model_at_end: false 77 | save_strategy: "epoch" 78 | save_steps: 1000 79 | save_total_limit: 999 80 | 81 | # evaluation 82 | do_train: true 83 | do_eval: true 84 | do_predict: true 85 | evaluation_strategy: "epoch" 86 | eval_steps: 500 87 | eval_delay: 0 88 | predict_with_generate: true 89 | evaluate_all_checkpoints: false 90 | 91 | # batch size 92 | per_device_train_batch_size: 8 93 | per_device_eval_batch_size: 4 94 | gradient_accumulation_steps: 2 # 2 FOR 2 GPUs, 4 FOR 1 GPU (32 efective batch size) 95 | generation_num_beams: 1 96 | 97 | # optimizer settings 98 | optim: adamw_torch_fused 99 | learning_rate: 0.0003 100 | weight_decay: 0.0 101 | num_train_epochs: 3 102 | lr_scheduler_type: cosine 103 | warmup_ratio: 0.03 104 | adam_epsilon: 1e-7 105 | 106 | # lora settings 107 | lora_r: 8 108 | lora_alpha: 16 109 | lora_dropout: 0.05 110 | lora_target_modules: 111 | - all 112 | 113 | # reporting 114 | logging_strategy: steps 115 | logging_first_step: true 116 | logging_steps: 25 117 | report_to: wandb 118 | run_name: "GoLLIE+-34b_CodeLLaMA" 119 | disable_tqdm: false 120 | 121 | # hub settings 122 | push_to_hub: false 123 | resume_from_checkpoint: false 124 | 125 | # performance 126 | bf16: true 127 | fp16: false 128 | torch_compile: false 129 | ddp_find_unused_parameters: false 130 | -------------------------------------------------------------------------------- /configs/model_configs/GoLLIE-7B_CodeLLaMA.yaml: -------------------------------------------------------------------------------- 1 | #Training args 2 | model_name_or_path: codellama/CodeLlama-7b-hf 3 | torch_dtype: bfloat16 4 | use_lora: true 5 | quantization: 4 6 | quantization_inference: null 7 | gradient_checkpointing: true 8 | force_auto_device_map: false 9 | use_flash_attention: true 10 | 11 | # dataset arguments 12 | dataset_dir: 13 | /ikerlariak/osainz006/GoLLIE/data/processed_w_examples 14 | train_tasks: 15 | - ace05.eae 16 | - ace05.ee 17 | - ace05.ner 18 | - ace05.rc 19 | - ace05.re 20 | - ace05.ver 21 | - bc5cdr.ner 22 | - conll03.ner 23 | - diann.ner 24 | - ncbidisease.ner 25 | - ontonotes5.ner 26 | - rams.eae 27 | - tacred.sf 28 | - wnut17.ner 29 | validation_tasks: 30 | - ace05.ee 31 | - conll03.ner 32 | test_tasks: 33 | - ace05.eae 34 | - ace05.ee 35 | - ace05.ner 36 | - ace05.rc 37 | - ace05.re 38 | - ace05.ver 39 | - bc5cdr.ner 40 | - conll03.ner 41 | - diann.ner 42 | - ncbidisease.ner 43 | - ontonotes5.ner 44 | - rams.eae 45 | - tacred.sf 46 | - wikievents.eae 47 | - wikievents.ee 48 | - wikievents.ner 49 | - wnut17.ner 50 | - e3c.ner 51 | - broadtwitter.ner 52 | - fabner.ner 53 | - harveyner.ner 54 | - multinerd.ner 55 | - casie.eae 56 | - casie.ee 57 | - mitmovie.ner 58 | - mitrestaurant.ner 59 | - crossner.crossner_ai 60 | - crossner.crossner_music 61 | - crossner.crossner_politics 62 | - crossner.crossner_literature 63 | - crossner.crossner_natural_science 64 | max_examples_per_task_train: 30000 65 | max_examples_per_task_val: 5000 66 | max_examples_per_task_test: null 67 | max_seq_length: 2048 68 | generation_max_length: 2048 69 | ignore_pad_token_for_loss: true 70 | prompt_loss_weight: 0.0 71 | 72 | # checkpoint settings 73 | output_dir: /ikerlariak/osainz006/models/GoLLIE/GoLLIE+-7b_CodeLLaMA 74 | overwrite_output_dir: true 75 | load_best_model_at_end: false 76 | save_strategy: "epoch" 77 | save_steps: 1000 78 | save_total_limit: 999 79 | 80 | # evaluation 81 | do_train: true 82 | do_eval: true 83 | do_predict: true 84 | evaluation_strategy: "epoch" 85 | eval_steps: 500 86 | eval_delay: 0 87 | predict_with_generate: true 88 | evaluate_all_checkpoints: false 89 | 90 | # batch size 91 | per_device_train_batch_size: 32 92 | per_device_eval_batch_size: 8 93 | gradient_accumulation_steps: 1 94 | generation_num_beams: 1 95 | 96 | # optimizer settings 97 | optim: adamw_torch_fused 98 | learning_rate: 0.0003 99 | weight_decay: 0.0 100 | num_train_epochs: 3 101 | lr_scheduler_type: cosine 102 | warmup_ratio: 0.03 103 | adam_epsilon: 1e-7 104 | 105 | # lora settings 106 | lora_r: 8 107 | lora_alpha: 16 108 | lora_dropout: 0.05 109 | lora_target_modules: 110 | - all 111 | 112 | # reporting 113 | logging_strategy: steps 114 | logging_first_step: true 115 | logging_steps: 25 116 | report_to: wandb 117 | run_name: "GoLLIE+-7b_CodeLLaMA" 118 | disable_tqdm: false 119 | 120 | # hub settings 121 | push_to_hub: false 122 | resume_from_checkpoint: false 123 | 124 | # performance 125 | bf16: true 126 | fp16: false 127 | torch_compile: false 128 | ddp_find_unused_parameters: false 129 | -------------------------------------------------------------------------------- /configs/model_configs/GoLLIE-7B_CodeLLaMA_ablation_candiates.yaml: -------------------------------------------------------------------------------- 1 | #Training args 2 | model_name_or_path: codellama/CodeLlama-7b-hf 3 | torch_dtype: bfloat16 4 | use_lora: true 5 | quantization: 4 6 | quantization_inference: null 7 | gradient_checkpointing: true 8 | force_auto_device_map: false 9 | use_flash_attention: true 10 | 11 | # dataset arguments 12 | dataset_dir: 13 | /ikerlariak/osainz006/GoLLIE/data/processed 14 | train_tasks: 15 | - ace05.eae 16 | - ace05.ee 17 | - ace05.ner 18 | - ace05.rc 19 | - ace05.re 20 | - ace05.ver 21 | - bc5cdr.ner 22 | - conll03.ner 23 | - diann.ner 24 | - ncbidisease.ner 25 | - ontonotes5.ner 26 | - rams.eae 27 | - tacred.sf 28 | - wnut17.ner 29 | validation_tasks: 30 | - ace05.ee 31 | - conll03.ner 32 | test_tasks: 33 | - ace05.eae 34 | - ace05.ee 35 | - ace05.ner 36 | - ace05.rc 37 | - ace05.re 38 | - ace05.ver 39 | - bc5cdr.ner 40 | - conll03.ner 41 | - diann.ner 42 | - ncbidisease.ner 43 | - ontonotes5.ner 44 | - rams.eae 45 | - tacred.sf 46 | - wikievents.eae 47 | - wikievents.ee 48 | - wikievents.ner 49 | - wnut17.ner 50 | - e3c.ner 51 | - broadtwitter.ner 52 | - fabner.ner 53 | - harveyner.ner 54 | - multinerd.ner 55 | - casie.eae 56 | - casie.ee 57 | - mitmovie.ner 58 | - mitrestaurant.ner 59 | - crossner.crossner_ai 60 | - crossner.crossner_music 61 | - crossner.crossner_politics 62 | - crossner.crossner_literature 63 | - crossner.crossner_natural_science 64 | max_examples_per_task_train: 30000 65 | max_examples_per_task_val: 5000 66 | max_examples_per_task_test: null 67 | max_seq_length: 2048 68 | generation_max_length: 2048 69 | ignore_pad_token_for_loss: true 70 | prompt_loss_weight: 0.0 71 | 72 | # checkpoint settings 73 | output_dir: /ikerlariak/osainz006/models/GoLLIE/GoLLIE-7b_CodeLLaMA 74 | overwrite_output_dir: true 75 | load_best_model_at_end: false 76 | save_strategy: "epoch" 77 | save_steps: 1000 78 | save_total_limit: 999 79 | 80 | # evaluation 81 | do_train: true 82 | do_eval: true 83 | do_predict: true 84 | evaluation_strategy: "epoch" 85 | eval_steps: 500 86 | eval_delay: 0 87 | predict_with_generate: true 88 | evaluate_all_checkpoints: false 89 | 90 | # batch size 91 | per_device_train_batch_size: 32 92 | per_device_eval_batch_size: 8 93 | gradient_accumulation_steps: 1 94 | generation_num_beams: 1 95 | 96 | # optimizer settings 97 | optim: adamw_torch_fused 98 | learning_rate: 0.0003 99 | weight_decay: 0.0 100 | num_train_epochs: 3 101 | lr_scheduler_type: cosine 102 | warmup_ratio: 0.03 103 | adam_epsilon: 1e-7 104 | 105 | # lora settings 106 | lora_r: 8 107 | lora_alpha: 16 108 | lora_dropout: 0.05 109 | lora_target_modules: 110 | - all 111 | 112 | # reporting 113 | logging_strategy: steps 114 | logging_first_step: true 115 | logging_steps: 25 116 | report_to: wandb 117 | run_name: "GoLLIE-7b_CodeLLaMA" 118 | disable_tqdm: false 119 | 120 | # hub settings 121 | push_to_hub: false 122 | resume_from_checkpoint: false 123 | 124 | # performance 125 | bf16: true 126 | fp16: false 127 | torch_compile: false 128 | ddp_find_unused_parameters: false 129 | -------------------------------------------------------------------------------- /configs/model_configs/GoLLIE-7B_CodeLLaMA_ablation_dropout.yaml: -------------------------------------------------------------------------------- 1 | #Training args 2 | model_name_or_path: codellama/CodeLlama-7b-hf 3 | torch_dtype: bfloat16 4 | use_lora: true 5 | quantization: 4 6 | quantization_inference: null 7 | gradient_checkpointing: true 8 | force_auto_device_map: false 9 | use_flash_attention: true 10 | 11 | # dataset arguments 12 | dataset_dir: 13 | /ikerlariak/osainz006/GoLLIE/data/processed_w_examples_abl_dropout 14 | train_tasks: 15 | - ace05.eae 16 | - ace05.ee 17 | - ace05.ner 18 | - ace05.rc 19 | - ace05.re 20 | - ace05.ver 21 | - bc5cdr.ner 22 | - conll03.ner 23 | - diann.ner 24 | - ncbidisease.ner 25 | - ontonotes5.ner 26 | - rams.eae 27 | - tacred.sf 28 | - wnut17.ner 29 | validation_tasks: 30 | - ace05.ee 31 | - conll03.ner 32 | test_tasks: 33 | - ace05.eae 34 | - ace05.ee 35 | - ace05.ner 36 | - ace05.rc 37 | - ace05.re 38 | - ace05.ver 39 | - bc5cdr.ner 40 | - conll03.ner 41 | - diann.ner 42 | - ncbidisease.ner 43 | - ontonotes5.ner 44 | - rams.eae 45 | - tacred.sf 46 | - wikievents.eae 47 | - wikievents.ee 48 | - wikievents.ner 49 | - wnut17.ner 50 | - e3c.ner 51 | - broadtwitter.ner 52 | - fabner.ner 53 | - harveyner.ner 54 | - multinerd.ner 55 | - casie.eae 56 | - casie.ee 57 | - mitmovie.ner 58 | - mitrestaurant.ner 59 | - crossner.crossner_ai 60 | - crossner.crossner_music 61 | - crossner.crossner_politics 62 | - crossner.crossner_literature 63 | - crossner.crossner_natural_science 64 | max_examples_per_task_train: 30000 65 | max_examples_per_task_val: 5000 66 | max_examples_per_task_test: null 67 | max_seq_length: 2048 68 | generation_max_length: 2048 69 | ignore_pad_token_for_loss: true 70 | prompt_loss_weight: 0.0 71 | 72 | # checkpoint settings 73 | output_dir: /ikerlariak/osainz006/models/GoLLIE/GoLLIE+-7b_CodeLLaMA_abl_dropout 74 | overwrite_output_dir: true 75 | load_best_model_at_end: false 76 | save_strategy: "epoch" 77 | save_steps: 1000 78 | save_total_limit: 999 79 | 80 | # evaluation 81 | do_train: true 82 | do_eval: true 83 | do_predict: true 84 | evaluation_strategy: "epoch" 85 | eval_steps: 500 86 | eval_delay: 0 87 | predict_with_generate: true 88 | evaluate_all_checkpoints: false 89 | 90 | # batch size 91 | per_device_train_batch_size: 32 92 | per_device_eval_batch_size: 8 93 | gradient_accumulation_steps: 1 94 | generation_num_beams: 1 95 | 96 | # optimizer settings 97 | optim: adamw_torch_fused 98 | learning_rate: 0.0003 99 | weight_decay: 0.0 100 | num_train_epochs: 3 101 | lr_scheduler_type: cosine 102 | warmup_ratio: 0.03 103 | adam_epsilon: 1e-7 104 | 105 | # lora settings 106 | lora_r: 8 107 | lora_alpha: 16 108 | lora_dropout: 0.05 109 | lora_target_modules: 110 | - all 111 | 112 | # reporting 113 | logging_strategy: steps 114 | logging_first_step: true 115 | logging_steps: 25 116 | report_to: wandb 117 | run_name: "GoLLIE+-7b_CodeLLaMA_abl_dropout" 118 | disable_tqdm: false 119 | 120 | # hub settings 121 | push_to_hub: false 122 | resume_from_checkpoint: false 123 | 124 | # performance 125 | bf16: true 126 | fp16: false 127 | torch_compile: false 128 | ddp_find_unused_parameters: false 129 | -------------------------------------------------------------------------------- /configs/model_configs/GoLLIE-7B_CodeLLaMA_ablation_masking.yaml: -------------------------------------------------------------------------------- 1 | #Training args 2 | model_name_or_path: codellama/CodeLlama-7b-hf 3 | torch_dtype: bfloat16 4 | use_lora: true 5 | quantization: 4 6 | quantization_inference: null 7 | gradient_checkpointing: true 8 | force_auto_device_map: false 9 | use_flash_attention: true 10 | 11 | # dataset arguments 12 | dataset_dir: 13 | /ikerlariak/osainz006/GoLLIE/data/processed_w_examples_abl_masking 14 | train_tasks: 15 | - ace05.eae 16 | - ace05.ee 17 | - ace05.ner 18 | - ace05.rc 19 | - ace05.re 20 | - ace05.ver 21 | - bc5cdr.ner 22 | - conll03.ner 23 | - diann.ner 24 | - ncbidisease.ner 25 | - ontonotes5.ner 26 | - rams.eae 27 | - tacred.sf 28 | - wnut17.ner 29 | validation_tasks: 30 | - ace05.ee 31 | - conll03.ner 32 | test_tasks: 33 | - ace05.eae 34 | - ace05.ee 35 | - ace05.ner 36 | - ace05.rc 37 | - ace05.re 38 | - ace05.ver 39 | - bc5cdr.ner 40 | - conll03.ner 41 | - diann.ner 42 | - ncbidisease.ner 43 | - ontonotes5.ner 44 | - rams.eae 45 | - tacred.sf 46 | - wikievents.eae 47 | - wikievents.ee 48 | - wikievents.ner 49 | - wnut17.ner 50 | - e3c.ner 51 | - broadtwitter.ner 52 | - fabner.ner 53 | - harveyner.ner 54 | - multinerd.ner 55 | - casie.eae 56 | - casie.ee 57 | - mitmovie.ner 58 | - mitrestaurant.ner 59 | - crossner.crossner_ai 60 | - crossner.crossner_music 61 | - crossner.crossner_politics 62 | - crossner.crossner_literature 63 | - crossner.crossner_natural_science 64 | max_examples_per_task_train: 30000 65 | max_examples_per_task_val: 5000 66 | max_examples_per_task_test: null 67 | max_seq_length: 2048 68 | generation_max_length: 2048 69 | ignore_pad_token_for_loss: true 70 | prompt_loss_weight: 0.0 71 | 72 | # checkpoint settings 73 | output_dir: /ikerlariak/osainz006/models/GoLLIE/GoLLIE+-7b_CodeLLaMA_abl_masking 74 | overwrite_output_dir: true 75 | load_best_model_at_end: false 76 | save_strategy: "epoch" 77 | save_steps: 1000 78 | save_total_limit: 999 79 | 80 | # evaluation 81 | do_train: true 82 | do_eval: true 83 | do_predict: true 84 | evaluation_strategy: "epoch" 85 | eval_steps: 500 86 | eval_delay: 0 87 | predict_with_generate: true 88 | evaluate_all_checkpoints: false 89 | 90 | # batch size 91 | per_device_train_batch_size: 32 92 | per_device_eval_batch_size: 8 93 | gradient_accumulation_steps: 1 94 | generation_num_beams: 1 95 | 96 | # optimizer settings 97 | optim: adamw_torch_fused 98 | learning_rate: 0.0003 99 | weight_decay: 0.0 100 | num_train_epochs: 3 101 | lr_scheduler_type: cosine 102 | warmup_ratio: 0.03 103 | adam_epsilon: 1e-7 104 | 105 | # lora settings 106 | lora_r: 8 107 | lora_alpha: 16 108 | lora_dropout: 0.05 109 | lora_target_modules: 110 | - all 111 | 112 | # reporting 113 | logging_strategy: steps 114 | logging_first_step: true 115 | logging_steps: 25 116 | report_to: wandb 117 | run_name: "GoLLIE+-7b_CodeLLaMA_abl_masking" 118 | disable_tqdm: false 119 | 120 | # hub settings 121 | push_to_hub: false 122 | resume_from_checkpoint: false 123 | 124 | # performance 125 | bf16: true 126 | fp16: false 127 | torch_compile: false 128 | ddp_find_unused_parameters: false 129 | -------------------------------------------------------------------------------- /configs/model_configs/GoLLIE-7B_CodeLLaMA_train_full_model.yaml: -------------------------------------------------------------------------------- 1 | #Training args 2 | model_name_or_path: codellama/CodeLlama-7b-hf 3 | torch_dtype: bfloat16 4 | use_lora: false 5 | quantization: null 6 | quantization_inference: null 7 | gradient_checkpointing: true 8 | force_auto_device_map: false 9 | use_flash_attention: true 10 | deepspeed: configs/deepspeed_configs/deepspeed_zero3.json 11 | 12 | # dataset arguments 13 | dataset_dir: 14 | /ikerlariak/osainz006/GoLLIE/data/processed_w_examples 15 | train_tasks: 16 | - ace05.eae 17 | - ace05.ee 18 | - ace05.ner 19 | - ace05.rc 20 | - ace05.re 21 | - ace05.ver 22 | - bc5cdr.ner 23 | - conll03.ner 24 | - diann.ner 25 | - ncbidisease.ner 26 | - ontonotes5.ner 27 | - rams.eae 28 | - tacred.sf 29 | - wnut17.ner 30 | validation_tasks: 31 | - ace05.ee 32 | - conll03.ner 33 | test_tasks: 34 | - ace05.eae 35 | - ace05.ee 36 | - ace05.ner 37 | - ace05.rc 38 | - ace05.re 39 | - ace05.ver 40 | - bc5cdr.ner 41 | - conll03.ner 42 | - diann.ner 43 | - ncbidisease.ner 44 | - ontonotes5.ner 45 | - rams.eae 46 | - tacred.sf 47 | - wikievents.eae 48 | - wikievents.ee 49 | - wikievents.ner 50 | - wnut17.ner 51 | - e3c.ner 52 | - broadtwitter.ner 53 | - fabner.ner 54 | - harveyner.ner 55 | - multinerd.ner 56 | - casie.eae 57 | - casie.ee 58 | - mitmovie.ner 59 | - mitrestaurant.ner 60 | - crossner.crossner_ai 61 | - crossner.crossner_music 62 | - crossner.crossner_politics 63 | - crossner.crossner_literature 64 | - crossner.crossner_natural_science 65 | max_examples_per_task_train: 30000 66 | max_examples_per_task_val: 5000 67 | max_examples_per_task_test: null 68 | max_seq_length: 2048 69 | generation_max_length: 2048 70 | ignore_pad_token_for_loss: true 71 | prompt_loss_weight: 0.0 72 | 73 | # checkpoint settings 74 | output_dir: /ikerlariak/osainz006/models/GoLLIE/GoLLIE+-7b_CodeLLaMA_FULL_MODEL 75 | overwrite_output_dir: true 76 | load_best_model_at_end: false 77 | save_strategy: "epoch" 78 | save_steps: 1000 79 | save_total_limit: 999 80 | 81 | # evaluation 82 | do_train: true 83 | do_eval: true 84 | do_predict: true 85 | evaluation_strategy: "epoch" 86 | eval_steps: 500 87 | eval_delay: 0 88 | predict_with_generate: true 89 | evaluate_all_checkpoints: false 90 | 91 | # batch size: 16 batch size * 8 gradaccum * 1 GPUs = 128 92 | per_device_train_batch_size: 32 93 | per_device_eval_batch_size: 8 94 | gradient_accumulation_steps: 1 95 | generation_num_beams: 1 96 | 97 | # optimizer settings 98 | optim: adamw_torch 99 | learning_rate: 0.0001 100 | weight_decay: 0.0 101 | num_train_epochs: 3 102 | lr_scheduler_type: cosine 103 | warmup_ratio: 0.03 104 | adam_epsilon: 1e-7 105 | 106 | # lora settings 107 | lora_r: 8 108 | lora_alpha: 16 109 | lora_dropout: 0.05 110 | lora_target_modules: 111 | - all 112 | 113 | # reporting 114 | logging_strategy: steps 115 | logging_first_step: true 116 | logging_steps: 25 117 | report_to: wandb 118 | run_name: "GoLLIE+-7b_CodeLLaMA_FULL_MODEL" 119 | disable_tqdm: false 120 | 121 | # hub settings 122 | push_to_hub: false 123 | resume_from_checkpoint: false 124 | 125 | # performance 126 | bf16: true 127 | fp16: false 128 | torch_compile: false 129 | ddp_find_unused_parameters: false 130 | -------------------------------------------------------------------------------- /configs/model_configs/eval/GoLLIE-13B_CodeLLaMA.yaml: -------------------------------------------------------------------------------- 1 | #Training args 2 | model_name_or_path: HiTZ/GoLLIE-13B 3 | torch_dtype: bfloat16 4 | use_lora: false 5 | quantization: 4 6 | quantization_inference: 4 7 | gradient_checkpointing: true 8 | force_auto_device_map: false 9 | use_flash_attention: true 10 | 11 | # dataset arguments 12 | dataset_dir: 13 | /ikerlariak/osainz006/GoLLIE/data/processed_w_examples 14 | train_tasks: 15 | - ace05.eae 16 | - ace05.ee 17 | - ace05.ner 18 | - ace05.rc 19 | - ace05.re 20 | - ace05.ver 21 | - bc5cdr.ner 22 | - conll03.ner 23 | - diann.ner 24 | - ncbidisease.ner 25 | - ontonotes5.ner 26 | - rams.eae 27 | - tacred.sf 28 | - wnut17.ner 29 | validation_tasks: 30 | - ace05.ee 31 | - conll03.ner 32 | test_tasks: 33 | - ace05.eae 34 | - ace05.ee 35 | - ace05.ner 36 | - ace05.rc 37 | - ace05.re 38 | - ace05.ver 39 | - bc5cdr.ner 40 | - conll03.ner 41 | - diann.ner 42 | - ncbidisease.ner 43 | - ontonotes5.ner 44 | - rams.eae 45 | - tacred.sf 46 | - wikievents.eae 47 | - wikievents.ee 48 | - wikievents.ner 49 | - wnut17.ner 50 | - e3c.ner 51 | - broadtwitter.ner 52 | - fabner.ner 53 | - harveyner.ner 54 | - multinerd.ner 55 | - casie.eae 56 | - casie.ee 57 | - mitmovie.ner 58 | - mitrestaurant.ner 59 | - crossner.crossner_ai 60 | - crossner.crossner_music 61 | - crossner.crossner_politics 62 | - crossner.crossner_literature 63 | - crossner.crossner_natural_science 64 | max_examples_per_task_train: 30000 65 | max_examples_per_task_val: 5000 66 | max_examples_per_task_test: null 67 | max_seq_length: 2048 68 | generation_max_length: 2048 69 | ignore_pad_token_for_loss: true 70 | prompt_loss_weight: 0.0 71 | 72 | # checkpoint settings 73 | output_dir: /ikerlariak/osainz006/models/GoLLIE/GoLLIE+-13b_CodeLLaMA 74 | overwrite_output_dir: true 75 | load_best_model_at_end: false 76 | save_strategy: "epoch" 77 | save_steps: 1000 78 | save_total_limit: 999 79 | 80 | # evaluation 81 | do_train: false 82 | do_eval: false 83 | do_predict: true 84 | evaluation_strategy: "epoch" 85 | eval_steps: 500 86 | eval_delay: 0 87 | predict_with_generate: true 88 | evaluate_all_checkpoints: false 89 | 90 | # batch size 91 | per_device_train_batch_size: 16 92 | per_device_eval_batch_size: 8 93 | gradient_accumulation_steps: 2 94 | generation_num_beams: 1 95 | 96 | # optimizer settings 97 | optim: adamw_torch_fused 98 | learning_rate: 0.0003 99 | weight_decay: 0.0 100 | num_train_epochs: 3 101 | lr_scheduler_type: cosine 102 | warmup_ratio: 0.03 103 | adam_epsilon: 1e-7 104 | 105 | # lora settings 106 | lora_r: 8 107 | lora_alpha: 16 108 | lora_dropout: 0.05 109 | lora_target_modules: 110 | - all 111 | 112 | # reporting 113 | logging_strategy: steps 114 | logging_first_step: true 115 | logging_steps: 25 116 | report_to: wandb 117 | run_name: "GoLLIE+-13b_CodeLLaMA" 118 | disable_tqdm: false 119 | 120 | # hub settings 121 | push_to_hub: false 122 | resume_from_checkpoint: false 123 | 124 | # performance 125 | bf16: true 126 | fp16: false 127 | torch_compile: false 128 | ddp_find_unused_parameters: false 129 | -------------------------------------------------------------------------------- /configs/model_configs/eval/GoLLIE-34B_CodeLLaMA.yaml: -------------------------------------------------------------------------------- 1 | #Training args 2 | model_name_or_path: HiTZ/GoLLIE-34B 3 | torch_dtype: bfloat16 4 | use_lora: false 5 | quantization: 4 6 | quantization_inference: 4 7 | gradient_checkpointing: true 8 | force_auto_device_map: true 9 | max_memory_MB: 80000 10 | use_flash_attention: true 11 | 12 | # dataset arguments 13 | dataset_dir: 14 | /ikerlariak/osainz006/GoLLIE/data/processed_w_examples 15 | train_tasks: 16 | - ace05.eae 17 | - ace05.ee 18 | - ace05.ner 19 | - ace05.rc 20 | - ace05.re 21 | - ace05.ver 22 | - bc5cdr.ner 23 | - conll03.ner 24 | - diann.ner 25 | - ncbidisease.ner 26 | - ontonotes5.ner 27 | - rams.eae 28 | - tacred.sf 29 | - wnut17.ner 30 | validation_tasks: 31 | - ace05.ee 32 | - conll03.ner 33 | test_tasks: 34 | - ace05.eae 35 | - ace05.ee 36 | - ace05.ner 37 | - ace05.rc 38 | - ace05.re 39 | - ace05.ver 40 | - bc5cdr.ner 41 | - conll03.ner 42 | - diann.ner 43 | - ncbidisease.ner 44 | - ontonotes5.ner 45 | - rams.eae 46 | - tacred.sf 47 | - wikievents.eae 48 | - wikievents.ee 49 | - wikievents.ner 50 | - wnut17.ner 51 | - e3c.ner 52 | - broadtwitter.ner 53 | - fabner.ner 54 | - harveyner.ner 55 | - multinerd.ner 56 | - casie.eae 57 | - casie.ee 58 | - mitmovie.ner 59 | - mitrestaurant.ner 60 | - crossner.crossner_ai 61 | - crossner.crossner_music 62 | - crossner.crossner_politics 63 | - crossner.crossner_literature 64 | - crossner.crossner_natural_science 65 | max_examples_per_task_train: 30000 66 | max_examples_per_task_val: 5000 67 | max_examples_per_task_test: null 68 | max_seq_length: 2048 69 | generation_max_length: 2048 70 | ignore_pad_token_for_loss: true 71 | prompt_loss_weight: 0.0 72 | 73 | # checkpoint settings 74 | output_dir: /ikerlariak/osainz006/models/GoLLIE/GoLLIE+-34b_CodeLLaMA 75 | overwrite_output_dir: true 76 | load_best_model_at_end: false 77 | save_strategy: "epoch" 78 | save_steps: 1000 79 | save_total_limit: 999 80 | 81 | # evaluation 82 | do_train: false 83 | do_eval: false 84 | do_predict: true 85 | evaluation_strategy: "epoch" 86 | eval_steps: 500 87 | eval_delay: 0 88 | predict_with_generate: true 89 | evaluate_all_checkpoints: false 90 | 91 | # batch size 92 | per_device_train_batch_size: 8 93 | per_device_eval_batch_size: 4 94 | gradient_accumulation_steps: 2 # 2 FOR 2 GPUs, 4 FOR 1 GPU (32 efective batch size) 95 | generation_num_beams: 1 96 | 97 | # optimizer settings 98 | optim: adamw_torch_fused 99 | learning_rate: 0.0003 100 | weight_decay: 0.0 101 | num_train_epochs: 3 102 | lr_scheduler_type: cosine 103 | warmup_ratio: 0.03 104 | adam_epsilon: 1e-7 105 | 106 | # lora settings 107 | lora_r: 8 108 | lora_alpha: 16 109 | lora_dropout: 0.05 110 | lora_target_modules: 111 | - all 112 | 113 | # reporting 114 | logging_strategy: steps 115 | logging_first_step: true 116 | logging_steps: 25 117 | report_to: wandb 118 | run_name: "GoLLIE+-34b_CodeLLaMA" 119 | disable_tqdm: false 120 | 121 | # hub settings 122 | push_to_hub: false 123 | resume_from_checkpoint: false 124 | 125 | # performance 126 | bf16: true 127 | fp16: false 128 | torch_compile: false 129 | ddp_find_unused_parameters: false 130 | -------------------------------------------------------------------------------- /configs/model_configs/eval/GoLLIE-7B_CodeLLaMA_ablation_dropout.yaml: -------------------------------------------------------------------------------- 1 | #Training args 2 | model_name_or_path: codellama/CodeLlama-7b-hf 3 | lora_weights_name_or_path: /ikerlariak/osainz006/models/GoLLIE/GoLLIE+-7b_CodeLLaMA_abl_dropout 4 | torch_dtype: bfloat16 5 | use_lora: true 6 | quantization: 4 7 | quantization_inference: null 8 | gradient_checkpointing: true 9 | force_auto_device_map: false 10 | use_flash_attention: true 11 | 12 | # dataset arguments 13 | dataset_dir: 14 | /ikerlariak/osainz006/GoLLIE/data/processed_w_examples_abl_dropout 15 | train_tasks: 16 | - ace05.eae 17 | - ace05.ee 18 | - ace05.ner 19 | - ace05.rc 20 | - ace05.re 21 | - ace05.ver 22 | - bc5cdr.ner 23 | - conll03.ner 24 | - diann.ner 25 | - ncbidisease.ner 26 | - ontonotes5.ner 27 | - rams.eae 28 | - tacred.sf 29 | - wnut17.ner 30 | validation_tasks: 31 | - ace05.ee 32 | - conll03.ner 33 | test_tasks: 34 | - ace05.eae 35 | - ace05.ee 36 | - ace05.ner 37 | - ace05.rc 38 | - ace05.re 39 | - ace05.ver 40 | - bc5cdr.ner 41 | - conll03.ner 42 | - diann.ner 43 | - ncbidisease.ner 44 | - ontonotes5.ner 45 | - rams.eae 46 | - tacred.sf 47 | - wikievents.eae 48 | - wikievents.ee 49 | - wikievents.ner 50 | - wnut17.ner 51 | - e3c.ner 52 | - broadtwitter.ner 53 | - fabner.ner 54 | - harveyner.ner 55 | - multinerd.ner 56 | - casie.eae 57 | - casie.ee 58 | - mitmovie.ner 59 | - mitrestaurant.ner 60 | - crossner.crossner_ai 61 | - crossner.crossner_music 62 | - crossner.crossner_politics 63 | - crossner.crossner_literature 64 | - crossner.crossner_natural_science 65 | max_examples_per_task_train: 30000 66 | max_examples_per_task_val: 5000 67 | max_examples_per_task_test: null 68 | max_seq_length: 2048 69 | generation_max_length: 2048 70 | ignore_pad_token_for_loss: true 71 | prompt_loss_weight: 0.0 72 | 73 | # checkpoint settings 74 | output_dir: /ikerlariak/osainz006/models/GoLLIE/GoLLIE+-7b_CodeLLaMA_abl_dropout 75 | overwrite_output_dir: true 76 | load_best_model_at_end: false 77 | save_strategy: "epoch" 78 | save_steps: 1000 79 | save_total_limit: 999 80 | 81 | # evaluation 82 | do_train: false 83 | do_eval: false 84 | do_predict: true 85 | evaluation_strategy: "epoch" 86 | eval_steps: 500 87 | eval_delay: 0 88 | predict_with_generate: true 89 | evaluate_all_checkpoints: false 90 | 91 | # batch size 92 | per_device_train_batch_size: 32 93 | per_device_eval_batch_size: 8 94 | gradient_accumulation_steps: 1 95 | generation_num_beams: 1 96 | 97 | # optimizer settings 98 | optim: adamw_torch_fused 99 | learning_rate: 0.0003 100 | weight_decay: 0.0 101 | num_train_epochs: 3 102 | lr_scheduler_type: cosine 103 | warmup_ratio: 0.03 104 | adam_epsilon: 1e-7 105 | 106 | # lora settings 107 | lora_r: 8 108 | lora_alpha: 16 109 | lora_dropout: 0.05 110 | lora_target_modules: 111 | - all 112 | 113 | # reporting 114 | logging_strategy: steps 115 | logging_first_step: true 116 | logging_steps: 25 117 | report_to: wandb 118 | run_name: "GoLLIE+-7b_CodeLLaMA_abl_dropout" 119 | disable_tqdm: false 120 | 121 | # hub settings 122 | push_to_hub: false 123 | resume_from_checkpoint: false 124 | 125 | # performance 126 | bf16: true 127 | fp16: false 128 | torch_compile: false 129 | ddp_find_unused_parameters: false 130 | -------------------------------------------------------------------------------- /configs/model_configs/eval/GoLLIE-7B_CodeLLaMA_ablation_masking.yaml: -------------------------------------------------------------------------------- 1 | #Training args 2 | model_name_or_path: codellama/CodeLlama-7b-hf 3 | lora_weights_name_or_path: /ikerlariak/osainz006/models/GoLLIE/GoLLIE+-7b_CodeLLaMA_abl_masking 4 | torch_dtype: bfloat16 5 | use_lora: true 6 | quantization: 4 7 | quantization_inference: null 8 | gradient_checkpointing: true 9 | force_auto_device_map: false 10 | use_flash_attention: true 11 | 12 | # dataset arguments 13 | dataset_dir: 14 | /ikerlariak/osainz006/GoLLIE/data/processed_w_examples_abl_masking 15 | train_tasks: 16 | - ace05.eae 17 | - ace05.ee 18 | - ace05.ner 19 | - ace05.rc 20 | - ace05.re 21 | - ace05.ver 22 | - bc5cdr.ner 23 | - conll03.ner 24 | - diann.ner 25 | - ncbidisease.ner 26 | - ontonotes5.ner 27 | - rams.eae 28 | - tacred.sf 29 | - wnut17.ner 30 | validation_tasks: 31 | - ace05.ee 32 | - conll03.ner 33 | test_tasks: 34 | - ace05.eae 35 | - ace05.ee 36 | - ace05.ner 37 | - ace05.rc 38 | - ace05.re 39 | - ace05.ver 40 | - bc5cdr.ner 41 | - conll03.ner 42 | - diann.ner 43 | - ncbidisease.ner 44 | - ontonotes5.ner 45 | - rams.eae 46 | - tacred.sf 47 | - wikievents.eae 48 | - wikievents.ee 49 | - wikievents.ner 50 | - wnut17.ner 51 | - e3c.ner 52 | - broadtwitter.ner 53 | - fabner.ner 54 | - harveyner.ner 55 | - multinerd.ner 56 | - casie.eae 57 | - casie.ee 58 | - mitmovie.ner 59 | - mitrestaurant.ner 60 | - crossner.crossner_ai 61 | - crossner.crossner_music 62 | - crossner.crossner_politics 63 | - crossner.crossner_literature 64 | - crossner.crossner_natural_science 65 | max_examples_per_task_train: 30000 66 | max_examples_per_task_val: 5000 67 | max_examples_per_task_test: null 68 | max_seq_length: 2048 69 | generation_max_length: 2048 70 | ignore_pad_token_for_loss: true 71 | prompt_loss_weight: 0.0 72 | 73 | # checkpoint settings 74 | output_dir: /ikerlariak/osainz006/models/GoLLIE/GoLLIE+-7b_CodeLLaMA_abl_masking 75 | overwrite_output_dir: true 76 | load_best_model_at_end: false 77 | save_strategy: "epoch" 78 | save_steps: 1000 79 | save_total_limit: 999 80 | 81 | # evaluation 82 | do_train: false 83 | do_eval: false 84 | do_predict: true 85 | evaluation_strategy: "epoch" 86 | eval_steps: 500 87 | eval_delay: 0 88 | predict_with_generate: true 89 | evaluate_all_checkpoints: false 90 | 91 | # batch size 92 | per_device_train_batch_size: 32 93 | per_device_eval_batch_size: 8 94 | gradient_accumulation_steps: 1 95 | generation_num_beams: 1 96 | 97 | # optimizer settings 98 | optim: adamw_torch_fused 99 | learning_rate: 0.0003 100 | weight_decay: 0.0 101 | num_train_epochs: 3 102 | lr_scheduler_type: cosine 103 | warmup_ratio: 0.03 104 | adam_epsilon: 1e-7 105 | 106 | # lora settings 107 | lora_r: 8 108 | lora_alpha: 16 109 | lora_dropout: 0.05 110 | lora_target_modules: 111 | - all 112 | 113 | # reporting 114 | logging_strategy: steps 115 | logging_first_step: true 116 | logging_steps: 25 117 | report_to: wandb 118 | run_name: "GoLLIE+-7b_CodeLLaMA_abl_masking" 119 | disable_tqdm: false 120 | 121 | # hub settings 122 | push_to_hub: false 123 | resume_from_checkpoint: false 124 | 125 | # performance 126 | bf16: true 127 | fp16: false 128 | torch_compile: false 129 | ddp_find_unused_parameters: false 130 | -------------------------------------------------------------------------------- /configs/pharapharse_config/LlaMA2-Chat.yaml: -------------------------------------------------------------------------------- 1 | #Training args 2 | model_name_or_path: meta-llama/Llama-2-70b-chat-hf 3 | config_template: llama-2 4 | torch_dtype: bfloat16 5 | use_lora: false 6 | quantization: 4 7 | predict_with_generate: true 8 | do_predict: true 9 | per_device_eval_batch_size: 1 10 | use_flash_attention: true 11 | 12 | generation_args_json: /ikerlariak/igarcia945/CoLLIE/configs/pharapharse_config/generation_config.json 13 | output_dir: /ikerlariak/igarcia945/CoLLIE/paraphrase/Llama-2-70b-chat-hf 14 | 15 | 16 | # dataset arguments 17 | datasets: 18 | - ace05 19 | - rams 20 | - conll03 21 | - casie 22 | - tacred 23 | - ontonotes5 24 | - ncbidisease 25 | - bc5cdr 26 | - diann 27 | - wnut17 28 | - multinerd 29 | - wikievents 30 | - fabner 31 | - e3c 32 | - broadtwitter 33 | - harveyner 34 | - mitmovie 35 | - mitrestaurant 36 | - crossner 37 | 38 | language: en 39 | 40 | # reporting 41 | logging_strategy: steps 42 | logging_first_step: true 43 | logging_steps: 25 44 | report_to: none 45 | 46 | 47 | # hub settings 48 | push_to_hub: false 49 | resume_from_checkpoint: false 50 | 51 | # performance 52 | bf16: false 53 | fp16: false 54 | torch_compile: false 55 | ddp_find_unused_parameters: false -------------------------------------------------------------------------------- /configs/pharapharse_config/Vicunav1.3-33B.yaml: -------------------------------------------------------------------------------- 1 | #Training args 2 | model_name_or_path: lmsys/vicuna-33b-v1.3 3 | config_template: vicuna_v1.1 4 | torch_dtype: "auto" 5 | use_lora: false 6 | quantization: 4 7 | predict_with_generate: true 8 | do_predict: true 9 | per_device_eval_batch_size: 4 10 | 11 | generation_args_json: /ikerlariak/igarcia945/CoLLIE/configs/pharapharse_config/generation_config.json 12 | output_dir: /ikerlariak/igarcia945/CoLLIE/paraphrase/vicunav1.3-33b 13 | 14 | 15 | # dataset arguments 16 | datasets: 17 | - ace05 18 | - rams 19 | - conll03 20 | - casie 21 | - tacred 22 | - ontonotes5 23 | - ncbidisease 24 | - bc5cdr 25 | - diann 26 | - wnut17 27 | - multinerd 28 | - wikievents 29 | - fabner 30 | - e3c 31 | - broadtwitter 32 | - harveyner 33 | - mitmovie 34 | - mitrestaurant 35 | - crossner 36 | 37 | language: en 38 | 39 | # reporting 40 | logging_strategy: steps 41 | logging_first_step: true 42 | logging_steps: 25 43 | report_to: none 44 | 45 | 46 | # hub settings 47 | push_to_hub: false 48 | resume_from_checkpoint: false 49 | 50 | # performance 51 | bf16: false 52 | fp16: false 53 | torch_compile: false 54 | ddp_find_unused_parameters: false -------------------------------------------------------------------------------- /configs/pharapharse_config/generation_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_from_model_config": true, 3 | "max_new_tokens": 1024, 4 | "min_new_tokens": 4, 5 | "do_sample": true, 6 | "num_beams": 1, 7 | "temperature": 0.7, 8 | "top_k": 50, 9 | "top_p": 0.95, 10 | "num_return_sequences": 4 11 | } 12 | -------------------------------------------------------------------------------- /configs/pharapharse_config/gpt2.yaml: -------------------------------------------------------------------------------- 1 | #Training args 2 | model_name_or_path: EleutherAI/gpt-neo-125m 3 | config_template: vicuna_v1.1 4 | torch_dtype: "auto" 5 | use_lora: false 6 | quantization: null 7 | predict_with_generate: true 8 | do_predict: true 9 | per_device_eval_batch_size: 4 10 | 11 | generation_args_json: /home/ikergarcia/Documents/CoLLIE/configs/pharapharse_config/generation_config.json 12 | output_dir: /home/ikergarcia/Documents/CoLLIE/paraphrase/vicuna-13b 13 | 14 | 15 | # dataset arguments 16 | datasets: 17 | - ace05 18 | - rams 19 | - conll03 20 | - casie 21 | - tacred 22 | - ontonotes5 23 | - ncbidisease 24 | - bc5cdr 25 | - diann 26 | - wnut17 27 | - multinerd 28 | - wikievents 29 | - fabner 30 | - e3c 31 | - broadtwitter 32 | - harveyner 33 | - mitmovie 34 | - mitrestaurant 35 | - crossner 36 | language: en 37 | 38 | 39 | 40 | 41 | 42 | # reporting 43 | logging_strategy: steps 44 | logging_first_step: true 45 | logging_steps: 25 46 | report_to: none 47 | 48 | 49 | # hub settings 50 | push_to_hub: false 51 | resume_from_checkpoint: false 52 | 53 | # performance 54 | bf16: false 55 | fp16: false 56 | torch_compile: false 57 | ddp_find_unused_parameters: false -------------------------------------------------------------------------------- /docs/_layouts/default.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | {% seo %} 9 | 10 | {% include head-custom.html %} 11 | 12 | 13 |
14 | 15 | 16 | 17 | {{ content }} 18 | 19 | {% if site.github.private != true and site.github.license %} 20 | 23 | {% endif %} 24 |
25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /docs/assets/openai.svg: -------------------------------------------------------------------------------- 1 | 3 | 5 | OpenAI icon 6 | 7 | 9 | -------------------------------------------------------------------------------- /docs/assets/user.svg: -------------------------------------------------------------------------------- 1 | 3 | 4 | 5 | 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 119 3 | target-version = ['py37'] 4 | extend-exclude = ''' 5 | /( 6 | | data 7 | | venv 8 | )/ 9 | ''' 10 | preview = true 11 | 12 | [tool.ruff] 13 | # Never enforce `E501` (line length violations). 14 | ignore = ["C901", "E501", "E741", "W605"] 15 | select = ["C", "E", "F", "I", "W"] 16 | line-length = 119 17 | 18 | # Ignore import violations in all `__init__.py` files. 19 | [tool.ruff.per-file-ignores] 20 | "__init__.py" = ["E402", "F401", "F403", "F811"] 21 | "prompts.py" = ["F811"] 22 | 23 | [tool.ruff.isort] 24 | lines-after-imports = 2 25 | known-first-party = ["transformers"] -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from . import tasks 2 | 3 | 4 | __all__ = ["tasks"] 5 | -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitz-zentroa/GoLLIE/164c611743fdc1befe71bbdf03e08c5eb4e35957/src/dataset/__init__.py -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitz-zentroa/GoLLIE/164c611743fdc1befe71bbdf03e08c5eb4e35957/src/model/__init__.py -------------------------------------------------------------------------------- /src/model/patch_models/README.md: -------------------------------------------------------------------------------- 1 | Adapted from Open-Assistant by LAION-AI 2 | We have removed the code that was not useful for CoLLIE 3 | 4 | URL: https://github.com/LAION-AI/Open-Assistant/tree/main/model/model_training/models -------------------------------------------------------------------------------- /src/model/patch_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitz-zentroa/GoLLIE/164c611743fdc1befe71bbdf03e08c5eb4e35957/src/model/patch_models/__init__.py -------------------------------------------------------------------------------- /src/model/patch_models/patching_neox.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import transformers 5 | from src.model.patch_models.patching_utils import compute_flash_attention 6 | 7 | 8 | def neox_forward_with_flash_attn( 9 | self: transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention, 10 | flash_attn: nn.Module, # flash_attn.modules.mha.FlashSelfAttention 11 | query: torch.Tensor, 12 | key: torch.Tensor, 13 | value: torch.Tensor, 14 | attention_mask=None, 15 | head_mask=None, 16 | ): 17 | # query, key, value: [bs, num_attention_heads, seq_len, attn_head_size] 18 | if query.shape == key.shape: 19 | flash_attn.train(self.training) 20 | out_dtype = value.dtype 21 | q, k, v = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) 22 | if attention_mask is not None: 23 | attention_mask = attention_mask[:, 0, 0, :] 24 | out = compute_flash_attention(flash_attn, q, k, v, attention_mask) 25 | out = out.transpose(1, 2).to(out_dtype) 26 | return out, None 27 | else: 28 | return self.old_forward(query, key, value, attention_mask, head_mask) 29 | -------------------------------------------------------------------------------- /src/model/patch_models/patching_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def compute_flash_attention(flash_attn, q, k, v, attention_mask=None, head_mask=None): 6 | # q, k, v: [bs, seq_len, num_attention_heads, attn_head_size] 7 | # attention_mask (float): [bs, seq_len] 8 | batch_size, max_len = q.size(0), q.size(1) 9 | 10 | qkv = torch.stack([q, k, v], dim=2).to(torch.float16) # need to truncate in case input is fp32 11 | cu_seqlens, max_seqlen = None, None 12 | 13 | if attention_mask is None: 14 | return flash_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) 15 | else: 16 | # Limitation: non-contiguous attention mask will not be handled correctly 17 | # model will be able to pay attention between the first and last non-masked token, i.e. left- and right-side padding is supported. 18 | csums = (attention_mask >= 0).cumsum(dim=1) 19 | ends = csums.argmax(dim=1) + 1 20 | starts = ends - csums.max(dim=1).values 21 | seqlens = ends - starts 22 | 23 | qkv = torch.cat([qkv[i, starts[i] : ends[i]] for i in range(batch_size)], dim=0) 24 | zero = torch.zeros_like(seqlens[:1]) # torch.tensor([0]) with correct dtype and device 25 | cu_seqlens = torch.cat([zero, seqlens.cumsum(dim=0)], dim=0).to(torch.int32) 26 | max_seqlen = seqlens.max().item() 27 | 28 | out = flash_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) 29 | # out: [num_unmasked_tokens, num_attention_heads, attn_head_size] 30 | 31 | seqs = [out[start:end] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])] 32 | # stack and pad sequences together 33 | padded_seqs = [ 34 | F.pad(seqs[i], (0, 0) * (seqs[i].dim() - 1) + (starts[i], max_len - ends[i]), value=0.0) 35 | for i in range(batch_size) 36 | ] 37 | out = torch.stack(padded_seqs) 38 | return out 39 | 40 | 41 | if __name__ == "__main__": 42 | from flash_attn.modules.mha import FlashSelfAttention 43 | 44 | flash_attn = FlashSelfAttention(causal=True) 45 | 46 | dtype = torch.float16 47 | device = torch.device("cuda:0") 48 | 49 | batch_size, seq_len, num_heads, head_size = 4, 18, 8, 32 50 | q = torch.randn(batch_size, seq_len, num_heads, head_size, dtype=dtype, device=device) 51 | k = torch.randn(batch_size, seq_len, num_heads, head_size, dtype=dtype, device=device) 52 | v = torch.randn(batch_size, seq_len, num_heads, head_size, dtype=dtype, device=device) 53 | 54 | attn_mask = torch.randn(batch_size, seq_len, dtype=dtype, device=device).abs().cumsum(dim=1) 55 | attn_mask = ((attn_mask > 3) & (attn_mask < 10)).int().log() 56 | 57 | out = compute_flash_attention(flash_attn, q, k, v, attention_mask=attn_mask) 58 | -------------------------------------------------------------------------------- /src/paraphrase/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitz-zentroa/GoLLIE/164c611743fdc1befe71bbdf03e08c5eb4e35957/src/paraphrase/__init__.py -------------------------------------------------------------------------------- /src/paraphrase/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List 3 | 4 | 5 | @dataclass 6 | class DataInferenceArguments: 7 | """ 8 | Arguments pertaining to what data we are going to input our model for paraphrasing. 9 | """ 10 | 11 | datasets: List[str] = field( 12 | default=None, 13 | metadata={"help": "The tasks to train on. Can be a list of tasks or a single task."}, 14 | ) 15 | 16 | config_template: str = field( 17 | default="vicuna_v1.1", 18 | metadata={ 19 | "help": ( 20 | "The config template to use. Available templates: 'one_shot', 'vicuna_v1.1', 'koala_v1', 'dolly_v2'," 21 | " 'oasst_pythia', 'stablelm', 'baize', 'rwkv', 'openbuddy', 'phoenix', 'chatgpt', 'claude', 'mpt'" 22 | ) 23 | }, 24 | ) 25 | 26 | language: str = field( 27 | default="en", 28 | metadata={"help": "The language to do phrase paraphrasing."}, 29 | ) 30 | 31 | generation_args_json: str = field( 32 | default=None, 33 | metadata={"help": "The generation args json file."}, 34 | ) 35 | -------------------------------------------------------------------------------- /src/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitz-zentroa/GoLLIE/164c611743fdc1befe71bbdf03e08c5eb4e35957/src/scripts/__init__.py -------------------------------------------------------------------------------- /src/scripts/compare_class_scores.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from typing import Dict, List 5 | 6 | 7 | def compare_class_scores(model_paths: str, output_path: str): 8 | """ 9 | Compare the class scores of multiple models and save the results to a CSV file. 10 | 11 | Args: 12 | model_paths (`List[str]`): 13 | List of paths to the models to compare. 14 | output_path (`str`): 15 | Path to the output CSV file. 16 | """ 17 | class_scores: Dict[str, Dict[str, List[str]]] = {} 18 | for model_path in model_paths: 19 | task_scores_json = os.path.join(model_path, "task_scores.json") 20 | if not os.path.exists(task_scores_json): 21 | raise FileNotFoundError(f"task_scores.json not found in {model_path}") 22 | with open(task_scores_json, "r", encoding="utf8") as f: 23 | task_scores = json.load(f) 24 | for dataset, values in task_scores.items(): 25 | for task, values in values.items(): 26 | if "class_scores" in values: 27 | name = f"{dataset}.{task}" 28 | if name not in class_scores: 29 | class_scores[name] = {} 30 | 31 | for label, score in values["class_scores"].items(): 32 | if label not in class_scores[name]: 33 | class_scores[name][label] = [] 34 | class_scores[name][label].append(str(score["f1-score"])) 35 | 36 | # Convert dict into CSV 37 | with open(output_path, "w", encoding="utf8") as f: 38 | print(f"Label,{','.join(model_paths)}", file=f) 39 | for name, scores in class_scores.items(): 40 | print(name, file=f) 41 | for label, score in scores.items(): 42 | print(f"{label},{','.join(score)}", file=f) 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument("--model_paths", type=str, nargs="+", required=True, help="Paths to the models to compare.") 48 | parser.add_argument("--output_path", type=str, required=True, help="Path to the output CSV file.") 49 | args = parser.parse_args() 50 | compare_class_scores(args.model_paths, args.output_path) 51 | -------------------------------------------------------------------------------- /src/scripts/plot_results.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import seaborn as sns 4 | 5 | 6 | sns.set_theme(style="whitegrid") 7 | 8 | # PromptNER https://arxiv.org/pdf/2305.15444.pdf 9 | GPT3_5 = [ 10 | 00.00, # Movie 11 | 00.00, # Restaurant 12 | 20.30, # Politics 13 | 31.30, # Literature 14 | 24.50, # Music 15 | 40.70, # AI 16 | 40.60, # Science 17 | ] 18 | 19 | # Instruct UIE https://arxiv.org/pdf/2304.08085.pdf 20 | InstructUIE = [ 21 | 63.00, # Movie 22 | 20.99, # Restaurant 23 | 49.00, # Politics 24 | 47.21, # Literature 25 | 53.16, # Music 26 | 48.15, # AI 27 | 49.30, # Science 28 | ] 29 | 30 | CoLLIE = [ 31 | 63.0, # Movie 32 | 43.4, # Restaurant 33 | 57.2, # Politics 34 | 62.7, # Literature 35 | 67.8, # Music 36 | 59.1, # AI 37 | 55.5, # Science 38 | ] 39 | 40 | 41 | def main(): 42 | fig, ax = plt.subplots(1, 7, figsize=(12, 4), sharey=True, layout="constrained") 43 | 44 | TASK_NAMES = ["Movie", "Restaurant", "Politics", "Literature", "Music", "AI", "Science"] 45 | 46 | for i, (gpt, iuie, collie, name) in enumerate(zip(GPT3_5, InstructUIE, CoLLIE, TASK_NAMES)): 47 | rect = ax[i].bar( 48 | [1], 49 | [np.round(gpt)], 50 | width=1.0, 51 | label="GPT-3", 52 | # color="#a40e26", 53 | color=sns.color_palette("crest", 3)[0], 54 | hatch="//", 55 | ) 56 | if gpt: 57 | ax[i].bar_label(rect, padding=3, fontsize=12) 58 | 59 | rect = ax[i].bar( 60 | [2], 61 | [np.round(iuie)], 62 | width=1.0, 63 | label="Instruct-UIE", 64 | # color="#6639ba", 65 | color=sns.color_palette("crest", 3)[1], 66 | hatch="/", 67 | ) 68 | ax[i].bar_label(rect, padding=3, fontsize=12) 69 | 70 | rect = ax[i].bar( 71 | [3], 72 | [np.round(collie)], 73 | width=1.0, 74 | label="CoLLIE", 75 | # color="#0a3069" 76 | color=sns.color_palette("crest", 3)[2], 77 | ) 78 | ax[i].bar_label(rect, padding=3, fontsize=12) 79 | 80 | ax[i].set_xticks([1, 2, 3]) 81 | ax[i].set_yticklabels([]) 82 | ax[i].set_xticklabels(["", name, ""], fontsize=14) 83 | # ax[i].set_title(name) 84 | ax[i].grid(False) 85 | ax[i].spines["top"].set_visible(False) 86 | ax[i].spines["right"].set_visible(False) 87 | ax[i].spines["bottom"].set_visible(False) 88 | ax[i].spines["left"].set_visible(False) 89 | 90 | fig.legend(["GPT-3.5", "SOTA", "GoLLIE"], loc="outside upper center", ncol=3, fontsize=14, frameon=False) 91 | # ax[3].legend(["GPT-3", "Instruct-UIE", "CoLLIE"], fontsize=12, ncol=3, bbox_to_anchor=(1.00, 1.15), loc="lower center") 92 | 93 | # plt.tight_layout() 94 | plt.savefig("assets/plots/zero_shot_results.pdf", dpi=300) 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /src/scripts/visualize_example.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from argparse import ArgumentParser 4 | 5 | from src.tasks import TASK_ID_TO_TASKS 6 | 7 | 8 | def main(args): 9 | # Create the directory if not exists 10 | os.makedirs(args.output_dir, exist_ok=True) 11 | 12 | with open(args.input_file) as f: 13 | lines = [json.loads(line) for line in f] 14 | 15 | basename = os.path.basename(args.input_file).rstrip(".jsonl") 16 | output_path = os.path.join(args.output_dir, f"{basename}.{args.row}.py") 17 | with open(output_path, "wt", encoding="utf-8") as f: 18 | line = lines[args.row] 19 | imports = TASK_ID_TO_TASKS[line["task_id"]] 20 | 21 | print(f"from {imports} import *", file=f) 22 | print("from src.tasks.utils_typing import Entity, Value, Relation, Event", file=f) 23 | print("from dataclasses import dataclass", file=f) 24 | print(line["text"], file=f) 25 | print(f"labels = {line['labels']}", file=f) 26 | 27 | os.system(f"black {output_path}") 28 | os.system(f"ruff check {output_path} --fix") 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = ArgumentParser("Visualize examples") 33 | 34 | parser.add_argument("-i", "--input_file", dest="input_file", type=str) 35 | parser.add_argument("-r", "--row", dest="row", type=int, default=0) 36 | parser.add_argument("-o", "--output_dir", dest="output_dir", default=".ignore/examples") 37 | 38 | args = parser.parse_args() 39 | main(args) 40 | -------------------------------------------------------------------------------- /src/tasks/ace/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader, prompts, scorer 2 | 3 | 4 | __all__ = ["data_loader", "prompts", "scorer"] 5 | -------------------------------------------------------------------------------- /src/tasks/ace/scorer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Type 2 | 3 | from src.tasks.ace.prompts import ( 4 | COARSE_EVENT_DEFINITIONS, 5 | COARSE_RELATION_DEFINITIONS, 6 | ENTITY_DEFINITIONS, 7 | EVENT_DEFINITIONS, 8 | RELATION_DEFINITIONS, 9 | VALUE_DEFINITIONS, 10 | ) 11 | from src.tasks.utils_scorer import EventScorer, RelationScorer, SpanScorer 12 | from src.tasks.utils_typing import Entity, Value 13 | 14 | 15 | class ACEEntityScorer(SpanScorer): 16 | """ACE Entity identification and classification scorer.""" 17 | 18 | valid_types: List[Type] = ENTITY_DEFINITIONS 19 | 20 | def __call__(self, reference: List[Entity], predictions: List[Entity]) -> Dict[str, Dict[str, float]]: 21 | output = super().__call__(reference, predictions) 22 | return {"entities": output["spans"]} 23 | 24 | 25 | class ACEValueScorer(SpanScorer): 26 | """ACE Values identification and classification scorer.""" 27 | 28 | valid_types: List[Type] = VALUE_DEFINITIONS 29 | 30 | def __call__(self, reference: List[Value], predictions: List[Value]) -> Dict[str, Dict[str, float]]: 31 | output = super().__call__(reference, predictions) 32 | return {"values": output["spans"]} 33 | 34 | 35 | class ACECoarseRelationScorer(RelationScorer): 36 | """ACE Relation identification scorer.""" 37 | 38 | valid_types: List[Type] = COARSE_RELATION_DEFINITIONS 39 | 40 | 41 | class ACERelationScorer(RelationScorer): 42 | """ACE Relation identification scorer.""" 43 | 44 | valid_types: List[Type] = RELATION_DEFINITIONS 45 | 46 | 47 | class ACEEventScorer(EventScorer): 48 | """ACE Event and argument classification scorer.""" 49 | 50 | valid_types: List[Type] = COARSE_EVENT_DEFINITIONS 51 | 52 | 53 | class ACEEventArgumentScorer(EventScorer): 54 | """ACE Event and argument classification scorer.""" 55 | 56 | valid_types: List[Type] = EVENT_DEFINITIONS 57 | -------------------------------------------------------------------------------- /src/tasks/bc5cdr/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader, prompts, scorer 2 | 3 | 4 | __all__ = ["data_loader", "prompts", "scorer"] 5 | -------------------------------------------------------------------------------- /src/tasks/bc5cdr/guidelines_gold.py: -------------------------------------------------------------------------------- 1 | GUIDELINES = { 2 | "bc5cdr_disease": { 3 | "en": [ 4 | "A disease is a particular abnormal condition that negatively affects the structure or function of all or" 5 | " part of an organism, and that is not immediately due to any external injury. Disease mentions, include" 6 | " Specific Disease (i.e., clear-cell renal cell carcinoma), Disease Class (i.e., cystic kidney diseases)," 7 | " Composite mentions (i.e., prostatic, pancreas, skin, and lung cancer), and Modifier (i.e., hereditary" 8 | " breast cancer families). " 9 | ] 10 | }, 11 | "bc5cdr_chemical": { 12 | "en": [ 13 | ( 14 | "The basic rule for chemical entity annotation is that the chemical should have a specific structure." 15 | " Chemicals which should be annotated are listed as follows: 1) Chemical Nouns convertible to: - A" 16 | " single chemical structure diagram: single atoms, ions, isotopes, pure elements and molecules such" 17 | " as: Calcium (Ca), Iron (Fe), Lithium (Li), Potassium (K), Oxygen (O2), - A general Markush diagram" 18 | " with R groups such as: Amino acids 2) General class names where the definition of the class includes" 19 | " information on some structural or elemental composition such as: steroids, sugars, fatty acids," 20 | " saturated fatty acids … 3) Small Biochemicals - Monosaccharides, disaccharides and trisaccharides:" 21 | " Glucose, Sucrose … - Peptides and proteins with less than 15 aminoacids: Angiotensin II … -" 22 | " Monomers, dimmers, trimmers of nucleotides: e.g. ATP, cAMP … - Fatty acids and their derivatives" 23 | " excluding polymeric structures. e.g. Cholesterol, glycerol, prostaglandin E1 … 4) Synthetic Polymers" 24 | " such as: Polyethylene glycol 5) Special chemicals having well-defined chemical compositions. E.g." 25 | " “ethanolic extract of Daucus carota seeds (DCE)” in 16755009; “grape seed proanthocyanidin extract”" 26 | " in 11334364." 27 | ), 28 | ] 29 | }, 30 | } 31 | 32 | EXAMPLES = { 33 | "bc5cdr_disease_examples": { 34 | "en": [ 35 | "toxicity", 36 | "pain", 37 | "hypotension", 38 | "proteinuria", 39 | "seizures", 40 | "hypertension", 41 | "seizure", 42 | "myocardial infarction", 43 | "hepatitis", 44 | "bradycardia", 45 | ] 46 | }, 47 | "bc5cdr_chemical_examples": { 48 | "en": [ 49 | "cocaine", 50 | "dopamine", 51 | "morphine", 52 | "nicotine", 53 | "lithium", 54 | "haloperidol", 55 | "clonidine", 56 | "creatinine", 57 | "cisplatin", 58 | "lidocaine", 59 | ] 60 | }, 61 | } 62 | -------------------------------------------------------------------------------- /src/tasks/bc5cdr/prompts.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..utils_typing import Entity, dataclass 4 | 5 | 6 | """Entity definitions 7 | 8 | The entity definitions are derived from the official BC5DR guidelines: 9 | https://biocreative.bioinformatics.udel.edu/media/store/files/2015/bc5_CDR_data_guidelines.pdf 10 | 11 | 12 | """ 13 | 14 | 15 | @dataclass 16 | class Disease(Entity): 17 | """{bc5cdr_disease}""" 18 | 19 | span: str # {bc5cdr_disease_examples} 20 | 21 | 22 | @dataclass 23 | class Chemical(Entity): 24 | """{bc5cdr_chemical}""" 25 | 26 | span: str # {bc5cdr_chemical_examples} 27 | 28 | 29 | ENTITY_DEFINITIONS: List[Entity] = [Disease, Chemical] 30 | 31 | # __all__ = list(map(str, [*ENTITY_DEFINITIONS])) 32 | -------------------------------------------------------------------------------- /src/tasks/bc5cdr/scorer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Type 2 | 3 | from src.tasks.bc5cdr.prompts import ENTITY_DEFINITIONS 4 | from src.tasks.utils_scorer import SpanScorer 5 | from src.tasks.utils_typing import Entity 6 | 7 | 8 | class Bc5cdrEntityScorer(SpanScorer): 9 | """CoNLL03 Entity identification and classification scorer.""" 10 | 11 | valid_types: List[Type] = ENTITY_DEFINITIONS 12 | 13 | def __call__(self, reference: List[Entity], predictions: List[Entity]) -> Dict[str, Dict[str, float]]: 14 | output = super().__call__(reference, predictions) 15 | return {"entities": output["spans"]} 16 | -------------------------------------------------------------------------------- /src/tasks/broadtwitter/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader, prompts, scorer 2 | 3 | 4 | __all__ = ["data_loader", "prompts", "scorer"] 5 | -------------------------------------------------------------------------------- /src/tasks/broadtwitter/guidelines.py: -------------------------------------------------------------------------------- 1 | GUIDELINES = { 2 | "ner_person": {"en": ["first, middle and last names of people, animals and fictional characters aliases."]}, 3 | "ner_organization": { 4 | "en": [ 5 | "Companies (press agencies, studios, banks, stock\n" 6 | "markets, manufacturers, cooperatives) subdivisions of companies (newsrooms)\n" 7 | "brands political movements (political parties, terrorist organisations)\n" 8 | "government bodies (ministries, councils, courts, political unions of countries\n" 9 | "(e.g. the {\\it U.N.})) publications (magazines, newspapers, journals)\n" 10 | "musical companies (bands, choirs, opera companies, orchestras\n" 11 | "public organisations (schools, universities, charities other collections\n" 12 | "of people (sports clubs, sports teams, associations, theaters companies,\n" 13 | "religious orders, youth organisations.\n" 14 | ] 15 | }, 16 | "ner_location": { 17 | "en": [ 18 | "Roads (streets, motorways) trajectories regions (villages, towns, cities, provinces,\n" 19 | "countries, continents, dioceses, parishes) structures (bridges, ports, dams) natural locations\n" 20 | "(mountains, mountain ranges, woods, rivers, wells, fields, valleys, gardens, nature reserves,\n" 21 | "allotments, beaches, national parks) public places (squares, opera houses, museums, schools, markets,\n" 22 | "airports, stations, swimming pools, hospitals, sports facilities, youth centers, parks, town halls,\n" 23 | "theaters, cinemas, galleries, camping grounds, NASA launch pads, club houses, universities, libraries,\n" 24 | "churches, medical centers, parking lots, playgrounds, cemeteries) commercial places (chemists, pubs,\n" 25 | "restaurants, depots, hostels, hotels, industrial parks, nightclubs, music venues) assorted buildings\n" 26 | "(houses, monasteries, creches, mills, army barracks, castles, retirement homes, towers, halls, rooms,\n" 27 | "vicarages, courtyards) abstract ``places'' (e.g. {\\it the free world})\n" 28 | ] 29 | }, 30 | } 31 | -------------------------------------------------------------------------------- /src/tasks/broadtwitter/guidelines_gold.py: -------------------------------------------------------------------------------- 1 | GUIDELINES = { 2 | "ner_person": {"en": ["first, middle and last names of people, animals and fictional characters aliases."]}, 3 | "ner_organization": { 4 | "en": [ 5 | "Companies (press agencies, studios, banks, stock\n" 6 | "markets, manufacturers, cooperatives) subdivisions of companies (newsrooms)\n" 7 | "brands political movements (political parties, terrorist organisations)\n" 8 | "government bodies (ministries, councils, courts, political unions of countries\n" 9 | "(e.g. the {\\it U.N.})) publications (magazines, newspapers, journals)\n" 10 | "musical companies (bands, choirs, opera companies, orchestras\n" 11 | "public organisations (schools, universities, charities other collections\n" 12 | "of people (sports clubs, sports teams, associations, theaters companies,\n" 13 | "religious orders, youth organisations.\n" 14 | ] 15 | }, 16 | "ner_location": { 17 | "en": [ 18 | "Roads (streets, motorways) trajectories regions (villages, towns, cities, provinces,\n" 19 | "countries, continents, dioceses, parishes) structures (bridges, ports, dams) natural locations\n" 20 | "(mountains, mountain ranges, woods, rivers, wells, fields, valleys, gardens, nature reserves,\n" 21 | "allotments, beaches, national parks) public places (squares, opera houses, museums, schools, markets,\n" 22 | "airports, stations, swimming pools, hospitals, sports facilities, youth centers, parks, town halls,\n" 23 | "theaters, cinemas, galleries, camping grounds, NASA launch pads, club houses, universities, libraries,\n" 24 | "churches, medical centers, parking lots, playgrounds, cemeteries) commercial places (chemists, pubs,\n" 25 | "restaurants, depots, hostels, hotels, industrial parks, nightclubs, music venues) assorted buildings\n" 26 | "(houses, monasteries, creches, mills, army barracks, castles, retirement homes, towers, halls, rooms,\n" 27 | "vicarages, courtyards) abstract ``places'' (e.g. {\\it the free world})\n" 28 | ] 29 | }, 30 | } 31 | EXAMPLES = { 32 | "ner_person_examples": { 33 | "en": [ 34 | "Obama", 35 | "President Obama", 36 | "@ SimoLove", 37 | "Kate Middleton", 38 | "Putin", 39 | "Prince William", 40 | "Cameron", 41 | "Harper", 42 | "@ RossMarowits", 43 | "James Foley", 44 | ] 45 | }, 46 | "ner_location_examples": { 47 | "en": ["UK", "Ukraine", "US", "U . S .", "Iraq", "Canada", "London", "Russia", "Australia", "Ontario"] 48 | }, 49 | "ner_organization_examples": { 50 | "en": [ 51 | "@ Independent", 52 | "Irish News", 53 | "Malaysia Airlines", 54 | "Twitter", 55 | "BBC", 56 | "Apple", 57 | "Getty", 58 | "Isis", 59 | "@ BBCBreaking", 60 | "Liverpool", 61 | ] 62 | }, 63 | } 64 | -------------------------------------------------------------------------------- /src/tasks/broadtwitter/prompts.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..utils_typing import Entity, dataclass 4 | 5 | 6 | """Entity definitions 7 | 8 | The entity definitions are derived from the official ConLL2003 guidelines: 9 | https://www.clips.uantwerpen.be/conll2003/ner/ 10 | Based on: Nancy Chinchor, Erica Brown, Lisa Ferro, Patty Robinson, 11 | "1999 Named Entity Task Definition". MITRE and SAIC, 1999. 12 | """ 13 | 14 | 15 | @dataclass 16 | class Person(Entity): 17 | """{ner_person}""" 18 | 19 | span: str # {ner_person_examples} 20 | 21 | 22 | @dataclass 23 | class Organization(Entity): 24 | """{ner_organization}""" 25 | 26 | span: str # {ner_organization_examples} 27 | 28 | 29 | @dataclass 30 | class Location(Entity): 31 | """{ner_location}""" 32 | 33 | span: str # {ner_location_examples} 34 | 35 | 36 | @dataclass 37 | class Miscellaneous(Entity): 38 | """{ner_miscellaneous}""" 39 | 40 | span: str # {ner_miscellaneous_examples} 41 | 42 | 43 | ENTITY_DEFINITIONS: List[Entity] = [ 44 | Person, 45 | Organization, 46 | Location, 47 | ] 48 | 49 | 50 | # __all__ = list(map(str, [*ENTITY_DEFINITIONS])) 51 | -------------------------------------------------------------------------------- /src/tasks/broadtwitter/scorer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Type 2 | 3 | from src.tasks.broadtwitter.prompts import ENTITY_DEFINITIONS 4 | from src.tasks.utils_scorer import SpanScorer 5 | from src.tasks.utils_typing import Entity 6 | 7 | 8 | class BroadTwitterEntityScorer(SpanScorer): 9 | """CoNLL03 Entity identification and classification scorer.""" 10 | 11 | valid_types: List[Type] = ENTITY_DEFINITIONS 12 | 13 | def __call__(self, reference: List[Entity], predictions: List[Entity]) -> Dict[str, Dict[str, float]]: 14 | output = super().__call__(reference, predictions) 15 | return {"entities": output["spans"]} 16 | -------------------------------------------------------------------------------- /src/tasks/casie/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader, prompts_eae, prompts_ed, scorer 2 | 3 | 4 | __all__ = ["data_loader", "prompts_ed", "prompts_eae", "scorer"] 5 | -------------------------------------------------------------------------------- /src/tasks/casie/guidelines.py: -------------------------------------------------------------------------------- 1 | GUIDELINES = { 2 | "databreach_attack_main": { 3 | "en": [ 4 | "An DatabreachAttack Event happens when an attacker compromises a system\n" 5 | " to later remove or expose the data, e.g., to sell, publish or make it accessible." 6 | ] 7 | }, 8 | "databreach_attack_mention": { 9 | "en": ["The text span that triggers the event, such as:\n 'attack', 'expose', 'publish', 'steal', ..."] 10 | }, 11 | "phising_attack_main": { 12 | "en": [ 13 | "A PhisingAttack Event happens when an attacker imitates another entity, in\n" 14 | " an attempt to get a victim to access malicious materials, such as a website or\n" 15 | " attachments." 16 | ] 17 | }, 18 | "phising_attack_mention": { 19 | "en": [ 20 | "The text span that triggers the event, such as:\n" 21 | " 'attack', 'purports to be', 'dupe', ...\n" 22 | " 'masquerading as', 'pretending to be', 'scam', ..." 23 | ] 24 | }, 25 | "ransom_attack_main": { 26 | "en": [ 27 | "A RansomAttack Event happens when n attacker breaks into a system and\n" 28 | " encrypts data, and will only decrypt the data for a ransom payment." 29 | ] 30 | }, 31 | "ransom_attack_mention": { 32 | "en": ["The text span that triggers the event, such as:\n 'attack', ransomware', 'selling', 'ransom', ..."] 33 | }, 34 | "vulnerability_discover_main": { 35 | "en": [ 36 | "A VulnerabilityDiscover Event happens when a security expert or other entity,\n" 37 | " like a company, finds a software vulnerability." 38 | ] 39 | }, 40 | "vulnerability_discover_mention": { 41 | "en": [ 42 | "The text span that triggers the event, such as:\n 'attack', 'found', 'exploit', 'vulnerability', ..." 43 | ] 44 | }, 45 | "vulnerability_patch_main": { 46 | "en": [ 47 | "A VulnerabiltyPatch Event happens when software company addresses a known\n" 48 | " vulnerability by releasing or describing an appropriate update." 49 | ] 50 | }, 51 | "vulnerability_patch_mention": { 52 | "en": [ 53 | "The text span that triggers the event, such as:\n" 54 | " 'patch', 'fixed', 'addresses', 'implemented',\n" 55 | " 'released', ..." 56 | ] 57 | }, 58 | } 59 | -------------------------------------------------------------------------------- /src/tasks/casie/guidelines_gold.py: -------------------------------------------------------------------------------- 1 | GUIDELINES = { 2 | "databreach_attack_main": { 3 | "en": [ 4 | "An DatabreachAttack Event happens when an attacker compromises a system\n" 5 | " to later remove or expose the data, e.g., to sell, publish or make it accessible." 6 | ] 7 | }, 8 | "databreach_attack_mention": { 9 | "en": ["The text span that triggers the event, such as:\n 'attack', 'expose', 'publish', 'steal', ..."] 10 | }, 11 | "phising_attack_main": { 12 | "en": [ 13 | "A PhisingAttack Event happens when an attacker imitates another entity, in\n" 14 | " an attempt to get a victim to access malicious materials, such as a website or\n" 15 | " attachments." 16 | ] 17 | }, 18 | "phising_attack_mention": { 19 | "en": [ 20 | "The text span that triggers the event, such as:\n" 21 | " 'attack', 'purports to be', 'dupe', ...\n" 22 | " 'masquerading as', 'pretending to be', 'scam', ..." 23 | ] 24 | }, 25 | "ransom_attack_main": { 26 | "en": [ 27 | "A RansomAttack Event happens when n attacker breaks into a system and\n" 28 | " encrypts data, and will only decrypt the data for a ransom payment." 29 | ] 30 | }, 31 | "ransom_attack_mention": { 32 | "en": ["The text span that triggers the event, such as:\n 'attack', ransomware', 'selling', 'ransom', ..."] 33 | }, 34 | "vulnerability_discover_main": { 35 | "en": [ 36 | "A VulnerabilityDiscover Event happens when a security expert or other entity,\n" 37 | " like a company, finds a software vulnerability." 38 | ] 39 | }, 40 | "vulnerability_discover_mention": { 41 | "en": [ 42 | "The text span that triggers the event, such as:\n 'attack', 'found', 'exploit', 'vulnerability', ..." 43 | ] 44 | }, 45 | "vulnerability_patch_main": { 46 | "en": [ 47 | "A VulnerabiltyPatch Event happens when software company addresses a known\n" 48 | " vulnerability by releasing or describing an appropriate update." 49 | ] 50 | }, 51 | "vulnerability_patch_mention": { 52 | "en": [ 53 | "The text span that triggers the event, such as:\n" 54 | " 'patch', 'fixed', 'addresses', 'implemented',\n" 55 | " 'released', ..." 56 | ] 57 | }, 58 | } 59 | 60 | EXAMPLES = { 61 | "databreach_attack_examples": {"en": ["attack", "stole", "publish", "steal"]}, 62 | "phising_attack_examples": { 63 | "en": ["attack", "purports to be", "dupe", "masquerading as", "pretending to be", "scam"] 64 | }, 65 | "ransom_attack_examples": {"en": ["attack", "ransomware", "selling", "ransom"]}, 66 | "vulnerability_discover_examples": {"en": ["attack", "found", "exploit", "vulnerability"]}, 67 | "vulnerability_patch_examples": {"en": ["patch", "fixed", "addresses", "implemented", "released"]}, 68 | } 69 | -------------------------------------------------------------------------------- /src/tasks/casie/scorer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Type 2 | 3 | from typing_extensions import override 4 | 5 | from src.tasks.casie.prompts_eae import EAE_EVENT_DEFINITIONS 6 | from src.tasks.casie.prompts_ed import ED_EVENT_DEFINITIONS 7 | from src.tasks.utils_scorer import EventScorer 8 | 9 | 10 | class CASIEEventScorer(EventScorer): 11 | """CASIE Event and argument classification scorer.""" 12 | 13 | valid_types: List[Type] = ED_EVENT_DEFINITIONS 14 | 15 | def __init__(self, allow_partial_match: bool = True) -> None: 16 | super().__init__() 17 | 18 | self.allow_partial_match: bool = allow_partial_match 19 | 20 | @override 21 | def __call__(self, reference: Any, predictions: Any) -> Dict[str, Dict[str, float]]: 22 | if not len(reference) or (len(reference) and not isinstance(reference[0], list)): 23 | reference = [reference] 24 | if not len(predictions) or (len(predictions) and not isinstance(predictions[0], list)): 25 | predictions = [predictions] 26 | 27 | for ref in reference: 28 | for ref in ref: 29 | ref._allow_partial_match = self.allow_partial_match 30 | for pred in predictions: 31 | for pred in pred: 32 | pred._allow_partial_match = self.allow_partial_match 33 | 34 | return super().__call__(reference, predictions) 35 | 36 | 37 | class CASIEEventArgumentScorer(EventScorer): 38 | """CASIE Event and argument classification scorer.""" 39 | 40 | valid_types: List[Type] = EAE_EVENT_DEFINITIONS 41 | 42 | def __init__(self, allow_partial_match: bool = True) -> None: 43 | super().__init__() 44 | 45 | self.allow_partial_match: bool = allow_partial_match 46 | 47 | @override 48 | def __call__(self, reference: Any, predictions: Any) -> Dict[str, Dict[str, float]]: 49 | if not len(reference) or (len(reference) and not isinstance(reference[0], list)): 50 | reference = [reference] 51 | if not len(predictions) or (len(predictions) and not isinstance(predictions[0], list)): 52 | predictions = [predictions] 53 | 54 | for ref in reference: 55 | for ref in ref: 56 | ref._allow_partial_match = self.allow_partial_match 57 | for pred in predictions: 58 | for pred in pred: 59 | pred._allow_partial_match = self.allow_partial_match 60 | 61 | return super().__call__(reference, predictions) 62 | -------------------------------------------------------------------------------- /src/tasks/conll03/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader, prompts, scorer 2 | 3 | 4 | __all__ = ["data_loader", "prompts", "scorer"] 5 | -------------------------------------------------------------------------------- /src/tasks/conll03/prompts.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..utils_typing import Entity, dataclass 4 | 5 | 6 | """Entity definitions 7 | 8 | The entity definitions are derived from the official ConLL2003 guidelines: 9 | https://www.clips.uantwerpen.be/conll2003/ner/ 10 | Based on: Nancy Chinchor, Erica Brown, Lisa Ferro, Patty Robinson, 11 | "1999 Named Entity Task Definition". MITRE and SAIC, 1999. 12 | """ 13 | 14 | 15 | @dataclass 16 | class Person(Entity): 17 | """{ner_person}""" 18 | 19 | span: str # {ner_person_examples} 20 | 21 | 22 | @dataclass 23 | class Organization(Entity): 24 | """{ner_organization}""" 25 | 26 | span: str # {ner_organization_examples} 27 | 28 | 29 | @dataclass 30 | class Location(Entity): 31 | """{ner_location}""" 32 | 33 | span: str # {ner_location_examples} 34 | 35 | 36 | @dataclass 37 | class Miscellaneous(Entity): 38 | """{ner_miscellaneous}""" 39 | 40 | span: str # {ner_miscellaneous_examples} 41 | 42 | 43 | ENTITY_DEFINITIONS: List[Entity] = [ 44 | Person, 45 | Organization, 46 | Location, 47 | Miscellaneous, 48 | ] 49 | 50 | ENTITY_DEFINITIONS_woMISC: List[Entity] = [ 51 | Person, 52 | Organization, 53 | Location, 54 | ] 55 | 56 | 57 | # __all__ = list(map(str, [*ENTITY_DEFINITIONS])) 58 | -------------------------------------------------------------------------------- /src/tasks/conll03/scorer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Type 2 | 3 | from src.tasks.conll03.prompts import ENTITY_DEFINITIONS, ENTITY_DEFINITIONS_woMISC 4 | from src.tasks.utils_scorer import SpanScorer 5 | from src.tasks.utils_typing import Entity 6 | 7 | 8 | class CoNLL03EntityScorer(SpanScorer): 9 | """CoNLL03 Entity identification and classification scorer.""" 10 | 11 | valid_types: List[Type] = ENTITY_DEFINITIONS 12 | 13 | def __call__(self, reference: List[Entity], predictions: List[Entity]) -> Dict[str, Dict[str, float]]: 14 | output = super().__call__(reference, predictions) 15 | return {"entities": output["spans"]} 16 | 17 | 18 | class CoNLL03EntityScorerNoMisc(SpanScorer): 19 | """CoNLL03 Entity identification and classification scorer.""" 20 | 21 | valid_types: List[Type] = ENTITY_DEFINITIONS_woMISC 22 | 23 | def __call__(self, reference: List[Entity], predictions: List[Entity]) -> Dict[str, Dict[str, float]]: 24 | output = super().__call__(reference, predictions) 25 | return {"entities": output["spans"]} 26 | -------------------------------------------------------------------------------- /src/tasks/crossner/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | data_loader, 3 | prompts_ai, 4 | prompts_literature, 5 | prompts_music, 6 | prompts_natural_science, 7 | prompts_politics, 8 | scorer, 9 | ) 10 | 11 | 12 | __all__ = [ 13 | "data_loader", 14 | "prompts_politics", 15 | "prompts_natural_science", 16 | "prompts_music", 17 | "prompts_literature", 18 | "prompts_ai", 19 | "scorer", 20 | ] 21 | -------------------------------------------------------------------------------- /src/tasks/crossner/prompts_ai.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..utils_typing import Entity, dataclass 4 | 5 | 6 | """Entity definitions 7 | 8 | The entity definitions are derived from the official CrossNER corpus guidelines: 9 | https://arxiv.org/pdf/2012.04373.pdf 10 | 11 | """ 12 | 13 | 14 | @dataclass 15 | class Field(Entity): 16 | """{crossner_ai_field}""" 17 | 18 | span: str # {crossner_ai_field_examples} 19 | 20 | 21 | @dataclass 22 | class Task(Entity): 23 | """{crossner_ai_task}""" 24 | 25 | span: str # {crossner_ai_task_examples} 26 | 27 | 28 | @dataclass 29 | class Product(Entity): 30 | """{crossner_ai_product}""" 31 | 32 | span: str # {crossner_ai_product_examples} 33 | 34 | 35 | @dataclass 36 | class Algorithm(Entity): 37 | """{crossner_ai_algorithm}""" 38 | 39 | span: str # {crossner_ai_algorithm_examples} 40 | 41 | 42 | @dataclass 43 | class Researcher(Entity): 44 | """{crossner_ai_researcher}""" 45 | 46 | span: str # {crossner_ai_researcher_examples} 47 | 48 | 49 | @dataclass 50 | class Metric(Entity): 51 | """{crossner_ai_metric}""" 52 | 53 | span: str # {crossner_ai_metric_examples} 54 | 55 | 56 | @dataclass 57 | class University(Entity): 58 | """{crossner_ai_university}""" 59 | 60 | span: str # {crossner_ai_university_examples} 61 | 62 | 63 | @dataclass 64 | class Country(Entity): 65 | """{crossner_ai_country}""" 66 | 67 | span: str # {crossner_ai_country_examples} 68 | 69 | 70 | @dataclass 71 | class Person(Entity): 72 | """{crossner_ai_person}""" 73 | 74 | span: str # {crossner_ai_person_examples} 75 | 76 | 77 | @dataclass 78 | class Organization(Entity): 79 | """{crossner_ai_organization}""" 80 | 81 | span: str # {crossner_ai_organization_examples} 82 | 83 | 84 | @dataclass 85 | class Location(Entity): 86 | """{crossner_ai_location}""" 87 | 88 | span: str # {crossner_ai_location_examples} 89 | 90 | 91 | @dataclass 92 | class ProgrammingLanguage(Entity): 93 | """{crossner_ai_programminglanguage}""" 94 | 95 | span: str # {crossner_ai_programminglanguage_examples} 96 | 97 | 98 | @dataclass 99 | class Conference(Entity): 100 | """{crossner_ai_conference}""" 101 | 102 | span: str # {crossner_ai_conference_examples} 103 | 104 | 105 | @dataclass 106 | class Other(Entity): 107 | """{crossner_ai_miscellaneous}""" 108 | 109 | span: str # {crossner_ai_miscellaneous_examples} 110 | 111 | 112 | ENTITY_DEFINITIONS_AI: List[Entity] = [ 113 | Field, 114 | Task, 115 | Product, 116 | Algorithm, 117 | Researcher, 118 | Metric, 119 | University, 120 | Country, 121 | Person, 122 | Organization, 123 | Location, 124 | ProgrammingLanguage, 125 | Conference, 126 | Other, 127 | ] 128 | 129 | ENTITY_DEFINITIONS_AI_woMISC: List[Entity] = [ 130 | Field, 131 | Task, 132 | Product, 133 | Algorithm, 134 | Researcher, 135 | Metric, 136 | University, 137 | Country, 138 | Person, 139 | Organization, 140 | Location, 141 | ProgrammingLanguage, 142 | Conference, 143 | ] 144 | 145 | 146 | # __all__ = list(map(str, [*ENTITY_DEFINITIONS])) 147 | -------------------------------------------------------------------------------- /src/tasks/crossner/prompts_literature.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..utils_typing import Entity, dataclass 4 | 5 | 6 | """Entity definitions 7 | 8 | The entity definitions are derived from the official CrossNER corpus guidelines: 9 | https://arxiv.org/pdf/2012.04373.pdf 10 | 11 | """ 12 | 13 | 14 | @dataclass 15 | class Book(Entity): 16 | """{crossner_literature_book}""" 17 | 18 | span: str # {crossner_literature_book_examples} 19 | 20 | 21 | @dataclass 22 | class Writer(Entity): 23 | """{crossner_literature_writer}""" 24 | 25 | span: str # {crossner_literature_writer_examples} 26 | 27 | 28 | @dataclass 29 | class Award(Entity): 30 | """{crossner_literature_award}""" 31 | 32 | span: str # {crossner_literature_award_examples} 33 | 34 | 35 | @dataclass 36 | class Poem(Entity): 37 | """{crossner_literature_poem}""" 38 | 39 | span: str # {crossner_literature_poem_examples} 40 | 41 | 42 | @dataclass 43 | class Event(Entity): 44 | """{crossner_literature_event}""" 45 | 46 | span: str # {crossner_literature_event_examples} 47 | 48 | 49 | @dataclass 50 | class Magazine(Entity): 51 | """{crossner_literature_magazine}""" 52 | 53 | span: str # {crossner_literature_magazine_examples} 54 | 55 | 56 | @dataclass 57 | class LiteraryGenre(Entity): 58 | """{crossner_literature_literarygenre}""" 59 | 60 | span: str # {crossner_literature_literarygenre_examples} 61 | 62 | 63 | @dataclass 64 | class Person(Entity): 65 | """{crossner_literature_person}""" 66 | 67 | span: str # {crossner_literature_person_examples} 68 | 69 | 70 | @dataclass 71 | class Location(Entity): 72 | """{crossner_literature_location}""" 73 | 74 | span: str # {crossner_literature_location_examples} 75 | 76 | 77 | @dataclass 78 | class Organization(Entity): 79 | """{crossner_literature_organization}""" 80 | 81 | span: str # {crossner_literature_organization_examples} 82 | 83 | 84 | @dataclass 85 | class Country(Entity): 86 | """{crossner_literature_country}""" 87 | 88 | span: str # {crossner_literature_country_examples} 89 | 90 | 91 | @dataclass 92 | class Other(Entity): 93 | """{crossner_literature_miscellaneous}""" 94 | 95 | span: str # {crossner_literature_miscellaneous_examples} 96 | 97 | 98 | ENTITY_DEFINITIONS_LITERATURE: List[Entity] = [ 99 | Book, 100 | Writer, 101 | Award, 102 | Poem, 103 | Event, 104 | Magazine, 105 | LiteraryGenre, 106 | Person, 107 | Location, 108 | Organization, 109 | Country, 110 | Other, 111 | ] 112 | 113 | ENTITY_DEFINITIONS_LITERATURE_woMISC: List[Entity] = [ 114 | Book, 115 | Writer, 116 | Award, 117 | Poem, 118 | Event, 119 | Magazine, 120 | LiteraryGenre, 121 | Person, 122 | Location, 123 | Organization, 124 | Country, 125 | ] 126 | 127 | # __all__ = list(map(str, [*ENTITY_DEFINITIONS])) 128 | -------------------------------------------------------------------------------- /src/tasks/crossner/prompts_music.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..utils_typing import Entity, dataclass 4 | 5 | 6 | """Entity definitions 7 | 8 | The entity definitions are derived from the official CrossNER corpus guidelines: 9 | https://arxiv.org/pdf/2012.04373.pdf 10 | 11 | """ 12 | 13 | 14 | @dataclass 15 | class MusicGenre(Entity): 16 | """{crossner_music_musicgenre}""" 17 | 18 | span: str # {crossner_music_musicgenre_examples} 19 | 20 | 21 | @dataclass 22 | class Song(Entity): 23 | """{crossner_music_song}""" 24 | 25 | span: str # {crossner_music_song_examples} 26 | 27 | 28 | @dataclass 29 | class Band(Entity): 30 | """{crossner_music_band}""" 31 | 32 | span: str # {crossner_music_band_examples} 33 | 34 | 35 | @dataclass 36 | class Album(Entity): 37 | """{crossner_music_album}""" 38 | 39 | span: str # {crossner_music_album_examples} 40 | 41 | 42 | @dataclass 43 | class MusicalArtist(Entity): 44 | """{crossner_music_musicalartist}""" 45 | 46 | span: str # {crossner_music_musicalartist_examples} 47 | 48 | 49 | @dataclass 50 | class MusicalInstrument(Entity): 51 | """{crossner_music_musicalinstrument}""" 52 | 53 | span: str # {crossner_music_musicalinstrument_examples} 54 | 55 | 56 | @dataclass 57 | class Award(Entity): 58 | """{crossner_music_award}""" 59 | 60 | span: str # {crossner_music_award_examples} 61 | 62 | 63 | @dataclass 64 | class Event(Entity): 65 | """{crossner_music_event}""" 66 | 67 | span: str # {crossner_music_event_examples} 68 | 69 | 70 | @dataclass 71 | class Country(Entity): 72 | """{crossner_music_country}""" 73 | 74 | span: str # {crossner_music_country_examples} 75 | 76 | 77 | @dataclass 78 | class Location(Entity): 79 | """{crossner_music_location}""" 80 | 81 | span: str # {crossner_music_location_examples} 82 | 83 | 84 | @dataclass 85 | class Organization(Entity): 86 | """{crossner_music_organization}""" 87 | 88 | span: str # {crossner_music_organization_examples} 89 | 90 | 91 | @dataclass 92 | class Person(Entity): 93 | """{crossner_music_person}""" 94 | 95 | span: str # {crossner_music_person_examples} 96 | 97 | 98 | @dataclass 99 | class Other(Entity): 100 | """{crossner_music_miscellaneous}""" 101 | 102 | span: str # {crossner_music_miscellaneous_examples} 103 | 104 | 105 | ENTITY_DEFINITIONS_MUSIC: List[Entity] = [ 106 | MusicGenre, 107 | Song, 108 | Band, 109 | Album, 110 | MusicalArtist, 111 | MusicalInstrument, 112 | Award, 113 | Event, 114 | Country, 115 | Location, 116 | Organization, 117 | Person, 118 | Other, 119 | ] 120 | 121 | ENTITY_DEFINITIONS_MUSIC_woMISC: List[Entity] = [ 122 | MusicGenre, 123 | Song, 124 | Band, 125 | Album, 126 | MusicalArtist, 127 | MusicalInstrument, 128 | Award, 129 | Event, 130 | Country, 131 | Location, 132 | Organization, 133 | Person, 134 | ] 135 | 136 | # __all__ = list(map(str, [*ENTITY_DEFINITIONS])) 137 | -------------------------------------------------------------------------------- /src/tasks/crossner/prompts_politics.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..utils_typing import Entity, dataclass 4 | 5 | 6 | """Entity definitions 7 | 8 | The entity definitions are derived from the official CrossNER corpus guidelines: 9 | https://arxiv.org/pdf/2012.04373.pdf 10 | 11 | """ 12 | 13 | 14 | @dataclass 15 | class Person(Entity): 16 | """{crossner_politics_person}""" 17 | 18 | span: str # {crossner_politics_person_examples} 19 | 20 | 21 | @dataclass 22 | class Organization(Entity): 23 | """{crossner_politics_organization}""" 24 | 25 | span: str # {crossner_politics_organization_examples} 26 | 27 | 28 | @dataclass 29 | class Location(Entity): 30 | """{crossner_politics_location}""" 31 | 32 | span: str # {crossner_politics_location_examples} 33 | 34 | 35 | @dataclass 36 | class Politician(Entity): 37 | """{crossner_politics_politician}""" 38 | 39 | span: str # {crossner_politics_politician_examples} 40 | 41 | 42 | @dataclass 43 | class PoliticalParty(Entity): 44 | """{crossner_politics_politicalparty}""" 45 | 46 | span: str # {crossner_politics_politicalparty_examples} 47 | 48 | 49 | @dataclass 50 | class Election(Entity): 51 | """{crossner_politics_election}""" 52 | 53 | span: str # {crossner_politics_election_examples} 54 | 55 | 56 | @dataclass 57 | class Event(Entity): 58 | """{crossner_politics_event}""" 59 | 60 | span: str # {crossner_politics_event_examples} 61 | 62 | 63 | @dataclass 64 | class Country(Entity): 65 | """{crossner_politics_country}""" 66 | 67 | span: str # {crossner_politics_country_examples} 68 | 69 | 70 | @dataclass 71 | class Other(Entity): 72 | """{crossner_politics_miscellaneous}""" 73 | 74 | span: str # {crossner_politics_miscellaneous_examples} 75 | 76 | 77 | ENTITY_DEFINITIONS_POLITICS: List[Entity] = [ 78 | Person, 79 | Organization, 80 | Location, 81 | Politician, 82 | PoliticalParty, 83 | Election, 84 | Event, 85 | Country, 86 | Other, 87 | ] 88 | 89 | ENTITY_DEFINITIONS_POLITICS_woMISC: List[Entity] = [ 90 | Person, 91 | Organization, 92 | Location, 93 | Politician, 94 | PoliticalParty, 95 | Election, 96 | Event, 97 | Country, 98 | ] 99 | 100 | # __all__ = list(map(str, [*ENTITY_DEFINITIONS])) 101 | -------------------------------------------------------------------------------- /src/tasks/diann/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader, prompts, scorer 2 | 3 | 4 | __all__ = ["data_loader", "prompts", "scorer"] 5 | -------------------------------------------------------------------------------- /src/tasks/diann/guidelines_gold.py: -------------------------------------------------------------------------------- 1 | GUIDELINES = { 2 | "diann_disability": { 3 | "en": [ 4 | "The term disability refers to an umbrella term covering impairments, limitations of activities and" 5 | " restrictions on participation. Lists of disability terms and lists of functions whose absence or" 6 | " limitation has been considered a disability: The Attention, Autonomy, Cognition Communication," 7 | " Behaviour, Day-to-day living activities Development, Emotions, Executive functioning Feeding, Functional" 8 | " capacity, Gait Hearing, Language, Learning Mental capabilities, Mobility, Perception Psychological" 9 | " capabilities, Sensory, Sight Sleep, Social cognition, Speech Swallowing, Occupational functioning. The" 10 | " modifiers affecting the disability, such as 'isolated aphasia', 'advanced dementia' and Acronyms" 11 | " referring to a disability such as Mild Cognitive Impairment (MCI) are included in the disability" 12 | " term." 13 | ] 14 | }, 15 | "diann_negation": { 16 | "en": ["Negation triggers are annotated when it affects one or more disabilities."], 17 | }, 18 | } 19 | 20 | EXAMPLES = { 21 | "diann_disability_examples": { 22 | "en": [ 23 | "cognitive impairment", 24 | "hearing loss", 25 | "dementia", 26 | "mental disorders", 27 | "mental retardation", 28 | "sensorineural hearing loss", 29 | "mild cognitive impairment", 30 | "functional impairment", 31 | "mental disorder", 32 | "dysarthria", 33 | ] 34 | }, 35 | "diann_negation_examples": {"en": []}, 36 | } 37 | -------------------------------------------------------------------------------- /src/tasks/diann/prompts.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..utils_typing import Entity, dataclass 4 | 5 | 6 | """Entity definitions 7 | 8 | The entity definitions are derived from the official Diann guidelines: 9 | http://nlp.uned.es/diann/#data 10 | """ 11 | 12 | 13 | @dataclass 14 | class Disability(Entity): 15 | """{diann_disability}""" 16 | 17 | span: str # {diann_disability_examples} 18 | 19 | 20 | """ 21 | @dataclass 22 | class Negation(Entity): 23 | \"""{diann_negation}\""" 24 | 25 | span: str # {diann_negation_examples} 26 | """ 27 | 28 | ENTITY_DEFINITIONS: List[Entity] = [Disability] 29 | 30 | # __all__ = list(map(str, [*ENTITY_DEFINITIONS])) 31 | -------------------------------------------------------------------------------- /src/tasks/diann/scorer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Type 2 | 3 | from src.tasks.diann.prompts import ENTITY_DEFINITIONS 4 | from src.tasks.utils_scorer import SpanScorer 5 | from src.tasks.utils_typing import Entity 6 | 7 | 8 | class DiannDiseaseEntityScorer(SpanScorer): 9 | """CoNLL03 Entity identification and classification scorer.""" 10 | 11 | valid_types: List[Type] = ENTITY_DEFINITIONS 12 | 13 | def __call__(self, reference: List[Entity], predictions: List[Entity]) -> Dict[str, Dict[str, float]]: 14 | output = super().__call__(reference, predictions) 15 | return {"entities": output["spans"]} 16 | -------------------------------------------------------------------------------- /src/tasks/e3c/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader, prompts, scorer 2 | 3 | 4 | __all__ = ["data_loader", "prompts", "scorer"] 5 | -------------------------------------------------------------------------------- /src/tasks/e3c/guidelines.py: -------------------------------------------------------------------------------- 1 | GUIDELINES = { 2 | "e3c_disease": { 3 | "en": [ 4 | "A definite pathologic process with a characteristic set of signs and symptoms. Examples: tumor, tumour," 5 | " vomiting, swelling, epistaxis, ascites, headache, fever, cyst, diplopia." 6 | ] 7 | }, 8 | } 9 | -------------------------------------------------------------------------------- /src/tasks/e3c/guidelines_gold.py: -------------------------------------------------------------------------------- 1 | GUIDELINES = { 2 | "e3c_disease": {"en": ["A definite pathologic process with a characteristic set of signs and symptoms."]}, 3 | } 4 | 5 | EXAMPLES = { 6 | "e3c_disease_examples": { 7 | "en": [ 8 | "tumor", 9 | "tumour", 10 | "vomiting", 11 | "swelling", 12 | "epistaxis", 13 | "ascites", 14 | "headache", 15 | "fever", 16 | "cyst", 17 | "diplopia", 18 | ] 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/tasks/e3c/prompts.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..utils_typing import Entity, dataclass 4 | 5 | 6 | """Entity definitions 7 | 8 | The entity definitions are derived from the official E3C corpus guidelines: 9 | https://github.com/hltfbk/E3C-Corpus/blob/main/documentation/CLINICALENTITY_ANNOTATION_GUIDELINES_v1.1.pdf 10 | 11 | """ 12 | 13 | 14 | @dataclass 15 | class ClinicalEntity(Entity): 16 | """{e3c_disease}""" 17 | 18 | span: str # {e3c_disease_examples} 19 | 20 | 21 | ENTITY_DEFINITIONS: List[Entity] = [ClinicalEntity] 22 | 23 | # __all__ = list(map(str, [*ENTITY_DEFINITIONS])) 24 | -------------------------------------------------------------------------------- /src/tasks/e3c/scorer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Type 2 | 3 | from src.tasks.e3c.prompts import ENTITY_DEFINITIONS 4 | from src.tasks.utils_scorer import SpanScorer 5 | from src.tasks.utils_typing import Entity 6 | 7 | 8 | class E3CEntityScorer(SpanScorer): 9 | """CoNLL03 Entity identification and classification scorer.""" 10 | 11 | valid_types: List[Type] = ENTITY_DEFINITIONS 12 | 13 | def __call__(self, reference: List[Entity], predictions: List[Entity]) -> Dict[str, Dict[str, float]]: 14 | output = super().__call__(reference, predictions) 15 | return {"entities": output["spans"]} 16 | -------------------------------------------------------------------------------- /src/tasks/fabner/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader, prompts, scorer 2 | 3 | 4 | __all__ = ["data_loader", "prompts", "scorer"] 5 | -------------------------------------------------------------------------------- /src/tasks/fabner/guidelines.py: -------------------------------------------------------------------------------- 1 | GUIDELINES = { 2 | "fabner_material": { 3 | "en": [ 4 | "Refers to any substance or compound used in manufacturing processes that has a significant impact on the" 5 | " process, properties, or characteristics of the end product. Materials can include metals, alloys," 6 | " polymers, ceramics, composites, and more." 7 | ] 8 | }, 9 | "fabner_manufacturingprocess": { 10 | "en": ["Refers to a specific technique, method, or procedure used in the production of goods."] 11 | }, 12 | "fabner_machineequipment": { 13 | "en": [ 14 | "Any necessary items required for a particular purpose, including assembled parts and set of tools." 15 | " Equipments are required to set up not only the mechanical machines, but also for setting up any relevant" 16 | " operation or task." 17 | ] 18 | }, 19 | "fabner_application": { 20 | "en": [ 21 | "Refers to the specific field or industry in which a manufacturing process is utilized or intended to be" 22 | " utilized. It signifies the purpose, domain, or sector where the manufacturing process finds its use." 23 | ] 24 | }, 25 | "fabner_engineeringfeatures": { 26 | "en": [ 27 | "Encompass distinct terms or phrases that convey specific information related to the structural," 28 | " functional, or visual properties of a manufacturing process." 29 | ] 30 | }, 31 | "fabner_mechanicalproperties": { 32 | "en": [ 33 | "Refers to specific attributes or characteristics of materials that describe their behavior under" 34 | " mechanical forces and conditions. These properties play a crucial role in manufacturing processes and" 35 | " are often used to evaluate the suitability and performance of materials for different applications." 36 | ] 37 | }, 38 | "fabner_processcharacterization": { 39 | "en": [ 40 | "Refers to the detailed analysis and measurement of various aspects of manufacturing processes to" 41 | " understand their performance, quality, and characteristics." 42 | ] 43 | }, 44 | "fabner_processparameters": { 45 | "en": [ 46 | "Process parameters refer to specific quantitative values or attributes that play a significant role in" 47 | " manufacturing processes." 48 | ] 49 | }, 50 | "fabner_enablingtechnology": { 51 | "en": [ 52 | "Specific tools, methods, processes, or technologies that play a pivotal role in facilitating or enhancing" 53 | " various aspects of manufacturing processes." 54 | ] 55 | }, 56 | "fabner_conceptprinciples": { 57 | "en": ["Refers to key concepts, principles, and elements related to manufacturing processes."] 58 | }, 59 | "fabner_manufacturingstandards": { 60 | "en": [ 61 | "Refer to specific guidelines, specifications, benchmarks, and terminology used within the manufacturing" 62 | " industry to ensure uniformity, quality, and compatibility of products, processes, and systems." 63 | ] 64 | }, 65 | "fabner_biomedical": {"en": ["Refers to any biomedical entity involved in the manufacturing process."]}, 66 | } 67 | -------------------------------------------------------------------------------- /src/tasks/fabner/prompts.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..utils_typing import Entity, dataclass 4 | 5 | 6 | """Entity definitions 7 | 8 | The entity definitions are derived from the official fabner guidelines: 9 | https://par.nsf.gov/servlets/purl/10290810 10 | 11 | """ 12 | 13 | 14 | @dataclass 15 | class Material(Entity): 16 | """{fabner_material}""" 17 | 18 | span: str # {fabner_material_examples} 19 | 20 | 21 | @dataclass 22 | class ManufacturingProcess(Entity): 23 | """{fabner_manufacturingprocess}""" 24 | 25 | span: str # {fabner_manufacturingprocess_examples} 26 | 27 | 28 | @dataclass 29 | class MachineEquipment(Entity): 30 | """{fabner_machineequipment}""" 31 | 32 | span: str # {fabner_machineequipment_examples} 33 | 34 | 35 | @dataclass 36 | class Application(Entity): 37 | """{fabner_application}""" 38 | 39 | span: str # {fabner_application_examples} 40 | 41 | 42 | @dataclass 43 | class EngineeringFeatures(Entity): 44 | """{fabner_engineeringfeatures}""" 45 | 46 | span: str # {fabner_engineeringfeatures_examples} 47 | 48 | 49 | @dataclass 50 | class MechanicalProperties(Entity): 51 | """{fabner_mechanicalproperties}""" 52 | 53 | span: str # {fabner_mechanicalproperties_examples} 54 | 55 | 56 | @dataclass 57 | class ProcessCharacterization(Entity): 58 | """{fabner_processcharacterization}""" 59 | 60 | span: str # {fabner_processcharacterization_examples} 61 | 62 | 63 | @dataclass 64 | class ProcessParameters(Entity): 65 | """{fabner_processparameters}""" 66 | 67 | span: str # {fabner_processparameters_examples} 68 | 69 | 70 | @dataclass 71 | class EnablingTechnology(Entity): 72 | """{fabner_enablingtechnology}""" 73 | 74 | span: str # {fabner_enablingtechnology_examples} 75 | 76 | 77 | @dataclass 78 | class ConceptPrinciples(Entity): 79 | """{fabner_conceptprinciples}""" 80 | 81 | span: str # {fabner_conceptprinciples_examples} 82 | 83 | 84 | @dataclass 85 | class ManufacturingStandards(Entity): 86 | """{fabner_manufacturingstandards}""" 87 | 88 | span: str # {fabner_manufacturingstandards_examples} 89 | 90 | 91 | @dataclass 92 | class Biomedical(Entity): 93 | """{fabner_biomedical}""" 94 | 95 | span: str # {fabner_biomedical_examples} 96 | 97 | 98 | ENTITY_DEFINITIONS: List[Entity] = [ 99 | Material, 100 | ManufacturingProcess, 101 | MachineEquipment, 102 | Application, 103 | EngineeringFeatures, 104 | MechanicalProperties, 105 | ProcessCharacterization, 106 | ProcessParameters, 107 | EnablingTechnology, 108 | ConceptPrinciples, 109 | ManufacturingStandards, 110 | Biomedical, 111 | ] 112 | 113 | # __all__ = list(map(str, [*ENTITY_DEFINITIONS])) 114 | -------------------------------------------------------------------------------- /src/tasks/fabner/scorer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Type 2 | 3 | from src.tasks.fabner.prompts import ENTITY_DEFINITIONS 4 | from src.tasks.utils_scorer import SpanScorer 5 | from src.tasks.utils_typing import Entity 6 | 7 | 8 | class FabNerEntityScorer(SpanScorer): 9 | """CoNLL03 Entity identification and classification scorer.""" 10 | 11 | valid_types: List[Type] = ENTITY_DEFINITIONS 12 | 13 | def __call__(self, reference: List[Entity], predictions: List[Entity]) -> Dict[str, Dict[str, float]]: 14 | output = super().__call__(reference, predictions) 15 | return {"entities": output["spans"]} 16 | -------------------------------------------------------------------------------- /src/tasks/harveyner/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader, prompts, scorer 2 | 3 | 4 | __all__ = ["data_loader", "prompts", "scorer"] 5 | -------------------------------------------------------------------------------- /src/tasks/harveyner/guidelines.py: -------------------------------------------------------------------------------- 1 | GUIDELINES = { 2 | "harveyner_point": { 3 | "en": [ 4 | "Refers to a location that is a building, a landmark, an intersection of two roads, an intersection of a" 5 | " river with a lake/reservoir/ocean, or a specifc address. Ignore generic company/franchise names unless" 6 | " it is accompanied with a precise location, for example, HEB at Kirkwood Drive. However, non-franchised" 7 | " small businesses with only one unique location are considered as a point. Ignore any locations in the" 8 | " Twitter username, unless @ does not refer to a Twitter account name. For example, I am @ XXX High" 9 | " School." 10 | ] 11 | }, 12 | "harveyner_area": { 13 | "en": [ 14 | "Refers to all the named entities of cities, neighborhoods, super neighborhoods, geographic divisions etc." 15 | " The following locations, Lake Houston, Barker Reservoir, and Addick’s Reservoir, are annotated as areas" 16 | " due to their significant size while all other lakes/reservoirs are not considered as areas." 17 | ] 18 | }, 19 | "harveyner_road": { 20 | "en": [ 21 | "Refers to a road/avenue/street or a section of a road/avenue/street when the tweet does not provide an" 22 | " exact location on that road." 23 | ] 24 | }, 25 | "harveyner_river": { 26 | "en": [ 27 | "Refers to a river or a section of a river when the tweet does not imply there is an intersection between" 28 | " the river and other places." 29 | ] 30 | }, 31 | } 32 | -------------------------------------------------------------------------------- /src/tasks/harveyner/guidelines_gold.py: -------------------------------------------------------------------------------- 1 | GUIDELINES = { 2 | "harveyner_point": { 3 | "en": [ 4 | "Refers to a location that is a building, a landmark, an intersection of two roads, an intersection of a" 5 | " river with a lake/reservoir/ocean, or a specifc address. Ignore generic company/franchise names unless" 6 | " it is accompanied with a precise location, for example, HEB at Kirkwood Drive. However, non-franchised" 7 | " small businesses with only one unique location are considered as a point. Ignore any locations in the" 8 | " Twitter username, unless @ does not refer to a Twitter account name. For example, I am @ XXX High" 9 | " School." 10 | ] 11 | }, 12 | "harveyner_area": { 13 | "en": [ 14 | "Refers to all the named entities of cities, neighborhoods, super neighborhoods, geographic divisions etc." 15 | " The following locations, Lake Houston, Barker Reservoir, and Addick’s Reservoir, are annotated as areas" 16 | " due to their significant size while all other lakes/reservoirs are not considered as areas." 17 | ] 18 | }, 19 | "harveyner_road": { 20 | "en": [ 21 | "Refers to a road/avenue/street or a section of a road/avenue/street when the tweet does not provide an" 22 | " exact location on that road." 23 | ] 24 | }, 25 | "harveyner_river": { 26 | "en": [ 27 | "Refers to a river or a section of a river when the tweet does not imply there is an intersection between" 28 | " the river and other places." 29 | ] 30 | }, 31 | } 32 | 33 | EXAMPLES = { 34 | "harveyner_point_examples": { 35 | "en": [ 36 | "GRB", 37 | "GEORGE R. BROWN", 38 | "Lakewood Church", 39 | "Bayou Oaks", 40 | "Northgate Subdivision S of toll road", 41 | "TerryHS", 42 | "GRB Convention Center", 43 | "@GRBCC", 44 | "Addicks", 45 | "TERRY HIGH SCHOOL", 46 | ] 47 | }, 48 | "harveyner_area_examples": { 49 | "en": [ 50 | "Sienna Plantation", 51 | "Galveston", 52 | "Addicks", 53 | "Barker", 54 | "Fort Bend", 55 | "Chambers", 56 | "Fort Bend County", 57 | "Pecan Grove", 58 | "Brazoria", 59 | "Dickinson", 60 | ] 61 | }, 62 | "harveyner_river_examples": { 63 | "en": [ 64 | "Buffalo Bayou", 65 | "Brazos River", 66 | "Brays Bayou", 67 | "Cypress Creek", 68 | "Spring Creek", 69 | "Addicks", 70 | "Colorado River", 71 | "San Jacinto", 72 | "WhiteOakBayou", 73 | "Brazos", 74 | ] 75 | }, 76 | "harveyner_road_examples": { 77 | "en": ["Barker Cypress", "I-10", "US 290 WB", "SH-71", "I-45", "105", "59", "Fry", "US-59 South", "Summer St"] 78 | }, 79 | } 80 | -------------------------------------------------------------------------------- /src/tasks/harveyner/prompts.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..utils_typing import Entity, dataclass 4 | 5 | 6 | """Entity definitions 7 | 8 | The entity definitions are derived from the official harveyner guidelines: 9 | https://aclanthology.org/2022.naacl-main.243.pdf 10 | 11 | """ 12 | 13 | 14 | @dataclass 15 | class Point(Entity): 16 | """{harveyner_point}""" 17 | 18 | span: str # {harveyner_point_examples} 19 | 20 | 21 | class Area(Entity): 22 | """{harveyner_area}""" 23 | 24 | span: str # {harveyner_area_examples} 25 | 26 | 27 | class Road(Entity): 28 | """{harveyner_road}""" 29 | 30 | span: str # {harveyner_road_examples} 31 | 32 | 33 | class River(Entity): 34 | """{harveyner_river}""" 35 | 36 | span: str # {harveyner_river_examples} 37 | 38 | 39 | ENTITY_DEFINITIONS: List[Entity] = [ 40 | Point, 41 | Area, 42 | Road, 43 | River, 44 | ] 45 | 46 | # __all__ = list(map(str, [*ENTITY_DEFINITIONS])) 47 | -------------------------------------------------------------------------------- /src/tasks/harveyner/scorer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Type 2 | 3 | from src.tasks.harveyner.prompts import ENTITY_DEFINITIONS 4 | from src.tasks.utils_scorer import SpanScorer 5 | from src.tasks.utils_typing import Entity 6 | 7 | 8 | class HarveyNEREntityScorer(SpanScorer): 9 | """CoNLL03 Entity identification and classification scorer.""" 10 | 11 | valid_types: List[Type] = ENTITY_DEFINITIONS 12 | 13 | def __call__(self, reference: List[Entity], predictions: List[Entity]) -> Dict[str, Dict[str, float]]: 14 | output = super().__call__(reference, predictions) 15 | return {"entities": output["spans"]} 16 | -------------------------------------------------------------------------------- /src/tasks/mitmovie/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader, prompts, scorer 2 | 3 | 4 | __all__ = ["data_loader", "prompts", "scorer"] 5 | -------------------------------------------------------------------------------- /src/tasks/mitmovie/prompts.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..utils_typing import Entity, dataclass 4 | 5 | 6 | """Entity definitions 7 | 8 | In the absence of public guidelines, the guidelines have been defined by the CoLLIE authors. 9 | Dataset available at: https://groups.csail.mit.edu/sls/downloads/movie/ 10 | 11 | """ 12 | 13 | 14 | @dataclass 15 | class Actor(Entity): 16 | """{mit_actor}""" 17 | 18 | span: str # {mit_actor_examples} 19 | 20 | 21 | @dataclass 22 | class Character(Entity): 23 | """{mit_character}""" 24 | 25 | span: str # {mit_character_examples} 26 | 27 | 28 | @dataclass 29 | class Director(Entity): 30 | """{mit_director}""" 31 | 32 | span: str # {mit_director_examples} 33 | 34 | 35 | @dataclass 36 | class Genre(Entity): 37 | """{mit_genre}""" 38 | 39 | span: str # {mit_genre_examples} 40 | 41 | 42 | @dataclass 43 | class Plot(Entity): 44 | """{mit_plot}""" 45 | 46 | span: str # {mit_plot_examples} 47 | 48 | 49 | @dataclass 50 | class Rating(Entity): 51 | """{mit_rating}""" 52 | 53 | span: str # {mit_rating_examples} 54 | 55 | 56 | @dataclass 57 | class RatingsAverage(Entity): 58 | """{mit_ratings_average}""" 59 | 60 | span: str # {mit_ratings_average_examples} 61 | 62 | 63 | @dataclass 64 | class Review(Entity): 65 | """{mit_review}""" 66 | 67 | span: str # {mit_review_examples} 68 | 69 | 70 | @dataclass 71 | class Song(Entity): 72 | """{mit_song}""" 73 | 74 | span: str # {mit_song_examples} 75 | 76 | 77 | @dataclass 78 | class Tittle(Entity): 79 | """{mit_title}""" 80 | 81 | span: str # {mit_title_examples} 82 | 83 | 84 | @dataclass 85 | class Trailer(Entity): 86 | """{mit_trailer}""" 87 | 88 | span: str # {mit_trailer_examples} 89 | 90 | 91 | @dataclass 92 | class Year(Entity): 93 | """{mit_year}""" 94 | 95 | span: str # {mit_year_examples} 96 | 97 | 98 | ENTITY_DEFINITIONS: List[Entity] = [ 99 | Actor, 100 | Character, 101 | Director, 102 | Genre, 103 | Plot, 104 | Rating, 105 | RatingsAverage, 106 | Review, 107 | Song, 108 | Tittle, 109 | Trailer, 110 | Year, 111 | ] 112 | 113 | # __all__ = list(map(str, [*ENTITY_DEFINITIONS])) 114 | -------------------------------------------------------------------------------- /src/tasks/mitmovie/scorer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Type 2 | 3 | from src.tasks.mitmovie.prompts import ENTITY_DEFINITIONS 4 | from src.tasks.utils_scorer import SpanScorer 5 | from src.tasks.utils_typing import Entity 6 | 7 | 8 | class MitMovieEntityScorer(SpanScorer): 9 | """CoNLL03 Entity identification and classification scorer.""" 10 | 11 | valid_types: List[Type] = ENTITY_DEFINITIONS 12 | 13 | def __call__(self, reference: List[Entity], predictions: List[Entity]) -> Dict[str, Dict[str, float]]: 14 | output = super().__call__(reference, predictions) 15 | return {"entities": output["spans"]} 16 | -------------------------------------------------------------------------------- /src/tasks/mitrestaurant/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader, prompts, scorer 2 | 3 | 4 | __all__ = ["data_loader", "prompts", "scorer"] 5 | -------------------------------------------------------------------------------- /src/tasks/mitrestaurant/prompts.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..utils_typing import Entity, dataclass 4 | 5 | 6 | """Entity definitions 7 | 8 | In the absence of public guidelines, the guidelines have been defined by the CoLLIE authors. 9 | Dataset available at: https://groups.csail.mit.edu/sls/downloads/restaurant/ 10 | 11 | """ 12 | 13 | 14 | @dataclass 15 | class Rating(Entity): 16 | """{mit_rating}""" 17 | 18 | span: str # {mit_rating_examples} 19 | 20 | 21 | @dataclass 22 | class Amenity(Entity): 23 | """{mit_amenity}""" 24 | 25 | span: str # {mit_amenity_examples} 26 | 27 | 28 | @dataclass 29 | class Location(Entity): 30 | """{mit_location}""" 31 | 32 | span: str # {mit_location_examples} 33 | 34 | 35 | @dataclass 36 | class RestaurantName(Entity): 37 | """{mit_restaurantname}""" 38 | 39 | span: str # {mit_restaurantname_examples} 40 | 41 | 42 | @dataclass 43 | class Price(Entity): 44 | """{mit_price}""" 45 | 46 | span: str # {mit_price_examples} 47 | 48 | 49 | @dataclass 50 | class Hours(Entity): 51 | """{mit_hours}""" 52 | 53 | span: str # {mit_hours_examples} 54 | 55 | 56 | @dataclass 57 | class Dish(Entity): 58 | """{mit_dish}""" 59 | 60 | span: str # {mit_dish_examples} 61 | 62 | 63 | @dataclass 64 | class Cuisine(Entity): 65 | """{mit_cuisine}""" 66 | 67 | span: str # {mit_cuisine_examples} 68 | 69 | 70 | ENTITY_DEFINITIONS: List[Entity] = [Rating, Amenity, Location, RestaurantName, Price, Hours, Dish, Cuisine] 71 | 72 | # __all__ = list(map(str, [*ENTITY_DEFINITIONS])) 73 | -------------------------------------------------------------------------------- /src/tasks/mitrestaurant/scorer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Type 2 | 3 | from src.tasks.mitrestaurant.prompts import ENTITY_DEFINITIONS 4 | from src.tasks.utils_scorer import SpanScorer 5 | from src.tasks.utils_typing import Entity 6 | 7 | 8 | class MitRestaurantEntityScorer(SpanScorer): 9 | """CoNLL03 Entity identification and classification scorer.""" 10 | 11 | valid_types: List[Type] = ENTITY_DEFINITIONS 12 | 13 | def __call__(self, reference: List[Entity], predictions: List[Entity]) -> Dict[str, Dict[str, float]]: 14 | output = super().__call__(reference, predictions) 15 | return {"entities": output["spans"]} 16 | -------------------------------------------------------------------------------- /src/tasks/multinerd/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader, prompts, scorer 2 | 3 | 4 | __all__ = ["data_loader", "prompts", "scorer"] 5 | -------------------------------------------------------------------------------- /src/tasks/multinerd/guidelines.py: -------------------------------------------------------------------------------- 1 | GUIDELINES = { 2 | "multinerd_person": {"en": ["People."]}, 3 | "multinerd_organization": { 4 | "en": ["Associations, companies, agencies, institutions, nationalities and religious or political groups."] 5 | }, 6 | "multinerd_location": { 7 | "en": [ 8 | "Physical locations (e.g. mountains, bodies of water), geopolitical entities (e.g. cities, states), and" 9 | " facilities (e.g. bridges, buildings, airports)" 10 | ] 11 | }, 12 | "multinerd_animal": {"en": ["Breeds of dogs, cats and other animals, including their scientific names."]}, 13 | "multinerd_biological": { 14 | "en": ["Genus of fungus, bacteria and protoctists, families of viruses, and other biological entities."] 15 | }, 16 | "multinerd_celestial": { 17 | "en": ["Planets, stars, asteroids, comets, nebulae, galaxies and other astronomical objects."] 18 | }, 19 | "multinerd_disease": { 20 | "en": [ 21 | "Physical, mental, infectious, non-infectious, deficiency, inherited, degenerative, social and" 22 | " self-inflicted diseases." 23 | ] 24 | }, 25 | "multinerd_event": {"en": ["Sport events, battles, wars and other events."]}, 26 | "multinerd_food": {"en": ["Foods and drinks."]}, 27 | "multinerd_instrument": { 28 | "en": ["Technological instruments, mechanical instruments, musical instruments, and other tools."] 29 | }, 30 | "multinerd_media": { 31 | "en": ["Titles of films, books, magazines, songs and albums, fictional characters and languages."] 32 | }, 33 | "multinerd_plant": {"en": ["Types of trees, flowers, and other plants, including their scientific names."]}, 34 | "multinerd_mythological": {"en": ["Mythological and religious entities."]}, 35 | "multinerd_time": { 36 | "en": [ 37 | "Specific and well-defined time intervals, such as eras, historical periods, centuries, years and" 38 | " important days. No months and days of the week." 39 | ] 40 | }, 41 | "multinerd_vehicle": {"en": ["Cars, motorcycles and other vehicles."]}, 42 | } 43 | -------------------------------------------------------------------------------- /src/tasks/multinerd/prompts.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..utils_typing import Entity, dataclass 4 | 5 | 6 | """Entity definitions 7 | 8 | The entity definitions are derived from the official Multi-NERD guidelines: 9 | https://aclanthology.org/2022.findings-naacl.60.pdf 10 | 11 | """ 12 | 13 | 14 | @dataclass 15 | class Person(Entity): 16 | """{multinerd_person}""" 17 | 18 | span: str # {multinerd_person_examples} 19 | 20 | 21 | @dataclass 22 | class Organization(Entity): 23 | """{multinerd_organization}""" 24 | 25 | span: str # {multinerd_organization_examples} 26 | 27 | 28 | @dataclass 29 | class Location(Entity): 30 | """{multinerd_location}""" 31 | 32 | span: str # {multinerd_location_examples} 33 | 34 | 35 | @dataclass 36 | class Animal(Entity): 37 | """{multinerd_animal}""" 38 | 39 | span: str # {multinerd_animal_examples} 40 | 41 | 42 | @dataclass 43 | class Biological(Entity): 44 | """{multinerd_biological}""" 45 | 46 | span: str # {multinerd_biological_examples} 47 | 48 | 49 | @dataclass 50 | class Celestial(Entity): 51 | """{multinerd_celestial}""" 52 | 53 | span: str # {multinerd_celestial_examples} 54 | 55 | 56 | @dataclass 57 | class Disease(Entity): 58 | """{multinerd_disease}""" 59 | 60 | span: str # {multinerd_disease_examples} 61 | 62 | 63 | @dataclass 64 | class Event(Entity): 65 | """{multinerd_event}""" 66 | 67 | span: str # {multinerd_event_examples} 68 | 69 | 70 | @dataclass 71 | class Food(Entity): 72 | """{multinerd_food}""" 73 | 74 | span: str # {multinerd_food_examples} 75 | 76 | 77 | @dataclass 78 | class Instrument(Entity): 79 | """{multinerd_instrument}""" 80 | 81 | span: str # {multinerd_instrument_examples} 82 | 83 | 84 | @dataclass 85 | class Media(Entity): 86 | """{multinerd_media}""" 87 | 88 | span: str # {multinerd_media_examples} 89 | 90 | 91 | @dataclass 92 | class Plant(Entity): 93 | """{multinerd_plant}""" 94 | 95 | span: str # {multinerd_plant_examples} 96 | 97 | 98 | @dataclass 99 | class Mythological(Entity): 100 | """{multinerd_mythological}""" 101 | 102 | span: str # {multinerd_mythological_examples} 103 | 104 | 105 | @dataclass 106 | class Time(Entity): 107 | """{multinerd_time}""" 108 | 109 | span: str # {multinerd_time_examples} 110 | 111 | 112 | @dataclass 113 | class Vehicle(Entity): 114 | """{multinerd_vehicle}""" 115 | 116 | span: str # {multinerd_vehicle_examples} 117 | 118 | 119 | ENTITY_DEFINITIONS: List[Entity] = [ 120 | Person, 121 | Location, 122 | Organization, 123 | Animal, 124 | Biological, 125 | Celestial, 126 | Disease, 127 | Event, 128 | Food, 129 | Instrument, 130 | Media, 131 | Plant, 132 | Mythological, 133 | Time, 134 | Vehicle, 135 | ] 136 | 137 | # __all__ = list(map(str, [*ENTITY_DEFINITIONS])) 138 | -------------------------------------------------------------------------------- /src/tasks/multinerd/scorer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Type 2 | 3 | from src.tasks.multinerd.prompts import ENTITY_DEFINITIONS 4 | from src.tasks.utils_scorer import SpanScorer 5 | from src.tasks.utils_typing import Entity 6 | 7 | 8 | class MultinerdEntityScorer(SpanScorer): 9 | """CoNLL03 Entity identification and classification scorer.""" 10 | 11 | valid_types: List[Type] = ENTITY_DEFINITIONS 12 | 13 | def __call__(self, reference: List[Entity], predictions: List[Entity]) -> Dict[str, Dict[str, float]]: 14 | output = super().__call__(reference, predictions) 15 | return {"entities": output["spans"]} 16 | -------------------------------------------------------------------------------- /src/tasks/ncbidisease/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader, prompts, scorer 2 | 3 | 4 | __all__ = ["data_loader", "prompts", "scorer"] 5 | -------------------------------------------------------------------------------- /src/tasks/ncbidisease/guidelines.py: -------------------------------------------------------------------------------- 1 | GUIDELINES = { 2 | "ncbi_disease": { 3 | "en": [ 4 | ( 5 | "A disease is a particular abnormal condition that negatively affects the structure or function of all" 6 | " or part of an \norganism, and that is not immediately due to any external injury. Disease mentions," 7 | " include Specific Disease (i.e., \nclear-cell renal cell carcinoma), Disease Class (i.e., cystic" 8 | " kidney diseases), Composite mentions (i.e., prostatic, \npancreas, skin, and lung cancer), and" 9 | " Modifier (i.e., hereditary breast cancer families)." 10 | ), 11 | ( 12 | "A disease refers to an unusual physiological state that detrimentally impacts the structure or" 13 | " function of all or \npart of an organism, and which isn't immediately caused by any external injury." 14 | " Disease-related mentions encompass \nSpecific Disease (for example, clear-cell renal cell" 15 | " carcinoma), Disease Class (such as cystic kidney diseases), \nComposite mentions (like prostatic," 16 | " pancreas, skin, and lung cancer), and Modifier (for instance, hereditary breast cancer \nfamilies)." 17 | ), 18 | ( 19 | "A disease refers to a distinct anomalous state that detrimentally impacts the structure or function" 20 | " of an entire \norganism, not resulting from any immediate external injury. Disease mentions" 21 | " encompass Specific Disease (for example, \nclear-cell renal cell carcinoma), Disease Class (such as" 22 | " cystic kidney diseases), Composite mentions (like prostatic, \npancreas, skin, and lung cancer), and" 23 | " Modifier (for instance, hereditary breast cancer families)." 24 | ), 25 | ( 26 | "A disease refers to a distinct irregular condition that unfavorably impacts the structure or function" 27 | " of all or part \nof an organism, and is not immediately attributable to any external trauma. Disease" 28 | " mentions encompass Specific \nDisease (for example, clear-cell renal cell carcinoma), Disease Class" 29 | " (for instance, cystic kidney diseases), Composite \nmentions (such as prostatic, pancreas, skin, and" 30 | " lung cancer), and Modifier (for example, hereditary breast cancer \nfamilies)." 31 | ), 32 | ( 33 | "A disease refers to a distinct irregularity that unfavorably impacts the structure or function of an" 34 | " entire \norganism or a portion of it, and is not immediately attributable to any external injury." 35 | " Disease mentions encompass Specific \nDisease (for example, clear-cell renal cell carcinoma)," 36 | " Disease Class (such as cystic kidney diseases), Composite \nmentions (like prostatic, pancreas," 37 | " skin, and lung cancer), and Modifier (like hereditary breast cancer families)." 38 | ), 39 | ] 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/tasks/ncbidisease/guidelines_gold.py: -------------------------------------------------------------------------------- 1 | GUIDELINES = { 2 | "ncbi_disease": { 3 | "en": [ 4 | "A disease is a particular abnormal condition that negatively affects the structure or function of all or" 5 | " part of an organism, and that is not immediately due to any external injury. Disease mentions, include" 6 | " Specific Disease (i.e., clear-cell renal cell carcinoma), Disease Class (i.e., cystic kidney diseases)," 7 | " Composite mentions (i.e., prostatic, pancreas, skin, and lung cancer), and Modifier (i.e., hereditary" 8 | " breast cancer families). " 9 | ] 10 | }, 11 | } 12 | 13 | EXAMPLES = { 14 | "ncbi_disease_examples": { 15 | "en": ["DM", "DMD", "APC", "ALD", "PWS", "WAS", "myotonic dystrophy", "G6PD deficiency", "HD", "PKU"] 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/tasks/ncbidisease/prompts.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..utils_typing import Entity, dataclass 4 | 5 | 6 | """Entity definitions 7 | 8 | The entity definitions are derived from the official NCBI-Disease guidelines: 9 | https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3951655/ 10 | Disease definition from: https://en.wikipedia.org/wiki/Disease which is based on the definitions from: 11 | [1] "Disease" at Dorland's Medical Dictionary 12 | [2] White, Tim (19 December 2014). "What is the Difference Between an "Injury" and "Disease" for Comcare Commonwealth 13 | Compensation Claims?". Tindall Gask Bentley. Archived from the original on 27 October 2017. Retrieved 6 November 2017. 14 | 15 | 16 | """ 17 | 18 | 19 | @dataclass 20 | class Disease(Entity): 21 | """{ncbi_disease}""" 22 | 23 | span: str # {ncbi_disease_examples} 24 | 25 | 26 | ENTITY_DEFINITIONS: List[Entity] = [Disease] 27 | 28 | # __all__ = list(map(str, [*ENTITY_DEFINITIONS])) 29 | -------------------------------------------------------------------------------- /src/tasks/ncbidisease/scorer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Type 2 | 3 | from src.tasks.ncbidisease.prompts import ENTITY_DEFINITIONS 4 | from src.tasks.utils_scorer import SpanScorer 5 | from src.tasks.utils_typing import Entity 6 | 7 | 8 | class NcbiDiseaseEntityScorer(SpanScorer): 9 | """CoNLL03 Entity identification and classification scorer.""" 10 | 11 | valid_types: List[Type] = ENTITY_DEFINITIONS 12 | 13 | def __call__(self, reference: List[Entity], predictions: List[Entity]) -> Dict[str, Dict[str, float]]: 14 | output = super().__call__(reference, predictions) 15 | return {"entities": output["spans"]} 16 | -------------------------------------------------------------------------------- /src/tasks/ner_tasks.csv: -------------------------------------------------------------------------------- 1 | Dataset, Implemented 2 | Conll, Yes 3 | Ncbi-disease, Yes 4 | bc5cdr, Yes 5 | diann, Yes 6 | E3C 7 | OntonotesV5, Yes 8 | WNUT 2017, Yes 9 | Open Entity, Too many entities 10 | Few-NERD 11 | multiNERD 12 | Multiconer1 13 | Multiconer2 14 | FabNER 15 | MIT Movie Review 16 | MIT Restaurant Review 17 | polyglot-NER 18 | broad_twitter_corpus 19 | CrossNER 20 | 21 | -------------------------------------------------------------------------------- /src/tasks/ontonotes/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader, prompts, scorer 2 | 3 | 4 | __all__ = ["data_loader", "prompts", "scorer"] 5 | -------------------------------------------------------------------------------- /src/tasks/ontonotes/guidelines.py: -------------------------------------------------------------------------------- 1 | GUIDELINES = { 2 | "ontonotes_person": {"en": ["People, including fictional."]}, 3 | "ontonotes_norp": {"en": ["Nationalities or religious or political groups."]}, 4 | "ontonotes_facility": {"en": ["Buildings, airports, highways, bridges, etc."]}, 5 | "ontonotes_organization": {"en": ["Companies, agencies, institutions, etc."]}, 6 | "ontonotes_gpe": {"en": ["Countries, cities, states."]}, 7 | "ontonotes_location": {"en": ["Non-GPE locations, mountain ranges, bodies of water."]}, 8 | "ontonotes_product": {"en": ["Objects, vehicles, foods, etc. (Not services)."]}, 9 | "ontonotes_event": {"en": ["Named hurricanes, battles, wars, sports events, etc."]}, 10 | "ontonotes_work_of_art": {"en": ["Titles of books, songs, etc."]}, 11 | "ontonotes_law": {"en": ["Named documents made into laws."]}, 12 | "ontonotes_language": {"en": ["Any named language."]}, 13 | "ontonotes_date": {"en": ["Absolute or relative dates or periods."]}, 14 | "ontonotes_time": {"en": ["Times smaller than a day."]}, 15 | "ontonotes_percent": {"en": ["Percentage, including ”%“."]}, 16 | "ontonotes_money": {"en": ["Monetary values, including unit."]}, 17 | "ontonotes_quantity": {"en": ["Measurements, as of weight or distance."]}, 18 | "ontonotes_ordinal": {"en": ["first, second, third, First, fourth, fifth, Second, seventh, eighth, sixth."]}, 19 | "ontonotes_cardinal": {"en": ["two, one, three, One, four, five, six, seven, Two, half."]}, 20 | } 21 | -------------------------------------------------------------------------------- /src/tasks/ontonotes/scorer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Type 2 | 3 | from src.tasks.ontonotes.prompts import ENTITY_DEFINITIONS 4 | from src.tasks.utils_scorer import SpanScorer 5 | from src.tasks.utils_typing import Entity 6 | 7 | 8 | class OntoNotesEntityScorer(SpanScorer): 9 | """CoNLL03 Entity identification and classification scorer.""" 10 | 11 | valid_types: List[Type] = ENTITY_DEFINITIONS 12 | 13 | def __call__(self, reference: List[Entity], predictions: List[Entity]) -> Dict[str, Dict[str, float]]: 14 | output = super().__call__(reference, predictions) 15 | return {"entities": output["spans"]} 16 | -------------------------------------------------------------------------------- /src/tasks/rams/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader, prompts, scorer 2 | 3 | 4 | __all__ = ["data_loader", "prompts", "scorer"] 5 | -------------------------------------------------------------------------------- /src/tasks/rams/scorer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Type 2 | 3 | from src.tasks.rams.prompts import EVENT_DEFINITIONS 4 | from src.tasks.utils_scorer import EventScorer 5 | 6 | 7 | class RAMSEventScorer(EventScorer): 8 | """RAMS Argument classification scorer.""" 9 | 10 | valid_types: List[Type] = EVENT_DEFINITIONS 11 | -------------------------------------------------------------------------------- /src/tasks/tacred/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader, prompts, scorer 2 | 3 | 4 | __all__ = ["data_loader", "prompts", "scorer"] 5 | -------------------------------------------------------------------------------- /src/tasks/tacred/scorer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Type 2 | 3 | from src.tasks.tacred.prompts import TEMPLATE_DEFINITIONS 4 | from src.tasks.utils_scorer import TemplateScorer 5 | 6 | 7 | class TACREDTemplateScorer(TemplateScorer): 8 | """TACRED Template scorer.""" 9 | 10 | valid_types: List[Type] = TEMPLATE_DEFINITIONS 11 | -------------------------------------------------------------------------------- /src/tasks/wikievents/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader, prompts, scorer 2 | 3 | 4 | __all__ = ["data_loader", "prompts", "scorer"] 5 | -------------------------------------------------------------------------------- /src/tasks/wikievents/scorer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Type 2 | 3 | from src.tasks.utils_scorer import EventScorer, SpanScorer 4 | from src.tasks.utils_typing import Entity 5 | from src.tasks.wikievents.prompts import ( 6 | COARSE_EVENT_DEFINITIONS, 7 | ENTITY_DEFINITIONS, 8 | EVENT_DEFINITIONS, 9 | ) 10 | 11 | 12 | class WikiEventsEntityScorer(SpanScorer): 13 | """WikiEvents Entity identification and classification scorer.""" 14 | 15 | valid_types: List[Type] = ENTITY_DEFINITIONS 16 | 17 | def __call__(self, reference: List[Entity], predictions: List[Entity]) -> Dict[str, Dict[str, float]]: 18 | output = super().__call__(reference, predictions) 19 | return {"entities": output["spans"]} 20 | 21 | 22 | class WikiEventsEventScorer(EventScorer): 23 | """WikiEvents Event and argument classification scorer.""" 24 | 25 | valid_types: List[Type] = COARSE_EVENT_DEFINITIONS 26 | 27 | 28 | class WikiEventsEventArgumentScorer(EventScorer): 29 | """WikiEvents Event and argument classification scorer.""" 30 | 31 | valid_types: List[Type] = EVENT_DEFINITIONS 32 | -------------------------------------------------------------------------------- /src/tasks/wnut/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader, prompts, scorer 2 | 3 | 4 | __all__ = ["data_loader", "prompts", "scorer"] 5 | -------------------------------------------------------------------------------- /src/tasks/wnut/prompts.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ..utils_typing import Entity, dataclass 4 | 5 | 6 | """Entity definitions 7 | 8 | The entity definitions are derived from the official WNUT17 guidelines: 9 | https://aclanthology.org/W17-4418.pdf 10 | 11 | 12 | """ 13 | 14 | 15 | @dataclass 16 | class Person(Entity): 17 | """{wnut_person}""" 18 | 19 | span: str # {wnut_person_examples} 20 | 21 | 22 | @dataclass 23 | class Location(Entity): 24 | """{wnut_location}""" 25 | 26 | span: str # {wnut_location_examples} 27 | 28 | 29 | @dataclass 30 | class Corporation(Entity): 31 | """{wnut_corporation}""" 32 | 33 | span: str # {wnut_corporation_examples} 34 | 35 | 36 | @dataclass 37 | class Product(Entity): 38 | """{wnut_product}""" 39 | 40 | span: str # {wnut_product_examples} 41 | 42 | 43 | @dataclass 44 | class CreativeWork(Entity): 45 | """{wnut_creativework}""" 46 | 47 | span: str # {wnut_creativework_examples} 48 | 49 | 50 | @dataclass 51 | class Group(Entity): 52 | """{wnut_group}""" 53 | 54 | span: str # {wnut_group_examples} 55 | 56 | 57 | ENTITY_DEFINITIONS: List[Entity] = [Person, Location, Corporation, Product, CreativeWork, Group] 58 | 59 | # __all__ = list(map(str, [*ENTITY_DEFINITIONS])) 60 | -------------------------------------------------------------------------------- /src/tasks/wnut/scorer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Type 2 | 3 | from src.tasks.utils_scorer import SpanScorer 4 | from src.tasks.utils_typing import Entity 5 | from src.tasks.wnut.prompts import ENTITY_DEFINITIONS 6 | 7 | 8 | class WnutEntityScorer(SpanScorer): 9 | """CoNLL03 Entity identification and classification scorer.""" 10 | 11 | valid_types: List[Type] = ENTITY_DEFINITIONS 12 | 13 | def __call__(self, reference: List[Entity], predictions: List[Entity]) -> Dict[str, Dict[str, float]]: 14 | output = super().__call__(reference, predictions) 15 | return {"entities": output["spans"]} 16 | -------------------------------------------------------------------------------- /src/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitz-zentroa/GoLLIE/164c611743fdc1befe71bbdf03e08c5eb4e35957/src/tests/__init__.py -------------------------------------------------------------------------------- /src/tests/test_prompts.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestEntityPrompts(unittest.TestCase): 5 | def test_ACE(self): 6 | pass 7 | -------------------------------------------------------------------------------- /templates/prompt.txt: -------------------------------------------------------------------------------- 1 | # The following lines describe the task definition 2 | {%- for definition in guidelines %} 3 | {{ definition }} 4 | {%- endfor %} 5 | 6 | # This is the text to analyze 7 | text = {{ text.__repr__() }} 8 | 9 | # The annotation instances that take place in the text above are listed here 10 | result = [ 11 | {%- for ann in annotations %} 12 | {{ ann }}, 13 | {%- endfor %} 14 | ] 15 | -------------------------------------------------------------------------------- /templates/prompt_ace_eae.txt: -------------------------------------------------------------------------------- 1 | # The following lines describe the task definition 2 | {%- for definition in guidelines %} 3 | {{ definition }} 4 | {%- endfor %} 5 | 6 | # This is the text to analyze 7 | text = {{ text.__repr__() }} 8 | 9 | # The list called result contains the instances for the following events according to the guidelines above: 10 | {%- for ann in gold %} 11 | # - "{{ann.mention}}" triggers a {{ann.__class__.__name__}} event. 12 | {%- endfor %} 13 | # 14 | result = [ 15 | {%- for ann in annotations %} 16 | {{ ann }}, 17 | {%- endfor %} 18 | ] 19 | -------------------------------------------------------------------------------- /templates/prompt_ace_rc.txt: -------------------------------------------------------------------------------- 1 | # The following lines describe the task definition 2 | {%- for definition in guidelines %} 3 | {{ definition }} 4 | {%- endfor %} 5 | 6 | # This is the text to analyze 7 | text = {{ text.__repr__() }} 8 | 9 | # The list called result contains the fine-grained relations for the following coarse-grained relations: 10 | {%- for ann in gold %} 11 | # - {{ ann.__repr__() }} 12 | {%- endfor %} 13 | # 14 | result = [ 15 | {%- for ann in annotations %} 16 | {{ ann }}, 17 | {%- endfor %} 18 | ] 19 | -------------------------------------------------------------------------------- /templates/prompt_ace_re.txt: -------------------------------------------------------------------------------- 1 | # The following lines describe the task definition 2 | {%- for definition in guidelines %} 3 | {{ definition }} 4 | {%- endfor %} 5 | 6 | # This is the text to analyze 7 | text = {{ text.__repr__() }} 8 | 9 | # The list called result contains the relation annotations for the following entities: 10 | {%- for ann in gold %} 11 | # - "{{ann.span}}": {{ann.__class__.__name__}} 12 | {%- endfor %} 13 | # 14 | result = [ 15 | {%- for ann in annotations %} 16 | {{ ann }}, 17 | {%- endfor %} 18 | ] 19 | -------------------------------------------------------------------------------- /templates/prompt_eae.txt: -------------------------------------------------------------------------------- 1 | # The following lines describe the task definition 2 | {%- for definition in guidelines %} 3 | {{ definition }} 4 | {%- endfor %} 5 | 6 | # This is the text to analyze 7 | text = {{ text.__repr__() }} 8 | 9 | # The list called result contains the instances for the following events: 10 | {%- for ann in annotations %} 11 | # - "{{ann.mention}}" triggers a {{ann.__class__.__name__}}.{{ann.subtype}} event. 12 | {%- endfor %} 13 | # 14 | result = [ 15 | {%- for ann in annotations %} 16 | {{ ann }}, 17 | {%- endfor %} 18 | ] 19 | -------------------------------------------------------------------------------- /templates/prompt_tacred.txt: -------------------------------------------------------------------------------- 1 | # The following lines describe the task definition 2 | {%- for definition in guidelines %} 3 | {{ definition }} 4 | {%- endfor %} 5 | 6 | # This is the text to analyze 7 | text = {{ text.__repr__() }} 8 | 9 | # The list called result contains the templates instances for the following entity queries: 10 | {%- for ann in gold %} 11 | # - {{ann.query}}: {{ann.__class__.__name__}} 12 | {%- endfor %} 13 | # 14 | result = [ 15 | {%- for ann in annotations %} 16 | {{ ann }}, 17 | {%- endfor %} 18 | ] 19 | --------------------------------------------------------------------------------