├── .dockerignore ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── Dockerfile ├── LICENSE ├── README.md ├── __init__.py ├── assets └── figures │ └── diagram.png ├── configs ├── callbacks │ ├── default.yaml │ ├── none.yaml │ └── sample_generator_wandb.yaml ├── datamodule │ ├── mcts.yaml │ ├── out_ds_rebel.yaml │ ├── out_ds_rtp.yaml │ ├── out_ds_swissprot.yaml │ ├── out_ds_wmt14.yaml │ ├── rebel.yaml │ ├── resample.yaml │ ├── rtp.yaml │ ├── swissprot.yaml │ └── wmt14.yaml ├── debug │ ├── default.yaml │ ├── default_run.yaml │ ├── fast.yaml │ ├── fast_run.yaml │ ├── limited.yaml │ ├── limited_run.yaml │ ├── overfit.yaml │ └── overfit_run.yaml ├── evaluate_from_file_root.yaml ├── evaluate_root.yaml ├── evaluation │ ├── genie_cie_large.yaml │ ├── genie_cie_large_resample.yaml │ ├── genie_cie_small.yaml │ ├── gpt2_toxicgen.yaml │ ├── gpt2_toxicgen_resample.yaml │ ├── mbart_translation.yaml │ ├── mbart_translation_resample.yaml │ └── protgpt2_design.yaml ├── evaluation_from_file │ ├── cie.yaml │ ├── resample.yaml │ ├── solubility.yaml │ ├── toxicity.yaml │ ├── toxicity_bootstrap.yaml │ ├── translation.yaml │ └── translation_bootstrap.yaml ├── evaluation_model │ ├── cie_noisy_oracle.yaml │ ├── cie_oracle.yaml │ ├── detoxify.yaml │ ├── detoxify_noisy_oracle.yaml │ ├── detoxify_noisy_oracle_rmse_01976.yaml │ ├── detoxify_noisy_oracle_rmse_01991.yaml │ ├── detoxify_noisy_oracle_rmse_02165.yaml │ ├── detoxify_noisy_oracle_rmse_02242.yaml │ ├── detoxify_oracle.yaml │ ├── likelihood.yaml │ ├── mt_noisy_oracle.yaml │ ├── mt_noisy_oracle_b2.yaml │ ├── mt_noisy_oracle_b3.yaml │ ├── mt_oracle.yaml │ ├── perfect_oracle.yaml │ ├── solubility.yaml │ ├── solubility_noisy_oracle.yaml │ └── solubility_oracle.yaml ├── hparam_search │ ├── base_hparam_search.yaml │ ├── launcher │ │ └── local_launcher.yaml │ └── search │ │ └── grid_search.yaml ├── hparam_search_root.yaml ├── hydra │ ├── debug.yaml │ ├── evaluation.yaml │ ├── hparam_search.yaml │ └── training.yaml ├── logger │ ├── python_logger.yaml │ ├── wandb.yaml │ ├── wandb_group.yaml │ └── wandb_standalone.yaml ├── mdp │ └── search_tree_factory.yaml ├── metric │ ├── bleu-1.yaml │ ├── bleu-4.yaml │ ├── detoxify.yaml │ ├── protein_soluble.yaml │ ├── triplet_set_f1.yaml │ ├── triplet_set_precision.yaml │ └── triplet_set_recall.yaml ├── model │ ├── additional_logger │ │ └── reconstruction_logger.yaml │ ├── ckpt_generic.yaml │ ├── ckpt_genie_genre_r.yaml │ ├── collator │ │ └── rebel_collator.yaml │ ├── decoding │ │ ├── beam_sampling.yaml │ │ ├── beam_sampling_top_p.yaml │ │ ├── beam_search.yaml │ │ ├── genie_generic.yaml │ │ ├── gpt_generic.yaml │ │ ├── greedy_search.yaml │ │ ├── mbart_generic.yaml │ │ ├── mcts.yaml │ │ ├── mcts_qtransform │ │ │ ├── by_min_max.yaml │ │ │ ├── by_parent_and_siblings.yaml │ │ │ ├── completed_by_mix_value.yaml │ │ │ └── noop.yaml │ │ ├── pplmcts.yaml │ │ ├── protgpt2_generic.yaml │ │ ├── stochastic_beams.yaml │ │ └── value_guided_beam_search.yaml │ ├── genie_base.yaml │ ├── gpt2_lm.yaml │ ├── mbart_cd.yaml │ ├── optimizer │ │ ├── adam.yaml │ │ └── adam_w.yaml │ ├── protgpt2_lm.yaml │ ├── scheduler_config │ │ ├── linear.yaml │ │ ├── polynomial.yaml │ │ └── reduce_on_plateau.yaml │ └── tree_lm.yaml ├── testing │ ├── cie_task.yaml │ ├── stochastic_beams.yaml │ └── translation_task.yaml ├── train_root.yaml └── trainer │ ├── cpu.yaml │ ├── ddp.yaml │ └── gpu.yaml ├── data ├── prompts_sports │ ├── sports_understanding_8shot_direct_s0_all_probs │ ├── sports_understanding_8shot_direct_s0_ll │ ├── sports_understanding_8shot_direct_s0_predictions │ ├── sports_understanding_8shot_stream_s0_all_probs │ ├── sports_understanding_8shot_stream_s0_ll │ ├── sports_understanding_8shot_stream_s0_predictions │ ├── sports_understanding_8shot_stream_s0_targets │ ├── tnlgv2_zero-shot_sports.jsonl │ ├── tnlgv2_zero-shot_sports_all_probs │ ├── tnlgv2_zero-shot_sports_ll │ └── tnlgv2_zero-shot_sports_predictions ├── rebel │ ├── en_test_10k.jsonl │ └── en_test_15k.jsonl └── rtp │ └── test_10000_prompts.jsonl ├── detoxify ├── __init__.py ├── evaluate.py ├── main.py ├── src │ ├── __init__.py │ ├── data_loaders.py │ └── utils.py └── train.py ├── download_data.sh ├── notebooks └── experiment_logs.ipynb ├── pip_requirements.txt ├── requirements.yaml ├── run_detoxify_evaluation.py ├── run_evaluation.py ├── run_evaluation_from_file.py ├── run_hparam_search.py ├── run_training.py ├── scripts ├── __init__.py ├── get_analysis_tree.py ├── get_detoxify_evaluation_results.py ├── get_genie_ckpt.py ├── joint_plot.py ├── launch_evaluation_from_file.sh ├── launchers │ ├── greedy-beam_search-stochastic_beams.sh │ ├── mt_mcts_run1.sh │ ├── mt_mcts_run2.sh │ ├── mt_mcts_run3.sh │ ├── mt_mcts_run4.sh │ ├── mt_mcts_run5.sh │ ├── mt_mcts_run6.sh │ ├── mt_vgbs_run1.sh │ ├── mt_vgbs_run2.sh │ ├── mt_vgbs_run3.sh │ ├── mt_vgbs_run4.sh │ ├── mt_vgbs_run5.sh │ ├── mt_vgbs_run6.sh │ ├── toxicity_mcts_run1.sh │ ├── toxicity_mcts_run2.sh │ ├── toxicity_mcts_run3.sh │ ├── toxicity_mcts_run4.sh │ ├── toxicity_vgbs_run1.sh │ ├── toxicity_vgbs_run2.sh │ ├── toxicity_vgbs_run3.sh │ └── toxicity_vgbs_run4.sh ├── plotting │ ├── figure_2_RQ1.sh │ ├── figure_2_RQ1_unnormalized.sh │ ├── figure_3_RQ1.sh │ ├── figure_4_RQ2.sh │ └── figure_5_RQ3.sh ├── visualize_figure4.py ├── visualize_figure5.py └── visualize_joint_plots.py └── src ├── __init__.py ├── constrained_generation ├── __init__.py ├── ie_prefix_constraints.py └── trie.py ├── datamodules ├── __init__.py ├── abstract.py ├── helper_data_classes.py ├── rebel.py ├── rtp.py ├── swissprot.py ├── utils │ ├── __init__.py │ ├── resampling_utils.py │ └── triplet_utils.py └── wmt14.py ├── evaluation_from_file_pipeline.py ├── evaluation_pipeline.py ├── hparam_search ├── __init__.py ├── launchers │ ├── __init__.py │ ├── base_launcher.py │ └── local_launcher.py └── search │ ├── __init__.py │ ├── base_search.py │ └── grid_search.py ├── hparam_search_pipeline.py ├── loggers ├── __init__.py └── reconstruction_logger.py ├── mdp ├── __init__.py ├── evaluation_models │ ├── __init__.py │ ├── abstract.py │ ├── detoxify_evaluation_model.py │ ├── explicit_evaluation_model.py │ ├── likelihood_evaluation_model.py │ ├── noisy_oracle_detoxify.py │ ├── noisy_oracle_genie.py │ ├── noisy_oracle_mt.py │ ├── noisy_oracle_solubility.py │ ├── oracle_detoxify.py │ ├── oracle_genie.py │ ├── oracle_mt.py │ ├── oracle_solubility.py │ ├── solubility_evaluation_model.py │ └── utils.py ├── oracles │ ├── __init__.py │ └── perfect_oracle.py ├── search_tree_factory.py ├── transition_model.py ├── tree.py ├── tree_lm.py ├── trees │ ├── __init__.py │ ├── abstract.py │ ├── analysis_tree.py │ ├── explicit_tree.py │ └── lm_as_tree.py ├── utility_value_model.py └── utils.py ├── metrics ├── __init__.py ├── bleu.py ├── detoxify.py ├── protbert.py ├── triplet_set_f1.py ├── triplet_set_precision.py └── triplet_set_recall.py ├── models ├── __init__.py ├── collators │ ├── __init__.py │ └── rebel_collator.py ├── genie.py ├── gpt2_for_generation.py ├── mbart_for_conditional_generation.py ├── mcts_mixin.py ├── protgpt2_for_generation.py └── utils │ ├── __init__.py │ └── general.py ├── training_pipeline.py └── utils ├── __init__.py ├── additional_loggers.py ├── evaluation.py ├── general.py ├── hparam_search_utils.py ├── hydra_ad_hoc.py ├── hydra_custom_resolvers.py ├── oracles.py ├── score_helpers.py └── scorers.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # Ignore everything 2 | ** 3 | 4 | # Allow files and directories 5 | !pip_requirements.txt 6 | !requirements.yaml 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### macOS ### 2 | # General 3 | .DS_Store 4 | .AppleDouble 5 | .LSOverride 6 | 7 | # Icon must end with two \r 8 | Icon 9 | 10 | # Thumbnails 11 | ._* 12 | 13 | # Files that might appear in the root of a volume 14 | .DocumentRevisions-V100 15 | .fseventsd 16 | .Spotlight-V100 17 | .TemporaryItems 18 | .Trashes 19 | .VolumeIcon.icns 20 | .com.apple.timemachine.donotpresent 21 | 22 | # Directories potentially created on remote AFP share 23 | .AppleDB 24 | .AppleDesktop 25 | Network Trash Folder 26 | Temporary Items 27 | .apdisk 28 | 29 | # Byte-compiled / optimized / DLL files 30 | __pycache__/ 31 | *.py[cod] 32 | *$py.class 33 | 34 | # C extensions 35 | *.so 36 | 37 | # Distribution / packaging 38 | .Python 39 | env/ 40 | build/ 41 | develop-eggs/ 42 | dist/ 43 | downloads/ 44 | eggs/ 45 | .eggs/ 46 | lib/ 47 | lib64/ 48 | parts/ 49 | sdist/ 50 | var/ 51 | wheels/ 52 | *.egg-info/ 53 | .installed.cfg 54 | *.egg 55 | 56 | # PyInstaller 57 | # Usually these files are written by a python script from a template 58 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 59 | *.manifest 60 | *.spec 61 | 62 | # Installer logs 63 | pip-log.txt 64 | pip-delete-this-directory.txt 65 | 66 | # Unit tests / coverage reports 67 | htmlcov/ 68 | .tox/ 69 | .coverage 70 | .coverage.* 71 | .cache 72 | nosetests.xml 73 | coverage.xml 74 | *.cover 75 | .hypothesis/ 76 | 77 | # Translations 78 | *.mo 79 | *.pot 80 | 81 | # Django stuff: 82 | *.log 83 | local_settings.py 84 | 85 | # Flask stuff: 86 | instance/ 87 | .webassets-cache 88 | 89 | # Scrapy stuff: 90 | .scrapy 91 | 92 | # Sphinx documentation 93 | docs/_build/ 94 | docs/build/ 95 | docs/docs/ 96 | 97 | # PyBuilder 98 | target/ 99 | 100 | # Jupyter Notebook 101 | .ipynb_checkpoints 102 | 103 | # pyenv 104 | .python-version 105 | 106 | # celery beat schedule file 107 | celerybeat-schedule 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # dotenv 113 | .env 114 | 115 | # virtualenv 116 | .venv 117 | venv/ 118 | ENV/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | 133 | # data folder 134 | data/* 135 | !data/rtp/test_10000_prompts.jsonl 136 | !data/prompts* 137 | 138 | # datafiles 139 | .xml 140 | .pkl 141 | 142 | # misc 143 | .idea/ 144 | .iml 145 | .dropbox 146 | 147 | # media files 148 | .png 149 | .jpg 150 | .pdf 151 | 152 | # logs folder 153 | logs/ 154 | wandb/ 155 | outputs/ 156 | 157 | # dev files 158 | notebooks/exploration.ipynb 159 | notebooks/logs/ 160 | test.ipynb 161 | unused_code.py 162 | 163 | logs/ 164 | plots/ 165 | configs/test_root.yaml 166 | temp* 167 | wandb/ 168 | scripts/*.json 169 | config_tree.txt 170 | /tests/debug_mcts_trees/ 171 | images/ 172 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "mctx"] 2 | path = mctx 3 | url = https://github.com/epfl-dlab/mctx.git 4 | [submodule "transformers"] 5 | path = transformers 6 | url = https://github.com/epfl-dlab/transformers.git 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v3.4.0 4 | hooks: 5 | # list of supported hooks: https://pre-commit.com/hooks.html 6 | - id: trailing-whitespace 7 | - id: end-of-file-fixer 8 | - id: check-yaml 9 | - id: check-json 10 | - id: check-added-large-files 11 | - id: detect-private-key 12 | 13 | # Black Python code formatting 14 | - repo: https://github.com/psf/black 15 | rev: 20.8b1 16 | hooks: 17 | - id: black 18 | args: [--line-length, "120"] 19 | 20 | # Yaml formatting 21 | - repo: https://github.com/pre-commit/mirrors-prettier 22 | rev: v2.3.0 23 | hooks: 24 | - id: prettier 25 | types: [yaml] 26 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ### TODO: Make sure that the cuda driver's version matches (also below) 2 | ### TODO: Make sure you have your .dockerignorefile 3 | FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04 4 | 5 | ### Use bash as the default shelll 6 | RUN chsh -s /bin/bash 7 | SHELL ["bash", "-c"] 8 | 9 | ### Add new Nvidia keys 10 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub 11 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub 12 | 13 | ### Install basics 14 | RUN apt-get update && \ 15 | apt-get install -y openssh-server sudo nano screen wget bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 git mercurial subversion unzip graphviz graphviz-dev && \ 16 | apt-get clean 17 | 18 | # Install miniconda 19 | ENV PATH /opt/conda/bin:$PATH 20 | 21 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-py38_4.11.0-Linux-x86_64.sh -O ~/miniconda.sh && \ 22 | mkdir ~/.conda && \ 23 | /bin/bash ~/miniconda.sh -b -p /opt/conda && \ 24 | rm ~/miniconda.sh && \ 25 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ 26 | echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ 27 | find /opt/conda/ -follow -type f -name '*.a' -delete && \ 28 | find /opt/conda/ -follow -type f -name '*.js.map' -delete && \ 29 | /opt/conda/bin/conda clean -afy 30 | 31 | ### Install Environment 32 | ENV envname understanding_decoding 33 | RUN env "PATH=$PATH" conda update conda && \ 34 | conda create --name $envname python=3.8 35 | 36 | ### Install a version of pytorch that is compatible with the installed cudatoolkit 37 | RUN conda install -n $envname pytorch=1.8.0 torchvision torchaudio cudatoolkit=10.1 -c pytorch 38 | COPY requirements.yaml /tmp/requirements.yaml 39 | RUN conda env update --name $envname --file /tmp/requirements.yaml --prune 40 | COPY pip_requirements.txt /tmp/pip_requirements.txt 41 | RUN conda run -n $envname pip install -r /tmp/pip_requirements.txt 42 | RUN conda install -n $envname pygraphviz==1.9 -c conda-forge 43 | RUN echo "conda activate $envname" >> ~/.bashrc && \ 44 | conda run -n $envname python -m ipykernel install --user --name=$envname 45 | 46 | ### Setup-for installing apex 47 | ### TODO: Make sure that the cuda driver's version matches 48 | ENV CUDAVER cuda-10.1 49 | ENV PATH /usr/local/$CUDAVER/bin:$PATH 50 | ENV LD_LIBRARY_PATH /usr/local/$CUDAVER/lib:$LD_LIBRARY_PATH 51 | ENV LD_LIBRARY_PATH /usr/local/$CUDAVER/lib64:$LD_LIBRARY_PATH 52 | ENV CUDA_PATH /usr/local/$CUDAVER 53 | ENV CUDA_ROOT /usr/local/$CUDAVER 54 | ENV CUDA_HOME /usr/local/$CUDAVER 55 | ENV CUDA_HOST_COMPILER /usr/bin/gcc 56 | 57 | RUN mkdir /tmp/unique_for_apex 58 | WORKDIR /tmp/unique_for_apex 59 | RUN git clone https://github.com/NVIDIA/apex.git 60 | WORKDIR /tmp/unique_for_apex/apex 61 | RUN /opt/conda/envs/$envname/bin/pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . 62 | 63 | WORKDIR / 64 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 epfl-dlab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-dlab/understanding-decoding/f0976e67964e5d083e90f2ff0b2a1662aa240ba3/__init__.py -------------------------------------------------------------------------------- /assets/figures/diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-dlab/understanding-decoding/f0976e67964e5d083e90f2ff0b2a1662aa240ba3/assets/figures/diagram.png -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: "val/loss" # name of the logged metric which determines when model is improving 4 | mode: "min" # can be "max" or "min" 5 | save_top_k: 3 # save k best models (determined by above metric) 6 | save_last: True # additionally always save model from last epoch 7 | verbose: False 8 | dirpath: "checkpoints/" 9 | filename: "model-{step:04d}-{val/loss:.4f}" 10 | save_on_train_epoch_end: False 11 | auto_insert_metric_name: False 12 | 13 | learning_rate_monitor: 14 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 15 | logging_interval: "step" 16 | 17 | early_stopping: 18 | _target_: pytorch_lightning.callbacks.EarlyStopping 19 | monitor: "val/loss" # name of the logged metric which determines when model is improving 20 | mode: "min" # can be "max" or "min" 21 | patience: 3 # how many validation epochs (note that this might differ from training epochs) of not improving until training stops 22 | min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement 23 | -------------------------------------------------------------------------------- /configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-dlab/understanding-decoding/f0976e67964e5d083e90f2ff0b2a1662aa240ba3/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /configs/callbacks/sample_generator_wandb.yaml: -------------------------------------------------------------------------------- 1 | sample_generator_wandb: 2 | _target_: src.callbacks.generate_samples.GenerateTextSamplesCallback 3 | logging_batch_interval: ${trainer.val_check_interval} # Invoke wandb logger for sample generation every m training batches. 4 | -------------------------------------------------------------------------------- /configs/datamodule/mcts.yaml: -------------------------------------------------------------------------------- 1 | dataset_parameters: 2 | train: 3 | dataloader: 4 | batch_size: 1 5 | 6 | val: 7 | dataloader: 8 | batch_size: 1 9 | 10 | test: 11 | dataloader: 12 | batch_size: 1 13 | -------------------------------------------------------------------------------- /configs/datamodule/out_ds_rebel.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.RebelOutputDataset 2 | exp_dir: ${..exp_dir} 3 | -------------------------------------------------------------------------------- /configs/datamodule/out_ds_rtp.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.RTPOutputDataset 2 | exp_dir: ${..exp_dir} 3 | -------------------------------------------------------------------------------- /configs/datamodule/out_ds_swissprot.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.ProteinOutputDataset 2 | exp_dir: ${..exp_dir} 3 | -------------------------------------------------------------------------------- /configs/datamodule/out_ds_wmt14.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.WMT14OutputDataset 2 | exp_dir: ${..exp_dir} 3 | -------------------------------------------------------------------------------- /configs/datamodule/rebel.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.RebelDataModule 2 | num_workers: 10 3 | pin_memory: false 4 | seed: ${seed} 5 | 6 | debug: false 7 | debug_k: 12 # needs to be specified when debug is True 8 | 9 | data_dir: ${data_dir}rebel 10 | 11 | # Parameter that controls the quality of the matching allowed in the data 12 | # title: allows only entities and relations with guaranteed unique textual identifiers coming from the title 13 | # label: additionally allows for labels which could potentially not be unique to be used as identifiers (see the section covering the ID2Name dictionaries in the Demo notebook for details) 14 | matching_status: "title" # title or label 15 | 16 | # Ignores any samples with at least one relation that is not inside the relations_to_keep.jsonl file 17 | relations_to_keep: ${data_dir}world_definitions/complete_relations.jsonl 18 | 19 | # Keeps all of the samples that pass the the relations_to_keep filter, but ignores some of them 20 | # relations_to_ignore: # ignores the relations in this file 21 | # relation_not_to_ignore: # ignores all of the relations in the set relations_to_keep - relation_not_to_ignore 22 | 23 | dataset_parameters: 24 | train: 25 | dataset: 26 | _target_: src.datamodules.RebelDataset 27 | seed: ${datamodule.seed} 28 | debug: ${datamodule.debug} 29 | debug_k: ${datamodule.debug_k} 30 | load_dataset_params: 31 | split: "train" 32 | data_dir: ${datamodule.data_dir} 33 | matching_status: ${datamodule.matching_status} 34 | relations_to_keep: ${datamodule.relations_to_keep} 35 | 36 | dataloader: 37 | batch_size: 4 38 | num_workers: ${datamodule.num_workers} 39 | 40 | val: 41 | dataset: 42 | _target_: src.datamodules.RebelDataset 43 | seed: ${datamodule.seed} 44 | debug: ${datamodule.debug} 45 | debug_k: ${datamodule.debug_k} 46 | load_dataset_params: 47 | split: "val" 48 | data_dir: ${datamodule.data_dir} 49 | matching_status: ${datamodule.matching_status} 50 | relations_to_keep: ${datamodule.relations_to_keep} 51 | 52 | dataloader: 53 | batch_size: 4 54 | num_workers: ${datamodule.num_workers} 55 | 56 | test: 57 | dataset: 58 | _target_: src.datamodules.RebelDataset 59 | seed: ${datamodule.seed} 60 | debug: ${datamodule.debug} 61 | debug_k: ${datamodule.debug_k} 62 | load_dataset_params: 63 | split: "test_10k" 64 | data_dir: ${datamodule.data_dir} 65 | matching_status: ${datamodule.matching_status} 66 | relations_to_keep: ${datamodule.relations_to_keep} 67 | 68 | dataloader: 69 | batch_size: 4 70 | num_workers: ${datamodule.num_workers} 71 | -------------------------------------------------------------------------------- /configs/datamodule/resample.yaml: -------------------------------------------------------------------------------- 1 | wandb_run_path: ??? 2 | exp_dir: None 3 | num_qs: ??? 4 | 5 | dataset_parameters: 6 | test: 7 | dataset: 8 | resample_exp: true 9 | exp_dir: ${datamodule.exp_dir} 10 | num_qs: ${datamodule.num_qs} 11 | 12 | dataloader: 13 | batch_size: 1 14 | -------------------------------------------------------------------------------- /configs/datamodule/rtp.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.RTPDataModule 2 | num_workers: 24 3 | seed: ${seed} # not used 4 | 5 | debug: true 6 | debug_k: 12 # needs to be specified when debug is True 7 | 8 | dataset_data_dir: ${data_dir}rtp/ 9 | dataset_parameters: 10 | train: 11 | dataset: 12 | _target_: src.datamodules.RTPDataset 13 | seed: ${datamodule.seed} 14 | debug: ${datamodule.debug} 15 | debug_k: ${datamodule.debug_k} 16 | dataset_data_dir: ${data_dir}rtp/ 17 | split: "train" 18 | 19 | dataloader: 20 | batch_size: 5 21 | num_workers: ${datamodule.num_workers} 22 | 23 | val: 24 | dataset: 25 | _target_: src.datamodules.RTPDataset 26 | seed: ${datamodule.seed} 27 | debug: ${datamodule.debug} 28 | debug_k: ${datamodule.debug_k} 29 | dataset_data_dir: ${data_dir}rtp/ 30 | split: "validation" 31 | 32 | dataloader: 33 | batch_size: 5 34 | num_workers: ${datamodule.num_workers} 35 | 36 | test: 37 | dataset: 38 | _target_: src.datamodules.RTPDataset 39 | seed: ${datamodule.seed} 40 | debug: ${datamodule.debug} 41 | debug_k: ${datamodule.debug_k} 42 | dataset_data_dir: ${data_dir}rtp/ 43 | split: "test" 44 | 45 | dataloader: 46 | batch_size: 4 47 | num_workers: ${datamodule.num_workers} 48 | -------------------------------------------------------------------------------- /configs/datamodule/swissprot.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.ProteinDataModule 2 | num_workers: 24 3 | seed: ${seed} # not used 4 | 5 | debug: true 6 | debug_k: 12 # needs to be specified when debug is True 7 | 8 | prompt_length: 50 9 | 10 | dataset_parameters: 11 | train: 12 | dataset: 13 | _target_: src.datamodules.ProteinDataset 14 | seed: ${datamodule.seed} 15 | debug: ${datamodule.debug} 16 | debug_k: ${datamodule.debug_k} 17 | split: "train" 18 | prompt_length: 50 19 | 20 | dataloader: 21 | batch_size: 5 22 | num_workers: ${datamodule.num_workers} 23 | 24 | val: 25 | dataset: 26 | _target_: src.datamodules.ProteinDataset 27 | seed: ${datamodule.seed} 28 | debug: ${datamodule.debug} 29 | debug_k: ${datamodule.debug_k} 30 | split: "validation" 31 | prompt_length: 50 32 | 33 | dataloader: 34 | batch_size: 5 35 | num_workers: ${datamodule.num_workers} 36 | 37 | test: 38 | dataset: 39 | _target_: src.datamodules.ProteinDataset 40 | seed: ${datamodule.seed} 41 | debug: ${datamodule.debug} 42 | debug_k: ${datamodule.debug_k} 43 | split: "test" 44 | prompt_length: 50 45 | 46 | dataloader: 47 | batch_size: 4 48 | num_workers: ${datamodule.num_workers} 49 | -------------------------------------------------------------------------------- /configs/datamodule/wmt14.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.WMT14DataModule 2 | num_workers: 24 3 | seed: ${seed} # controls the seed for the noisy oracles 4 | 5 | debug: false 6 | debug_k: 12 # needs to be specified when debug is True 7 | cache_dir: ${data_dir}wmt14 # The default is ~/.cache/huggingface/datasets, which has rather limited space 8 | 9 | dataset_parameters: 10 | train: 11 | dataset: 12 | _target_: src.datamodules.WMT14Dataset 13 | seed: ${datamodule.seed} 14 | debug: ${datamodule.debug} 15 | debug_k: ${datamodule.debug_k} 16 | load_dataset_params: 17 | name: "fr-en" 18 | split: "train" 19 | cache_dir: ${datamodule.cache_dir} 20 | 21 | dataloader: 22 | batch_size: 5 23 | num_workers: ${datamodule.num_workers} 24 | 25 | val: 26 | dataset: 27 | _target_: src.datamodules.WMT14Dataset 28 | seed: ${datamodule.seed} 29 | debug: ${datamodule.debug} 30 | debug_k: ${datamodule.debug_k} 31 | load_dataset_params: 32 | name: "fr-en" 33 | split: "validation" 34 | cache_dir: ${datamodule.cache_dir} 35 | 36 | dataloader: 37 | batch_size: 5 38 | num_workers: ${datamodule.num_workers} 39 | 40 | test: 41 | dataset: 42 | _target_: src.datamodules.WMT14Dataset 43 | seed: ${datamodule.seed} 44 | debug: ${datamodule.debug} 45 | debug_k: ${datamodule.debug_k} 46 | load_dataset_params: 47 | name: "fr-en" 48 | split: "test" 49 | cache_dir: ${datamodule.cache_dir} 50 | 51 | dataloader: 52 | batch_size: 4 53 | num_workers: ${datamodule.num_workers} 54 | -------------------------------------------------------------------------------- /configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup which runs for 100 steps – a starting point for the other configs 4 | 5 | defaults: 6 | - override /hydra: debug 7 | 8 | num_steps: 100 9 | num_validation_runs: 3 10 | 11 | trainer: 12 | max_steps: ${num_steps} 13 | gpus: 0 # debuggers don't like gpus 14 | track_grad_norm: 2 # track gradient norm with loggers 15 | val_check_interval: ${floor_div:${num_steps}, ${num_validation_runs}} 16 | # Available in newer version of PL 17 | # detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 18 | 19 | datamodule: 20 | num_workers: 0 # debuggers don't like multiprocessing 21 | pin_memory: False # disable gpu memory pin 22 | 23 | # sets level of all command line loggers to 'DEBUG' 24 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 25 | hydra: 26 | verbose: True 27 | 28 | # config is already printed by hydra when `hydra/verbose: True` 29 | print_config: False 30 | -------------------------------------------------------------------------------- /configs/debug/default_run.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup which runs for 100 steps – a starting point for the other configs 4 | # doesn't override any trainer / accelerator parameters 5 | 6 | defaults: 7 | - override /hydra: debug 8 | 9 | num_steps: 100 10 | num_validation_runs: 1 11 | 12 | trainer: 13 | max_steps: ${num_steps} 14 | track_grad_norm: 2 # track gradient norm with loggers 15 | val_check_interval: ${floor_div:${num_steps}, ${num_validation_runs}} 16 | # Available in newer version of PL 17 | # detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 18 | 19 | # sets level of all command line loggers to 'DEBUG' 20 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 21 | hydra: 22 | verbose: True 23 | 24 | # config is already printed by hydra when `hydra/verbose: True` 25 | print_config: False 26 | -------------------------------------------------------------------------------- /configs/debug/fast.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # disables the callbacks and the loggers 4 | # runs `num_batches` train, validation and test steps 5 | 6 | # ideal for fast debugging 7 | 8 | defaults: 9 | - default 10 | 11 | trainer: 12 | fast_dev_run: ${num_batches} 13 | 14 | ckpt_path: null 15 | num_batches: 5 16 | -------------------------------------------------------------------------------- /configs/debug/fast_run.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # disables the callbacks and the loggers 4 | # runs `num_batches` train, validation and test steps 5 | 6 | # ideal for a fast debug run 7 | 8 | defaults: 9 | - default_run.yaml 10 | 11 | trainer: 12 | fast_dev_run: ${num_batches} 13 | 14 | ckpt_path: null 15 | num_batches: 5 16 | -------------------------------------------------------------------------------- /configs/debug/limited.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs `num_steps` of training with `num_validation_runs` validation cycles 4 | # runs testing 5 | # the train, validation, test are all limited to (the first) `num_batches` batches (in each dataloader) 6 | 7 | defaults: 8 | - default 9 | 10 | trainer: 11 | max_steps: ${num_steps} 12 | val_check_interval: ${floor_div:${num_steps}, ${num_validation_runs}} 13 | 14 | limit_train_batches: ${num_batches} 15 | limit_val_batches: ${num_batches} 16 | limit_test_batches: ${num_batches} 17 | 18 | ckpt_path: null # To run the testing with a specific checkpoint, pass the path to the .ckpt file 19 | num_steps: 40 20 | num_batches: 20 21 | num_validation_runs: 2 22 | -------------------------------------------------------------------------------- /configs/debug/limited_run.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs `num_steps` of training with `num_validation_runs` validation cycles 4 | # runs testing 5 | # the train, validation, test are all limited to (the first) `num_batches` batches (in each dataloader) 6 | 7 | defaults: 8 | - default_run 9 | 10 | trainer: 11 | max_steps: ${num_steps} 12 | val_check_interval: ${floor_div:${num_steps}, ${num_validation_runs}} 13 | 14 | limit_train_batches: ${num_batches} 15 | limit_val_batches: ${num_batches} 16 | limit_test_batches: ${num_batches} 17 | 18 | ckpt_path: null # To run the testing with a specific checkpoint, pass the path to the .ckpt file 19 | num_steps: 20 20 | num_batches: 5 21 | num_validation_runs: 2 22 | -------------------------------------------------------------------------------- /configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs `num_steps` of training with `num_validation_runs` validation cycles 4 | # uses the same (first) `num_batches` batches of the training dataloader for both training and validation 5 | 6 | defaults: 7 | - default.yaml 8 | 9 | trainer: 10 | max_steps: ${num_steps} 11 | overfit_batches: ${num_batches} 12 | val_check_interval: ${floor_div:${num_steps}, ${num_validation_runs}} 13 | 14 | num_steps: 20000 15 | num_batches: 5 16 | num_validation_runs: 4000 17 | -------------------------------------------------------------------------------- /configs/debug/overfit_run.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 2 epochs of training, validation 4 | # each with a limited number of `num_batches` batches 5 | 6 | defaults: 7 | - default_run.yaml 8 | 9 | trainer: 10 | max_steps: ${num_steps} 11 | overfit_batches: ${num_batches} 12 | val_check_interval: ${floor_div:${num_steps}, ${num_validation_runs}} 13 | 14 | num_steps: 20 15 | num_batches: 5 16 | num_validation_runs: 2 17 | -------------------------------------------------------------------------------- /configs/evaluate_from_file_root.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - datamodule: null 6 | - evaluation_from_file: null 7 | - logger: null # set to null if you don't want to use any loggers 8 | 9 | # debugging config (enable through command line, e.g. `python train.py debug=fast) 10 | - debug: null 11 | 12 | # enable color logging 13 | - /logger@logger_config: python_logger 14 | # - override hydra/job_logging: colorlog 15 | 16 | wandb_run_path: ??? 17 | exp_dir: None 18 | 19 | # verbose explanation: Hydra hijacks the working directory by changing it to the current log directory, 20 | # so it's useful to have this path as a special variable 21 | # learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 22 | work_dir: ${hydra:runtime.cwd} 23 | 24 | # path to folder with data 25 | data_dir: ${work_dir}/data/ 26 | 27 | # pretty print config at the start of the run using Rich library 28 | print_config: True 29 | 30 | # disable python warnings if they annoy you 31 | ignore_warnings: True 32 | 33 | # seed for random number generators in pytorch, numpy and python.random 34 | seed: 123 35 | -------------------------------------------------------------------------------- /configs/evaluate_root.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - trainer: gpu.yaml 6 | - model: mbart_cd 7 | - evaluation_model: null 8 | - datamodule: wmt14 9 | - callbacks: default # set this to null if you don't want to use callbacks 10 | - logger: wandb # set to null if you don't want to use any loggers 11 | - hydra: evaluation 12 | 13 | - evaluation: null 14 | 15 | # debugging config (enable through command line, e.g. `python train.py debug=fast) 16 | - debug: null 17 | 18 | # enable color logging 19 | - override hydra/hydra_logging: colorlog 20 | - override hydra/job_logging: colorlog 21 | 22 | #model: 23 | # checkpoint_path: ${ckpt_path} 24 | # write_testing_output: True 25 | 26 | # verbose explanation: Hydra hijacks the working directory by changing it to the current log directory, 27 | # so it's useful to have this path as a special variable 28 | # learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 29 | work_dir: ${hydra:runtime.cwd} 30 | 31 | # path to folder with data 32 | data_dir: ${work_dir}/data/ 33 | 34 | # pretty print config at the start of the run using Rich library 35 | print_config: True 36 | 37 | # disable python warnings if they annoy you 38 | ignore_warnings: True 39 | 40 | # seed for random number generators in pytorch, numpy and python.random 41 | seed: 123 42 | 43 | # the experiment name – determines the logging folder's path 44 | run_name: ??? 45 | 46 | # passing checkpoint path is necessary 47 | ckpt_path: ??? 48 | -------------------------------------------------------------------------------- /configs/evaluation/genie_cie_large.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run_evaluation.py evaluation=gs-st_grid_identity 5 | 6 | defaults: 7 | - override /datamodule: rebel 8 | - override /model: ckpt_genie_genre_r 9 | - override /callbacks: null 10 | - override /trainer: gpu 11 | 12 | # the parameters below will be merged with parameters from the default configurations set above 13 | # this allows you to overwrite only a small/specific set of parameters 14 | datamodule: 15 | debug: False 16 | debug_k: 12 17 | 18 | model: 19 | checkpoint_path: ${data_dir}/models/genie_genre_r.ckpt 20 | save_testing_output: 21 | save_log_likelihood: True 22 | 23 | hparams_overrides: 24 | decoding: 25 | # Results in undesired behavior for decoding strategies with fixed `num_beams` (e.g. greedy search) 26 | # hf_generation_params: 27 | # num_beams: 10 28 | # num_return_sequences: ${.num_beams} 29 | free_generation: False # Set to true in development – loading the tries take time 30 | entity_trie_path: ${data_dir}/tries/large/entity_trie.pickle 31 | relation_trie_path: ${data_dir}/tries/large/relation_trie.pickle 32 | 33 | # name of the run determines folder name in logs 34 | run_name: "genie_cie_large" 35 | -------------------------------------------------------------------------------- /configs/evaluation/genie_cie_large_resample.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - genie_cie_large 5 | - override /datamodule: [rebel, resample] 6 | - override /model/decoding: [genie_generic, stochastic_beams] 7 | 8 | # the parameters below will be merged with parameters from the default configurations set above 9 | # this allows you to overwrite only a small/specific set of parameters 10 | datamodule: 11 | debug: False 12 | debug_k: 12 13 | 14 | num_qs: 10 # Number of quantiles to select (e.g. 10 will produce 11 plots corresponding to the 0, 0.1, ... 1. qunatiles) 15 | 16 | model: 17 | resample_exp: True 18 | n_sim: 10 19 | hparams_overrides: 20 | decoding: 21 | hf_generation_params: 22 | num_beams: 15 23 | 24 | # name of the run determines folder name in logs 25 | run_name: "genie_cie_large_resample" 26 | -------------------------------------------------------------------------------- /configs/evaluation/genie_cie_small.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run_evaluation.py evaluation=gs-st_grid_identity 5 | 6 | defaults: 7 | - override /datamodule: rebel 8 | - override /model: ckpt_genie_genre_r 9 | - override /callbacks: null 10 | - override /trainer: gpu 11 | 12 | # the parameters below will be merged with parameters from the default configurations set above 13 | # this allows you to overwrite only a small/specific set of parameters 14 | datamodule: 15 | debug: True 16 | 17 | model: 18 | save_testing_output: 19 | save_log_likelihood: True 20 | 21 | hparams_overrides: 22 | decoding: 23 | free_generation: False 24 | entity_trie_path: ${data_dir}/tries/small/entity_trie.pickle 25 | relation_trie_path: ${data_dir}/tries/small/relation_trie.pickle 26 | 27 | # name of the run determines folder name in logs 28 | run_name: "genie_cie_small" 29 | -------------------------------------------------------------------------------- /configs/evaluation/gpt2_toxicgen.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run_evaluation.py evaluation=gpt2_toxicgen 5 | 6 | defaults: 7 | - override /datamodule: rtp 8 | - override /model: gpt2_lm 9 | - override /callbacks: default 10 | - override /logger: wandb 11 | - override /trainer: gpu 12 | 13 | # the parameters below will be merged with parameters from the default configurations set above 14 | # this allows you to overwrite only a small/specific set of parameters 15 | datamodule: 16 | debug: False 17 | 18 | model: 19 | save_testing_output: 20 | save_log_likelihood: True 21 | 22 | # name of the run determines folder name in logs 23 | run_name: "gpt2_toxicgen_v1" 24 | -------------------------------------------------------------------------------- /configs/evaluation/gpt2_toxicgen_resample.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - gpt2_toxicgen 5 | - override /datamodule: [rtp, resample] 6 | - override /model/decoding: [gpt_generic, stochastic_beams] 7 | 8 | # the parameters below will be merged with parameters from the default configurations set above 9 | # this allows you to overwrite only a small/specific set of parameters 10 | datamodule: 11 | debug: False 12 | debug_k: 12 13 | 14 | num_qs: 10 # Number of quantiles to select (e.g. 10 will produce 11 plots corresponding to the 0, 0.1, ... 1. qunatiles) 15 | 16 | model: 17 | resample_exp: True 18 | n_sim: 10 19 | hparams_overrides: 20 | decoding: 21 | hf_generation_params: 22 | num_beams: 15 23 | 24 | # name of the run determines folder name in logs 25 | run_name: "gpt2_toxicgen_v1_resample" 26 | -------------------------------------------------------------------------------- /configs/evaluation/mbart_translation.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python -m run_evaluation evaluation=mbart_translation model/decoding=[mbart_generic,beam_search] 5 | 6 | defaults: 7 | - override /datamodule: wmt14 8 | - override /model: mbart_cd 9 | - override /callbacks: null 10 | - override /trainer: gpu 11 | 12 | # the parameters below will be merged with parameters from the default configurations set above 13 | # this allows you to overwrite only a small/specific set of parameters 14 | datamodule: 15 | debug: False 16 | 17 | model: 18 | save_testing_output: 19 | save_log_likelihood: True 20 | 21 | # name of the run determines folder name in logs 22 | run_name: "mbart_translation" 23 | -------------------------------------------------------------------------------- /configs/evaluation/mbart_translation_resample.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - mbart_translation 5 | - override /datamodule: [wmt14, resample] 6 | - override /model/decoding: [mbart_generic, stochastic_beams] 7 | 8 | # the parameters below will be merged with parameters from the default configurations set above 9 | # this allows you to overwrite only a small/specific set of parameters 10 | datamodule: 11 | debug: False 12 | debug_k: 12 13 | 14 | num_qs: 5 # Number of quantiles to select (e.g. 10 will produce 11 plots corresponding to the 0, 0.1, ... 1. qunatiles) 15 | 16 | model: 17 | resample_exp: True 18 | n_sim: 10 19 | decoding: 20 | hf_generation_params: 21 | num_return_sequences: 10 22 | 23 | # name of the run determines folder name in logs 24 | run_name: "mbart_translation_resample" 25 | -------------------------------------------------------------------------------- /configs/evaluation/protgpt2_design.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run_evaluation.py evaluation=gs-st_grid_identity 5 | 6 | defaults: 7 | - override /datamodule: swissprot 8 | - override /model: protgpt2_lm 9 | - override /callbacks: default 10 | - override /logger: wandb 11 | - override /trainer: gpu 12 | 13 | # the parameters below will be merged with parameters from the default configurations set above 14 | # this allows you to overwrite only a small/specific set of parameters 15 | datamodule: 16 | debug: False 17 | 18 | model: 19 | save_testing_output: 20 | save_log_likelihood: True 21 | 22 | # name of the run determines folder name in logs 23 | run_name: "protgpt2_design_v1" 24 | -------------------------------------------------------------------------------- /configs/evaluation_from_file/cie.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /metric: 5 | - triplet_set_precision 6 | - triplet_set_recall 7 | - triplet_set_f1 8 | - override /datamodule: out_ds_rebel 9 | -------------------------------------------------------------------------------- /configs/evaluation_from_file/resample.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | datamodule: 4 | resample_exp: true 5 | -------------------------------------------------------------------------------- /configs/evaluation_from_file/solubility.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /metric: 5 | - protein_soluble 6 | - override /datamodule: out_ds_swissprot 7 | 8 | seed: 123 9 | -------------------------------------------------------------------------------- /configs/evaluation_from_file/toxicity.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /metric: 5 | - detoxify 6 | - override /datamodule: out_ds_rtp 7 | 8 | seed: 123 9 | -------------------------------------------------------------------------------- /configs/evaluation_from_file/toxicity_bootstrap.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /metric: 5 | - detoxify 6 | - override /datamodule: out_ds_rtp 7 | 8 | bootstrap_n: 50 9 | num_workers: 1 10 | 11 | bootstrap_run_method: "from_scores_per_datapoint" 12 | confidence_level: 0.95 13 | 14 | seed: 123 15 | -------------------------------------------------------------------------------- /configs/evaluation_from_file/translation.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /metric: 5 | # - bleu-1 6 | - bleu-4 7 | - override /datamodule: out_ds_wmt14 8 | 9 | seed: 1233 10 | -------------------------------------------------------------------------------- /configs/evaluation_from_file/translation_bootstrap.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /metric: 5 | # - bleu-1 6 | - bleu-4 7 | - override /datamodule: out_ds_wmt14 8 | 9 | bootstrap_n: 50 10 | num_workers: 1 11 | 12 | bootstrap_run_method: "from_metric" 13 | confidence_level: 0.95 14 | 15 | seed: 1233 16 | -------------------------------------------------------------------------------- /configs/evaluation_model/cie_noisy_oracle.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.mdp.evaluation_models.NoisyOracleGenIE 2 | 3 | noising_function_parameters: 4 | sigma: 0.4 # with 10 additional tokens, prob of flip ~ 0.48. Falls back to the perfect oracle for sigma: 0 5 | -------------------------------------------------------------------------------- /configs/evaluation_model/cie_oracle.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.mdp.evaluation_models.OracleGenIE 2 | -------------------------------------------------------------------------------- /configs/evaluation_model/detoxify.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.mdp.evaluation_models.DetoxifyEvaluationModel 2 | model_type: unbiased-small # To remove and make mandatory 3 | device: "cuda:0" 4 | -------------------------------------------------------------------------------- /configs/evaluation_model/detoxify_noisy_oracle.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.mdp.evaluation_models.NoisyOracleDetoxify 2 | 3 | device: "cpu" 4 | huggingface_config_path: null 5 | batch_size: 40 6 | model_type: ??? 7 | checkpoint: ??? 8 | -------------------------------------------------------------------------------- /configs/evaluation_model/detoxify_noisy_oracle_rmse_01976.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - detoxify_noisy_oracle 3 | 4 | model_type: "noisy_oracle_rmse_01976" 5 | checkpoint: ${data_dir}detoxify/noisy_oracles/N-Step-Checkpoint_epoch=0_global_step=1200.ckpt -------------------------------------------------------------------------------- /configs/evaluation_model/detoxify_noisy_oracle_rmse_01991.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - detoxify_noisy_oracle 3 | 4 | model_type: "noisy_oracle_rmse_01991" 5 | checkpoint: ${data_dir}detoxify/noisy_oracles/N-Step-Checkpoint_epoch=0_global_step=800.ckpt 6 | -------------------------------------------------------------------------------- /configs/evaluation_model/detoxify_noisy_oracle_rmse_02165.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - detoxify_noisy_oracle 3 | 4 | model_type: "noisy_oracle_rmse_02165" 5 | checkpoint: ${data_dir}detoxify/noisy_oracles/N-Step-Checkpoint_epoch=0_global_step=400.ckpt 6 | -------------------------------------------------------------------------------- /configs/evaluation_model/detoxify_noisy_oracle_rmse_02242.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - detoxify_noisy_oracle 3 | 4 | model_type: "noisy_oracle_rmse_02242" 5 | checkpoint: ${data_dir}detoxify/noisy_oracles/N-Step-Checkpoint_epoch=0_global_step=200.ckpt -------------------------------------------------------------------------------- /configs/evaluation_model/detoxify_oracle.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - detoxify_noisy_oracle 3 | 4 | model_type: unbiased-small # To remove and make mandatory 5 | checkpoint: null 6 | # Same as 7 | # model_type: null 8 | # checkpoint: ${data_dir}detoxify/oracle/unbiased-albert-c8519128.ckpt 9 | -------------------------------------------------------------------------------- /configs/evaluation_model/likelihood.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.mdp.evaluation_models.LikelihoodEvaluationModel -------------------------------------------------------------------------------- /configs/evaluation_model/mt_noisy_oracle.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.mdp.evaluation_models.NoisyOracleMT 2 | 3 | bleu_score: 4 | _target_: src.metrics.BLEUScore 5 | lowercase: False 6 | max_ngram_order: 4 7 | tokenize: "none" 8 | smooth_method: "exp" 9 | 10 | noising_function_parameters: 11 | lambda: 0.1 # the weight associated to the true target's BLEU in the linear combination 12 | # lambda: 0.8 # the weight associated to the true target's BLEU in the linear combination 13 | 14 | operate_on: "ids" 15 | l_strip: 0 16 | -------------------------------------------------------------------------------- /configs/evaluation_model/mt_noisy_oracle_b2.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - mt_noisy_oracle 3 | 4 | bleu_score: 5 | max_ngram_order: 2 6 | 7 | l_strip: 1 -------------------------------------------------------------------------------- /configs/evaluation_model/mt_noisy_oracle_b3.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - mt_noisy_oracle 3 | 4 | bleu_score: 5 | max_ngram_order: 3 6 | 7 | l_strip: 0 -------------------------------------------------------------------------------- /configs/evaluation_model/mt_oracle.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.mdp.evaluation_models.OracleMT 2 | 3 | bleu_score: 4 | _target_: src.metrics.BLEUScore 5 | lowercase: False 6 | max_ngram_order: 4 7 | tokenize: "none" 8 | smooth_method: "exp" 9 | 10 | operate_on: "ids" 11 | -------------------------------------------------------------------------------- /configs/evaluation_model/perfect_oracle.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.mdp.oracles.PerfectOracle 2 | tree: ${model.tree} # ToDo: this should be more general. Currently model is HF model; should be PL module 3 | path: ??? 4 | -------------------------------------------------------------------------------- /configs/evaluation_model/solubility.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.mdp.evaluation_models.SolubilityEvaluationModel 2 | device: "cuda:0" 3 | -------------------------------------------------------------------------------- /configs/evaluation_model/solubility_noisy_oracle.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.mdp.evaluation_models.NoisyOracleSolubility 2 | 3 | oracle_parameters: 4 | device: "cuda:0" 5 | batch_size: 1 6 | 7 | noising_function_parameters: 8 | probability_inversion: 0.2 # with 10 tokens, the Kendall's tau correlation with the original vector is 0.67 +- 0.29. Falls back to the perfect oracle for probability_inversion: 0 9 | -------------------------------------------------------------------------------- /configs/evaluation_model/solubility_oracle.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.mdp.evaluation_models.SolubilityEvaluationModel 2 | device: "cuda:0" 3 | batch_size: 1 -------------------------------------------------------------------------------- /configs/hparam_search/base_hparam_search.yaml: -------------------------------------------------------------------------------- 1 | run_name: ${num_files:${work_dir}/logs/hparam_search/${parent_run_name}} # DON'T INTERPOLATE! It changes with every call! 2 | search: ??? 3 | launcher: ??? 4 | parent_run_name: ??? -------------------------------------------------------------------------------- /configs/hparam_search/launcher/local_launcher.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.hparam_search.launchers.LocalLauncher 2 | work_dir: ${work_dir} -------------------------------------------------------------------------------- /configs/hparam_search/search/grid_search.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.hparam_search.search.GridSearch 2 | 3 | max_workers: 1 # maximum number of trials to run in parallel 4 | 5 | run_name_str: null 6 | run_name_vars: null 7 | 8 | scripts: ??? # path to the script to run 9 | search_space: ??? # Dictionary implemented using https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.ParameterGrid.html -------------------------------------------------------------------------------- /configs/hparam_search_root.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - trainer: cpu 4 | - logger: wandb_standalone # wandb #null # set logger here or use command line (e.g. `python run.py logger=wandb`) 5 | - hparam_search: hparams_bs_toxicity 6 | - hydra: hparam_search 7 | 8 | # enable color logging 9 | - override hydra/hydra_logging: colorlog 10 | - override hydra/job_logging: colorlog 11 | 12 | # verbose explanation: Hydra hijacks the working directory by changing it to the current log directory, 13 | # so it's useful to have this path as a special variable 14 | # learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 15 | work_dir: ${hydra:runtime.cwd} 16 | 17 | # path to folder with data 18 | data_dir: ${work_dir}/data/ 19 | 20 | # pretty print config at the start of the run using Rich library 21 | print_config: False 22 | 23 | # disable python warnings if they annoy you 24 | ignore_warnings: True 25 | 26 | # seed for random number generators in pytorch, numpy and python.random 27 | seed: 123 28 | 29 | # the parent experiment (hparam search run) name -- determines the logging folder's parent 30 | parent_run_name: ??? 31 | 32 | # the experiment name (either_combination – determines the logging folder's path 33 | run_name: ??? 34 | 35 | search: ??? -------------------------------------------------------------------------------- /configs/hydra/debug.yaml: -------------------------------------------------------------------------------- 1 | # output paths for hydra logs 2 | run: 3 | dir: logs/debug/runs/${run_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 4 | 5 | sweep: 6 | dir: logs/debug/multiruns/${run_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 7 | subdir: ${hydra.job.num} 8 | 9 | # you can set here environment variables that are universal for all users 10 | job: 11 | env_set: 12 | CUDA_DEVICE_ORDER: "PCI_BUS_ID" 13 | HYDRA_FULL_ERROR: "1" 14 | # Set cuda visible devices from command line: export CUDA_VISIBLE_DEVICES=0;python evaluate_kilt_dataset.py 15 | # Or python run.py +hydra.job.env_set.CUDA_VISIBLE_DEVICES="3' 16 | -------------------------------------------------------------------------------- /configs/hydra/evaluation.yaml: -------------------------------------------------------------------------------- 1 | # output paths for hydra logs 2 | run: 3 | dir: logs/evaluation/runs/${run_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 4 | 5 | sweep: 6 | dir: logs/evaluation/multiruns/${run_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 7 | subdir: ${hydra.job.num} 8 | 9 | # you can set here environment variables that are universal for all users 10 | job: 11 | env_set: 12 | CUDA_DEVICE_ORDER: "PCI_BUS_ID" 13 | HYDRA_FULL_ERROR: "1" 14 | # Set cuda visible devices from command line: export CUDA_VISIBLE_DEVICES=0;python evaluate_kilt_dataset.py 15 | # Or python run.py +hydra.job.env_set.CUDA_VISIBLE_DEVICES="3' 16 | -------------------------------------------------------------------------------- /configs/hydra/hparam_search.yaml: -------------------------------------------------------------------------------- 1 | # output paths for hydra logs 2 | run: 3 | dir: logs/hparam_search/${parent_run_name}/${run_name} 4 | 5 | # will not be used 6 | sweep: 7 | dir: logs/hparams_search/multiruns/${run_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 8 | subdir: ${hydra.job.num} 9 | 10 | # you can set here environment variables that are universal for all users 11 | job: 12 | env_set: 13 | CUDA_DEVICE_ORDER: "PCI_BUS_ID" 14 | HYDRA_FULL_ERROR: "1" 15 | # Set cuda visible devices from command line: export CUDA_VISIBLE_DEVICES=0;python evaluate_kilt_dataset.py 16 | # Or python run.py +hydra.job.env_set.CUDA_VISIBLE_DEVICES="3' -------------------------------------------------------------------------------- /configs/hydra/training.yaml: -------------------------------------------------------------------------------- 1 | # output paths for hydra logs 2 | run: 3 | dir: logs/training/runs/${run_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 4 | 5 | sweep: 6 | dir: logs/training/multiruns/${run_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 7 | subdir: ${hydra.job.num} 8 | 9 | # you can set here environment variables that are universal for all users 10 | job: 11 | env_set: 12 | CUDA_DEVICE_ORDER: "PCI_BUS_ID" 13 | HYDRA_FULL_ERROR: "1" 14 | # Set cuda visible devices from command line: export CUDA_VISIBLE_DEVICES=0;python evaluate_kilt_dataset.py 15 | # Or python run.py +hydra.job.env_set.CUDA_VISIBLE_DEVICES="3' 16 | -------------------------------------------------------------------------------- /configs/logger/python_logger.yaml: -------------------------------------------------------------------------------- 1 | # @package hydra.job_logging 2 | # python logging configuration for tasks 3 | version: 1 4 | formatters: 5 | simple: 6 | format: "[%(asctime)s][%(name)s][%(levelname)s] - %(message)s" 7 | colorlog: 8 | "()": "colorlog.ColoredFormatter" 9 | format: "[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s" 10 | log_colors: 11 | DEBUG: purple 12 | INFO: green 13 | WARNING: yellow 14 | ERROR: red 15 | CRITICAL: red 16 | handlers: 17 | console: 18 | class: logging.StreamHandler 19 | formatter: colorlog 20 | stream: ext://sys.stdout 21 | root: 22 | level: INFO 23 | handlers: [console] 24 | 25 | disable_existing_loggers: false 26 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 5 | project: "understanding-decoding" # The name of the project where you're sending the new run. Default: Uncategorized runs 6 | name: ${run_name} # A short display name for this run, which is how you'll identify this run in the UI. Default: Randomly generated two word name 7 | save_dir: "." 8 | offline: False # set True to store all logs only locally 9 | id: null # A unique ID for this run, used for resuming! See guide for resuming runs... 10 | entity: null # set to name of your wandb team 11 | log_model: False 12 | job_type: "train" # Specify the type of run, which is useful when you're grouping runs together into larger experiments using group 13 | group: "" 14 | tags: [] 15 | notes: 16 | -------------------------------------------------------------------------------- /configs/logger/wandb_group.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 5 | project: "understanding-decoding" # The name of the project where you're sending the new run. Default: Uncategorized runs 6 | name: ${run_name} # A short display name for this run, which is how you'll identify this run in the UI. Default: Randomly generated two word name 7 | save_dir: "." 8 | offline: False # set True to store all logs only locally 9 | id: null # A unique ID for this run, used for resuming! See guide for resuming runs... 10 | entity: epfl-dlab # set to name of your wandb team 11 | log_model: False 12 | job_type: "train" # Specify the type of run, which is useful when you're grouping runs together into larger experiments using group 13 | group: "" 14 | tags: [] 15 | notes: 16 | -------------------------------------------------------------------------------- /configs/logger/wandb_standalone.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | logger_name: wandb_standalone 4 | wandb: 5 | project: "understanding-decoding" # The name of the project where you're sending the new run. Default: Uncategorized runs 6 | name: ${run_name} # A short display name for this run, which is how you'll identify this run in the UI. Default: Randomly generated two word name 7 | # save_dir: "." # (only for PL WandbLogger) 8 | # offline: False # set True to store all logs only locally # (only for PL WandbLogger) 9 | id: null # A unique ID for this run, used for resuming! See guide for resuming runs.. 10 | entity: null # set to name of your wandb team 11 | # log_model: False # (only for PL WandbLogger) 12 | job_type: "train" # Specify the type of run, which is useful when you're grouping runs together into larger experiments using group 13 | group: "" # TODO: We might want to use this flag when using distributed training --> https://docs.wandb.ai/guides/track/advanced/grouping 14 | tags: [] 15 | notes: -------------------------------------------------------------------------------- /configs/mdp/search_tree_factory.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.mdp.SearchTreeFactory 2 | 3 | tree_structure_params: 4 | depth: 6 5 | branching_factor: 4 6 | decay_branching_factor: 1 7 | 8 | transition_model_params: 9 | alpha: 0.9 10 | 11 | utility_value_params: 12 | ul_alignment: 0.3 13 | uu_alignment: 1 14 | -------------------------------------------------------------------------------- /configs/metric/bleu-1.yaml: -------------------------------------------------------------------------------- 1 | bleu-1: 2 | _target_: src.metrics.BLEUScore 3 | lowercase: False 4 | max_ngram_order: 1 5 | tokenize: 13a 6 | smooth_method: 'exp' 7 | -------------------------------------------------------------------------------- /configs/metric/bleu-4.yaml: -------------------------------------------------------------------------------- 1 | bleu-4: 2 | _target_: src.metrics.BLEUScore 3 | lowercase: False 4 | max_ngram_order: 4 5 | tokenize: 13a 6 | smooth_method: "exp" 7 | -------------------------------------------------------------------------------- /configs/metric/detoxify.yaml: -------------------------------------------------------------------------------- 1 | detoxify: 2 | _target_: src.metrics.DetoxifyScore 3 | model_type: unbiased-small # To remove and make mandatory 4 | device: "cuda:0" 5 | -------------------------------------------------------------------------------- /configs/metric/protein_soluble.yaml: -------------------------------------------------------------------------------- 1 | protein_soluble: 2 | _target_: src.metrics.SolubleScore 3 | name: soluble 4 | device: 0 5 | -------------------------------------------------------------------------------- /configs/metric/triplet_set_f1.yaml: -------------------------------------------------------------------------------- 1 | triplet_set_f1: 2 | _target_: src.metrics.TSF1 3 | -------------------------------------------------------------------------------- /configs/metric/triplet_set_precision.yaml: -------------------------------------------------------------------------------- 1 | triplet_set_precision: 2 | _target_: src.metrics.TSPrecision 3 | -------------------------------------------------------------------------------- /configs/metric/triplet_set_recall.yaml: -------------------------------------------------------------------------------- 1 | triplet_set_recall: 2 | _target_: src.metrics.TSRecall 3 | -------------------------------------------------------------------------------- /configs/model/additional_logger/reconstruction_logger.yaml: -------------------------------------------------------------------------------- 1 | # ToDo: Incorporate into Translation model 2 | _target_: src.loggers.ReconstructionLogger 3 | 4 | logging_interval: 100 # Every 100 optimization steps 5 | samples_per_train_interval: 5 # Log 5 train samples 6 | 7 | samples_per_val_interval: 10 # Log 10 validation samples 8 | num_samples_per_z_to_consider: 3 9 | -------------------------------------------------------------------------------- /configs/model/ckpt_generic.yaml: -------------------------------------------------------------------------------- 1 | _target_: ??? # e.g. src.models.MBartForConditionalGeneration.load_from_checkpoint 2 | checkpoint_path: ??? 3 | -------------------------------------------------------------------------------- /configs/model/ckpt_genie_genre_r.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - decoding: [genie_generic, beam_search] # override from command line as model/decoding=[genie_generic,beam_sampling] 3 | 4 | _target_: src.models.GeniePL.load_from_checkpoint 5 | checkpoint_path: "${data_dir}/models/genie_genre_r.ckpt" 6 | -------------------------------------------------------------------------------- /configs/model/collator/rebel_collator.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.collators.RebelCollator 2 | tokenizer: null 3 | max_input_length: null 4 | max_output_length: null 5 | padding: True 6 | truncation: True 7 | target_padding_token_id: null 8 | -------------------------------------------------------------------------------- /configs/model/decoding/beam_sampling.yaml: -------------------------------------------------------------------------------- 1 | name: beam_sampling 2 | 3 | hf_generation_params: 4 | num_beams: 5 5 | num_return_sequences: ${.num_beams} 6 | num_beam_groups: 1 7 | do_sample: True 8 | #constraints: None 9 | 10 | top_k: 0 11 | top_p: 1.0 12 | 13 | early_stopping: false 14 | 15 | temperature: 1.0 16 | length_penalty: 1.0 17 | -------------------------------------------------------------------------------- /configs/model/decoding/beam_sampling_top_p.yaml: -------------------------------------------------------------------------------- 1 | name: beam_sampling_top_p 2 | 3 | hf_generation_params: 4 | num_beams: 5 5 | num_return_sequences: ${.num_beams} 6 | num_beam_groups: 1 7 | do_sample: True 8 | #constraints: None 9 | 10 | top_k: 0 11 | top_p: 0.92 12 | 13 | early_stopping: false 14 | 15 | temperature: 1.0 16 | length_penalty: 1.0 17 | -------------------------------------------------------------------------------- /configs/model/decoding/beam_search.yaml: -------------------------------------------------------------------------------- 1 | name: beam_search 2 | 3 | hf_generation_params: 4 | num_beams: 5 5 | num_return_sequences: ${.num_beams} 6 | num_beam_groups: 1 7 | do_sample: False 8 | 9 | early_stopping: false 10 | 11 | temperature: 1.0 12 | length_penalty: 1.0 13 | -------------------------------------------------------------------------------- /configs/model/decoding/genie_generic.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - decoding: null # choose your decoding strategy by setting this to a decoding algorithm 3 | 4 | hf_generation_params: 5 | min_length: 0 6 | max_length: 256 7 | early_stopping: false 8 | 9 | encoder_no_repeat_ngram_size: 0 10 | no_repeat_ngram_size: 0 11 | 12 | temperature: 1.0 13 | length_penalty: 1.0 14 | forced_bos_token_id: 0 15 | 16 | seed: ${seed} 17 | verbose_flag_in_convert_to_triple: false 18 | # GenIE defaults 19 | #num_beams: 10 20 | #num_return_sequences: ${.num_beams} 21 | # 22 | #min_length: 0 23 | #max_length: ${...max_output_length} 24 | # 25 | #early_stopping: false 26 | # 27 | #encoder_no_repeat_ngram_size: 0 28 | #no_repeat_ngram_size: 0 29 | # 30 | #temperature: 1.0 31 | #length_penalty: 1.0 32 | -------------------------------------------------------------------------------- /configs/model/decoding/gpt_generic.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - decoding: null # choose your decoding strategy by setting this to a decoding algorithm 3 | 4 | hf_generation_params: 5 | no_repeat_ngram_size: 2 6 | max_length: 256 7 | early_stopping: false 8 | 9 | temperature: 1.0 10 | length_penalty: 1.0 11 | 12 | seed: ${seed} 13 | -------------------------------------------------------------------------------- /configs/model/decoding/greedy_search.yaml: -------------------------------------------------------------------------------- 1 | name: greedy_search 2 | 3 | hf_generation_params: 4 | num_beams: 1 5 | num_beam_groups: 1 6 | do_sample: False 7 | #constraints: None 8 | -------------------------------------------------------------------------------- /configs/model/decoding/mbart_generic.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - decoding: null # choose your decoding strategy by setting this to a decoding algorithm 3 | 4 | target_language_id: fr_XX 5 | seed: ${seed} 6 | -------------------------------------------------------------------------------- /configs/model/decoding/mcts.yaml: -------------------------------------------------------------------------------- 1 | name: mcts 2 | 3 | defaults: 4 | - mcts_qtransform@hf_generation_params.mcts_qtransform: by_parent_and_siblings 5 | 6 | hf_generation_params: 7 | use_cache: True 8 | use_mcts: True 9 | num_beams: 1 10 | num_beam_groups: 1 11 | num_return_sequences: 1 12 | 13 | mcts_debug_prints: False 14 | mcts_dirichlet_fraction: 0.25 15 | mcts_dirichlet_alpha: 0.3 16 | mcts_pb_c_init: 1.25 17 | mcts_pb_c_base: 19652 18 | mcts_num_simulations: 50 19 | mcts_topk_actions: 20 20 | reuse_cached_jitted_mcts: True 21 | -------------------------------------------------------------------------------- /configs/model/decoding/mcts_qtransform/by_min_max.yaml: -------------------------------------------------------------------------------- 1 | _partial_: true 2 | _target_: mctx.qtransform_by_min_max 3 | max_value: 1.0 4 | min_value: -1.0 -------------------------------------------------------------------------------- /configs/model/decoding/mcts_qtransform/by_parent_and_siblings.yaml: -------------------------------------------------------------------------------- 1 | _partial_: true 2 | _target_: mctx.qtransform_by_parent_and_siblings 3 | epsilon: 1e-8 -------------------------------------------------------------------------------- /configs/model/decoding/mcts_qtransform/completed_by_mix_value.yaml: -------------------------------------------------------------------------------- 1 | _partial_: true 2 | _target_: mctx.qtransform_completed_by_mix_value 3 | value_scale: 0.1 4 | maxvisit_init: 50.0 5 | rescale_values: True 6 | use_mixed_value: True 7 | epsilon: 1e-8 -------------------------------------------------------------------------------- /configs/model/decoding/mcts_qtransform/noop.yaml: -------------------------------------------------------------------------------- 1 | _partial_: true 2 | _target_: mctx.qtransform_noop -------------------------------------------------------------------------------- /configs/model/decoding/pplmcts.yaml: -------------------------------------------------------------------------------- 1 | name: mcts 2 | 3 | hf_generation_params: 4 | use_cache: True 5 | use_pplmcts: True 6 | num_beams: 1 7 | num_beam_groups: 1 8 | num_return_sequences: 1 9 | 10 | mcts_debug_prints: False 11 | 12 | mcts_pb_c_init: 3.0 13 | mcts_num_simulations: 50 14 | mcts_topk_actions: 20 15 | 16 | reuse_cached_jitted_mcts: null # not used by PPLMCTS 17 | mcts_pb_c_base: null # not used by PPLMCTS 18 | mcts_dirichlet_fraction: null # not used by PPLMCTS 19 | mcts_dirichlet_alpha: null # not used by PPLMCTS 20 | -------------------------------------------------------------------------------- /configs/model/decoding/protgpt2_generic.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - decoding: null # choose your decoding strategy by setting this to a decoding algorithm 3 | 4 | hf_generation_params: 5 | max_length: 256 6 | early_stopping: false 7 | 8 | temperature: 1.0 9 | length_penalty: 1.0 10 | 11 | seed: ${seed} 12 | -------------------------------------------------------------------------------- /configs/model/decoding/stochastic_beams.yaml: -------------------------------------------------------------------------------- 1 | name: stochastic_beams 2 | 3 | hf_generation_params: 4 | num_beams: 5 5 | num_return_sequences: ${.num_beams} 6 | num_beam_groups: 1 7 | do_stochastic: true 8 | do_sample: false # has to be set to false, if do_stochastic is true 9 | 10 | early_stopping: false 11 | 12 | temperature: 1.0 # Does this have any effect? 13 | length_penalty: 1.0 14 | -------------------------------------------------------------------------------- /configs/model/decoding/value_guided_beam_search.yaml: -------------------------------------------------------------------------------- 1 | name: value_guided_beam_search 2 | 3 | hf_generation_params: 4 | num_beams: 5 5 | num_return_sequences: ${.num_beams} 6 | num_beam_groups: 1 7 | early_stopping: false 8 | length_penalty: 1.0 9 | 10 | do_value_guided: true 11 | do_sample: false 12 | do_stochastic: false 13 | 14 | # top_hypothesis_factor: 2 15 | tokens_considered_by_value_processor: 10 # Must be a number larger than 2 * num_beams 16 | contribution_factor: 0.1 # The weight of the likelihood the linear interpolation between the likelihood and the value 17 | # contribution_factor: 0.7 # The weight of the likelihood the linear interpolation between the likelihood and the value -------------------------------------------------------------------------------- /configs/model/genie_base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - decoding: [genie_generic, beam_search] 3 | - collator: rebel_collator 4 | 5 | _target_: src.models.GeniePL 6 | 7 | random_initialization: False 8 | from_checkpoint: False 9 | pretrained_model_name_or_path: "martinjosifoski/genie-rw" 10 | 11 | max_input_length: 256 12 | max_output_length: 256 13 | bos_as_first_token_generated: True 14 | 15 | hf_config: 16 | _target_: transformers.BartConfig.from_pretrained 17 | pretrained_model_name_or_path: ${..pretrained_model_name_or_path} 18 | 19 | hf_model: 20 | pretrained_model_name_or_path: ${..pretrained_model_name_or_path} 21 | config: ${..hf_config} 22 | 23 | tokenizer: 24 | _target_: transformers.BartTokenizer.from_pretrained 25 | pretrained_model_name_or_path: ${..pretrained_model_name_or_path} 26 | 27 | save_testing_output: 28 | save_log_likelihood: False 29 | 30 | optimizer: 31 | # --- Learning rates and Schedulers --- 32 | lr: 3.0e-05 33 | weight_decay: 0.01 34 | schedule_name: "polynomial" 35 | lr_end: 0.0 36 | 37 | warmup_updates: 1000 38 | total_num_updates: ${trainer.max_steps} 39 | 40 | # --- Optimizer --- 41 | adam_eps: 1.0e-08 42 | 43 | eps: 0.1 44 | -------------------------------------------------------------------------------- /configs/model/gpt2_lm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - decoding: [gpt_generic, beam_search] 3 | 4 | _target_: src.models.GPT2ForGeneration 5 | 6 | random_initialization: False 7 | from_checkpoint: False 8 | 9 | hf_config: 10 | _target_: transformers.GPT2Config.from_pretrained 11 | pretrained_model_name_or_path: "gpt2" 12 | pad_token_id: 50256 13 | 14 | hf_model: 15 | pretrained_model_name_or_path: "gpt2" 16 | config: ${..hf_config} 17 | 18 | tokenizer: 19 | _target_: transformers.GPT2Tokenizer.from_pretrained 20 | pretrained_model_name_or_path: "gpt2" 21 | padding_side: "left" 22 | pad_token: "<|endoftext|>" 23 | 24 | toxicity_metric: 25 | _binary: True 26 | model_name: "original" # options: original or unbiased 27 | 28 | save_testing_output: 29 | save_log_likelihood: False 30 | -------------------------------------------------------------------------------- /configs/model/mbart_cd.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - decoding: [mbart_generic, beam_search] # override from command line as model/decoding=[mbart_generic, beam_sampling] 3 | 4 | _target_: src.models.MBartForConditionalGeneration 5 | 6 | random_initialization: False 7 | from_checkpoint: False 8 | 9 | hf_config: 10 | _target_: transformers.MBartConfig.from_pretrained # config class 11 | pretrained_model_name_or_path: "facebook/mbart-large-50-one-to-many-mmt" 12 | 13 | hf_model: 14 | pretrained_model_name_or_path: "facebook/mbart-large-50-one-to-many-mmt" 15 | config: ${..hf_config} 16 | 17 | tokenizer: 18 | _target_: transformers.MBart50TokenizerFast.from_pretrained 19 | pretrained_model_name_or_path: "facebook/mbart-large-50-one-to-many-mmt" 20 | src_lang: "en_XX" 21 | tgt_lang: "fr_XX" 22 | # activation_dropout: 3 # overrides given as kwargs 23 | 24 | save_testing_output: 25 | save_log_likelihood: False 26 | -------------------------------------------------------------------------------- /configs/model/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adam 2 | 3 | lr: 0.001 # 3e-05 4 | weight_decay: 0.00 5 | eps: 1e-08 6 | betas: [0.9, 0.999] 7 | -------------------------------------------------------------------------------- /configs/model/optimizer/adam_w.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.AdamW 2 | 3 | lr: 0.001 # 3e-05 4 | weight_decay: 0.01 5 | eps: 1e-08 6 | betas: [0.9, 0.999] 7 | -------------------------------------------------------------------------------- /configs/model/protgpt2_lm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - decoding: [protgpt2_generic, beam_search] 3 | 4 | _target_: src.models.ProtGPT2ForGeneration 5 | 6 | random_initialization: False 7 | from_checkpoint: False 8 | 9 | hf_config: 10 | _target_: transformers.AutoConfig.from_pretrained 11 | pretrained_model_name_or_path: "nferruz/ProtGPT2" 12 | pad_token_id: 50256 13 | 14 | hf_model: 15 | pretrained_model_name_or_path: "nferruz/ProtGPT2" 16 | config: ${..hf_config} 17 | 18 | tokenizer: 19 | _target_: transformers.AutoTokenizer.from_pretrained 20 | pretrained_model_name_or_path: "nferruz/ProtGPT2" 21 | padding_side: "left" 22 | pad_token: "<|endoftext|>" 23 | 24 | metric: 25 | _target_: transformers.AutoModelForSequenceClassification.from_pretrained 26 | pretrained_model_name_or_path: "Rostlab/prot_bert_bfd_membrane" 27 | 28 | save_testing_output: 29 | save_log_likelihood: False 30 | -------------------------------------------------------------------------------- /configs/model/scheduler_config/linear.yaml: -------------------------------------------------------------------------------- 1 | scheduler: 2 | _target_: transformers.get_linear_schedule_with_warmup 3 | num_warmup_steps: 1 4 | num_training_steps: 200 5 | 6 | scheduler_dict: 7 | scheduler: null # the schedule instance defined above – will be passed from the code (in configure_optimizer) 8 | interval: "epoch" # The unit of the scheduler's step size. 'step' or 'epoch 9 | frequency: 1 # corresponds to updating the learning rate after every `frequency` epoch/step 10 | name: LearningRateScheduler_${model.scheduler_config.scheduler._target_} 11 | } 12 | -------------------------------------------------------------------------------- /configs/model/scheduler_config/polynomial.yaml: -------------------------------------------------------------------------------- 1 | scheduler: 2 | _target_: transformers.get_polynomial_decay_schedule_with_warmup 3 | num_warmup_steps: 500 4 | num_training_steps: 200000 5 | lr_end: 0.0 6 | 7 | scheduler_dict: 8 | scheduler: null # the schedule instance defined above – will be passed from the code (in configure_optimizer) 9 | interval: "step" # The unit of the scheduler's step size. 'step' or 'epoch 10 | frequency: 1 # corresponds to updating the learning rate after every `frequency` epoch/step 11 | name: LearningRateScheduler_${model.scheduler_config.scheduler._target_} 12 | } 13 | -------------------------------------------------------------------------------- /configs/model/scheduler_config/reduce_on_plateau.yaml: -------------------------------------------------------------------------------- 1 | scheduler: 2 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 3 | mode: "min" 4 | factor: 0.8 5 | patience: 2 6 | threshold: 0.001 7 | threshold_mode: "abs" 8 | cooldown: 1 9 | min_lr: 1e-6 10 | eps: 1e-8 11 | verbose: True 12 | 13 | scheduler_dict: 14 | scheduler: null # the schedule instance defined above – will be passed from the code (in configure_optimizer) 15 | interval: "epoch" # The unit of the scheduler's step size. 'step' or 'epoch 16 | frequency: 1 # corresponds to updating the learning rate after every `frequency` epoch/step 17 | monitor: train/loss # Used by a LearningRateMonitor callback when ReduceLROnPlateau is used 18 | name: LearningRateScheduler_${model.scheduler_config.scheduler._target_} 19 | -------------------------------------------------------------------------------- /configs/model/tree_lm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - decoding: beam_search 3 | 4 | _target_: src.mdp.tree_lm.TreeLanguageModel 5 | 6 | tree: 7 | _target_: src.mdp.trees.ExplicitTree 8 | path: ??? # e.g. tests/toy_mdp_configs/greedy_fail_tm.txt 9 | 10 | config: 11 | _target_: src.mdp.tree_lm.TreeLanguageModelConfig 12 | 13 | # Note that this pattern will create a new object of tree. See https://github.com/facebookresearch/hydra/issues/1393 14 | # An alternative solution is to pass the tree in the constructor 15 | tree: ${..tree} 16 | decoding: ${..decoding} 17 | -------------------------------------------------------------------------------- /configs/testing/cie_task.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /trainer: gpu 5 | - override /model: ckpt_genie_genre_r 6 | - override /datamodule: rebel 7 | 8 | - override /callbacks: null # set to null if you don't want to use callbacks 9 | - override /logger: null # set to null if you don't want to use any loggers 10 | - override /training: null 11 | 12 | - override /hydra: debug 13 | - _self_ 14 | 15 | datamodule: 16 | debug: True 17 | num_workers: 0 # debuggers don't like multiprocessing 18 | pin_memory: False # disable gpu memory pin 19 | 20 | model: 21 | hparams_overrides: 22 | decoding: 23 | free_generation: False # Set to true in development – loading the tries take time 24 | entity_trie_path: ${data_dir}/tries/large/entity_trie.pickle 25 | relation_trie_path: ${data_dir}/tries/large/relation_trie.pickle 26 | 27 | # set to True to run model evaluation (for cleaner logs, running evaluation using the evaluation script is preferred) 28 | test: False 29 | 30 | # verbose explanation: Hydra hijacks the working directory by changing it to the current log directory, 31 | # so it's useful to have this path as a special variable 32 | # learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 33 | work_dir: ${hydra:runtime.cwd}/../ 34 | 35 | # path to folder with data 36 | data_dir: ${work_dir}/data/ 37 | 38 | # pretty print config at the start of the run using Rich library 39 | print_config: True 40 | 41 | # disable python warnings if they annoy you 42 | ignore_warnings: True 43 | 44 | # seed for random number generators in pytorch, numpy and python.random 45 | seed: 123 46 | 47 | # the experiment name – determines the logging folder's path 48 | run_name: test_cie 49 | -------------------------------------------------------------------------------- /configs/testing/stochastic_beams.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /trainer: cpu 5 | - override /model: tree_lm 6 | - override /model/decoding: stochastic_beams 7 | - override /datamodule: null 8 | 9 | - override /callbacks: null # set to null if you don't want to use callbacks 10 | - override /logger: null # set to null if you don't want to use any loggers 11 | - override /training: null 12 | 13 | - override /hydra: debug 14 | - _self_ 15 | #model.decoding.hf_generation_params: 16 | # encoder_no_repeat_ngram_size: 0 17 | # no_repeat_ngram_size: 0 18 | # 19 | # temperature: 1.0 20 | # length_penalty: 1.0 21 | # forced_bos_token_id: 0 22 | -------------------------------------------------------------------------------- /configs/testing/translation_task.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /trainer: gpu 5 | - override /model: mbart_cd 6 | - override /datamodule: wmt14 7 | 8 | - override /callbacks: null # set to null if you don't want to use callbacks 9 | - override /logger: null # set to null if you don't want to use any loggers 10 | - override /training: null 11 | 12 | - override /hydra: debug 13 | - _self_ 14 | 15 | datamodule: 16 | debug: True 17 | num_workers: 0 # debuggers don't like multiprocessing 18 | pin_memory: False # disable gpu memory pin 19 | 20 | # set to True to run model evaluation (for cleaner logs, running evaluation using the evaluation script is preferred) 21 | test: False 22 | 23 | # verbose explanation: Hydra hijacks the working directory by changing it to the current log directory, 24 | # so it's useful to have this path as a special variable 25 | # learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 26 | work_dir: ${hydra:runtime.cwd}/../ 27 | 28 | # path to folder with data 29 | data_dir: ${work_dir}/data/ 30 | 31 | # pretty print config at the start of the run using Rich library 32 | print_config: True 33 | 34 | # disable python warnings if they annoy you 35 | ignore_warnings: True 36 | 37 | # seed for random number generators in pytorch, numpy and python.random 38 | seed: 123 39 | 40 | # the experiment name – determines the logging folder's path 41 | run_name: test_translation 42 | -------------------------------------------------------------------------------- /configs/train_root.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - trainer: gpu 6 | - model: mbart_cd 7 | - datamodule: wmt14 8 | - callbacks: default # set to null if you don't want to use callbacks 9 | - logger: wandb # set to null if you don't want to use any loggers 10 | - hydra: training 11 | 12 | - training: null 13 | 14 | # debugging config (enable through command line, e.g. `python train.py debug=default) 15 | - debug: null 16 | 17 | # enable color logging 18 | - override hydra/hydra_logging: colorlog 19 | - override hydra/job_logging: colorlog 20 | 21 | # set to True to run model evaluation (for cleaner logs, running evaluation using the evaluation script is preferred) 22 | test: False 23 | 24 | # verbose explanation: Hydra hijacks the working directory by changing it to the current log directory, 25 | # so it's useful to have this path as a special variable 26 | # learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 27 | work_dir: ${hydra:runtime.cwd} 28 | 29 | # path to folder with data 30 | data_dir: ${work_dir}/data/ 31 | 32 | # pretty print config at the start of the run using Rich library 33 | print_config: True 34 | 35 | # disable python warnings if they annoy you 36 | ignore_warnings: True 37 | 38 | # seed for random number generators in pytorch, numpy and python.random 39 | seed: 123 40 | 41 | # the experiment name – determines the logging folder's path 42 | run_name: ??? 43 | -------------------------------------------------------------------------------- /configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | gpus: 0 4 | 5 | accumulate_grad_batches: 1 6 | 7 | max_steps: 2000 8 | 9 | # val_check_interval can be a float 0.5 (2 times per training epoch) or an integer 1000 (every 1000 training steps) 10 | # Currently, if you use int for val_check_interval, it has to be smaller than the number of batches per epoch, but it is being addressed here: https://github.com/PyTorchLightning/pytorch-lightning/pull/11993 11 | # ToDo: Verify that it works once the update lands 12 | val_check_interval: 0.5 # 1000 13 | #val_check_interval: ${mult_int:${.accumulate_grad_batches}, 1000} 14 | 15 | weights_summary: null 16 | progress_bar_refresh_rate: 5 17 | resume_from_checkpoint: null 18 | 19 | gradient_clip_val: 0.1 20 | gradient_clip_algorithm: "norm" 21 | #log_every_n_steps: ${mult_int:${.accumulate_grad_batches}, 50} # log every 50 optimizer steps 22 | -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | gpus: 4 4 | accelerator: "ddp" 5 | replace_sampler_ddp: True 6 | num_sanity_val_steps: 0 7 | 8 | max_steps: 200000 9 | val_check_interval: 40000 # 5000 * accumulate_grad_batch 10 | 11 | weights_summary: null 12 | progress_bar_refresh_rate: 5 13 | resume_from_checkpoint: null 14 | 15 | gradient_clip_val: 1 16 | gradient_clip_algorithm: "norm" 17 | 18 | accumulate_grad_batches: 4 19 | # See notes on effective learning rate in ddp: 20 | # https://github.com/epfl-dlab/dlab-wiki/blob/main/PyTorchLightning_trainer.md 21 | -------------------------------------------------------------------------------- /configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | # set `1` to train on GPU, `0` to train on CPU only 4 | gpus: 1 5 | 6 | accumulate_grad_batches: 1 7 | 8 | max_steps: 2000 9 | 10 | # val_check_interval can be a float 0.5 (2 times per training epoch) or an integer 1000 (every 1000 training steps) 11 | # Currently, if you use int for val_check_interval, it has to be smaller than the number of batches per epoch, but it is being addressed here: https://github.com/PyTorchLightning/pytorch-lightning/pull/11993 12 | # ToDo: Verify that it works once the update lands 13 | val_check_interval: 0.5 # 1000 14 | #val_check_interval: ${mult_int:${.accumulate_grad_batches}, 1000} 15 | 16 | weights_summary: null 17 | progress_bar_refresh_rate: 5 18 | resume_from_checkpoint: null 19 | 20 | gradient_clip_val: 0.1 21 | gradient_clip_algorithm: "norm" 22 | #log_every_n_steps: ${mult_int:${.accumulate_grad_batches}, 50} # log every 50 optimizer steps 23 | -------------------------------------------------------------------------------- /detoxify/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import Detoxify 2 | -------------------------------------------------------------------------------- /detoxify/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-dlab/understanding-decoding/f0976e67964e5d083e90f2ff0b2a1662aa240ba3/detoxify/src/__init__.py -------------------------------------------------------------------------------- /detoxify/src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | 4 | 5 | def move_to(obj, device): 6 | """Function to move objects of different types 7 | containing a tensor to device. 8 | """ 9 | if torch.is_tensor(obj): 10 | return obj.to(device) 11 | elif isinstance(obj, dict): 12 | res = {} 13 | for k, v in obj.items(): 14 | res[k] = move_to(v, device) 15 | return res 16 | elif isinstance(obj, list): 17 | res = [] 18 | for v in obj: 19 | res.append(move_to(v, device)) 20 | return res 21 | else: 22 | raise TypeError("Invalid type for move_to") 23 | 24 | 25 | def get_model_and_tokenizer(model_type, model_name, tokenizer_name, num_classes): 26 | model = getattr(transformers, model_name).from_pretrained(model_type, num_labels=num_classes) 27 | tokenizer = getattr(transformers, tokenizer_name).from_pretrained(model_type) 28 | 29 | return model, tokenizer -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Unless the data_directory is passed as an argument download the files in the `data` directory 4 | path_to_data_dir=${1:-data} 5 | 6 | mkdir -p $path_to_data_dir 7 | cd $path_to_data_dir 8 | 9 | ################################# 10 | ##### Download Pre-Trained Models 11 | ################################# 12 | mkdir -p models 13 | cd models 14 | 15 | # Models initialized with a pretrained entity linker GENRE (GenIE - PLM) 16 | wget https://zenodo.org/record/6139236/files/genie_genre_r.ckpt # Trained on Rebel 17 | 18 | cd .. 19 | ################################# 20 | 21 | ################ 22 | # Download Tries 23 | ################ 24 | wget https://zenodo.org/record/6139236/files/tries.zip 25 | unzip tries.zip && rm tries.zip 26 | ################### 27 | 28 | ########################### 29 | Download World Definitions 30 | ########################### 31 | wget https://zenodo.org/record/6139236/files/world_definitions.zip 32 | unzip world_definitions.zip && rm world_definitions.zip 33 | ################## 34 | 35 | ##################################### 36 | # Download Noisy Oracles for Toxicity 37 | ##################################### 38 | mkdir -p detoxify/noisy_oracles 39 | cd detoxify/noisy_oracles 40 | 41 | wget "https://huggingface.co/martinjosifoski/detoxify_noisy_oracles/resolve/main/N-Step-Checkpoint_epoch%3D0_global_step%3D200.ckpt" 42 | wget "https://huggingface.co/martinjosifoski/detoxify_noisy_oracles/resolve/main/N-Step-Checkpoint_epoch%3D0_global_step%3D400.ckpt" 43 | wget "https://huggingface.co/martinjosifoski/detoxify_noisy_oracles/resolve/main/N-Step-Checkpoint_epoch%3D0_global_step%3D800.ckpt" 44 | wget "https://huggingface.co/martinjosifoski/detoxify_noisy_oracles/resolve/main/N-Step-Checkpoint_epoch%3D0_global_step%3D1200.ckpt" 45 | ####################################### 46 | 47 | echo "The data was download at '$path_to_data_dir'." 48 | -------------------------------------------------------------------------------- /pip_requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.20.3 2 | jsonlines==2.0.0 3 | torchmetrics==0.7.3 4 | pytorch-lightning==1.4.9 5 | pandas==1.3.2 6 | pywikibot==6.5.0 7 | wikitextparser==0.47.5 8 | wandb==0.12.2 9 | omegaconf==2.2.2 10 | hydra-core==1.2.0 11 | hydra-colorlog==1.1.0 12 | rich==10.10.0 13 | pre-commit==2.15.0 14 | matplotlib==3.4.2 15 | tqdm==4.62.2 16 | jupyter==1.0.0 17 | ipykernel==6.4.1 18 | jupyterlab==3.1.7 19 | einops==0.4.1 20 | scikit-learn==1.1.0 21 | pyyaml==6.0 22 | runai==0.4.1 23 | datasets==2.2.2 24 | detoxify==0.5.0 25 | pytest-pycharm==0.7.0 26 | pytest-subtests==0.8.0 27 | seaborn==0.11.2 28 | sacrebleu==2.2.0 29 | chex==0.1.3 30 | networkx==2.8 31 | scipy==1.8.1 32 | jax==0.3.15 33 | jaxlib==0.3.15 34 | sentencepiece==0.1.96 35 | transformers==4.10.0 36 | protobuf==3.20 37 | setuptools==59.5.0 38 | -------------------------------------------------------------------------------- /requirements.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - defaults 3 | - pytorch 4 | - conda-forge 5 | dependencies: 6 | - python=3.8 7 | - pip 8 | - numpy 9 | - matplotlib 10 | - tqdm 11 | - jupyter 12 | - matplotlib 13 | - ipykernel 14 | - jupyterlab 15 | -------------------------------------------------------------------------------- /run_detoxify_evaluation.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import os 3 | import argparse 4 | import sys 5 | 6 | 7 | def evaluate_checkpoints(checkpoints, test_csv, num_gpus): 8 | with concurrent.futures.ThreadPoolExecutor(max_workers=num_gpus) as executor: 9 | for i, ckpt_path in enumerate(checkpoints): 10 | executor.submit(evaluate_checkpoint, ckpt_path, test_csv, f"cuda:{i % num_gpus}") 11 | 12 | 13 | def evaluate_checkpoint(checkpoint_path, test_csv, device): 14 | cmd = f"{sys.executable} detoxify/evaluate.py --checkpoint {checkpoint_path} --test_csv {test_csv} --device {device}" 15 | result_code = os.system(cmd) 16 | print("*" * 80, "\n Finished running command:", cmd, "\nResult code:", result_code, "\n") 17 | 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser(description="Evaluate .") 21 | parser.add_argument( 22 | "--checkpoints_folder", type=str, required=True, 23 | help="The directory containing the checkpoints to be evaluated." 24 | ) 25 | parser.add_argument( 26 | "--num_gpus", type=int, required=True, help="The directory containing the old checkpoint." 27 | ) 28 | parser.add_argument( 29 | "--test_csv", type=str, default="data/detoxify/test_public_expanded.csv", help="The path to the test dataset." 30 | ) 31 | 32 | args = parser.parse_args() 33 | checkpoints_to_evaluate = [os.path.join(args.checkpoints_folder, path) 34 | for path in os.listdir(args.checkpoints_folder)] 35 | 36 | evaluate_checkpoints(checkpoints_to_evaluate, args.test_csv, args.num_gpus) 37 | -------------------------------------------------------------------------------- /run_evaluation.py: -------------------------------------------------------------------------------- 1 | from src.utils import hydra_custom_resolvers 2 | import hydra 3 | from omegaconf import DictConfig 4 | #import wandb 5 | #wandb.init(project="my-test-project", entity="debjitpaul") 6 | 7 | # python run_evaluation.py evaluation= ckpt_path= run_name= 8 | 9 | 10 | @hydra.main(version_base="1.2", config_path="configs", config_name="evaluate_root") 11 | def main(hydra_config: DictConfig): 12 | # Imports should be nested inside @hydra.main to optimize tab completion 13 | # Read more here: https://github.com/facebookresearch/hydra/issues/934 14 | 15 | import src.utils.general as utils 16 | from src.evaluation_pipeline import evaluate 17 | 18 | # Applies optional utilities: 19 | # - disabling python warnings 20 | # - prints config 21 | utils.extras(hydra_config) 22 | 23 | # Evaluate model 24 | evaluate(hydra_config) 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /run_evaluation_from_file.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | 5 | import hydra 6 | 7 | 8 | # python run_evaluation_from_file.py --exp_dir "/home/martin_vm/understanding_decoding/logs/debug/runs/mbart_translation_v1/2022-05-24_11-02-22" --overrides evaluation_from_file=translation 9 | 10 | 11 | def main(overrides): 12 | # Imports should be nested inside @hydra.main to optimize tab completion 13 | # Read more here: https://github.com/facebookresearch/hydra/issues/934 14 | with hydra.initialize(version_base="1.2", config_path="configs"): 15 | hydra_config = hydra.compose(config_name="evaluate_from_file_root", overrides=overrides) 16 | 17 | hydra.core.utils.configure_log(hydra_config.logger_config, False) 18 | 19 | # print(OmegaConf.to_yaml(hydra_config, resolve=True)) 20 | # # Imports should be nested inside @hydra.main to optimize tab completion 21 | # # Read more here: https://github.com/facebookresearch/hydra/issues/934 22 | # 23 | import src.utils.general as utils 24 | from src.evaluation_from_file_pipeline import evaluate_from_file 25 | 26 | # Applies optional utilities: 27 | # - disabling python warnings 28 | # - prints config 29 | utils.extras(hydra_config) 30 | 31 | # Run evaluation on an output file 32 | evaluate_from_file(hydra_config) 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser(description="Run evaluation on the output from a file.") 37 | parser.add_argument("--wandb_run_path", type=str, required=False, 38 | help="The path to the wandb run that produced the output_files. " 39 | "Specify either the `wandb_run_path` or `exp_dir`.") 40 | 41 | parser.add_argument("--exp_dir", type=str, required=False, 42 | help="The path to the experiment log files. " 43 | "Specify either the `wandb_run_path` or `exp_dir`.") 44 | parser.add_argument("--overrides", type=str, nargs="+", help="Space separated overrides") 45 | args = parser.parse_args() 46 | 47 | if args.overrides is not None: 48 | overrides = args.overrides 49 | else: 50 | overrides = [] 51 | 52 | if args.wandb_run_path is not None: 53 | wandb_run_path = args.wandb_run_path 54 | else: 55 | assert args.exp_dir is not None 56 | 57 | # Extract the wandb_run_path (i.e. the unique wandb run identifier) from the correct log file in exp_dir 58 | wandb_debug_file_path = os.path.join(args.exp_dir, "wandb", "latest-run", "logs", "debug.log") 59 | 60 | KEYWORD = "finishing run" 61 | with open(wandb_debug_file_path, "r") as file: 62 | relevant_lines = [line for line in file if re.search(KEYWORD, line)] 63 | 64 | assert len(relevant_lines) == 1 65 | relevant_line = relevant_lines[0] 66 | wandb_run_path = relevant_line[relevant_line.find(KEYWORD) + len(KEYWORD):].strip() 67 | 68 | print(wandb_run_path) 69 | overrides.append(f"wandb_run_path={wandb_run_path}") 70 | overrides.append(f"work_dir={os.path.abspath(os.getcwd())}") 71 | overrides.append(f"data_dir={os.path.abspath(os.getcwd())}/data/") 72 | print(f"overrides={overrides}") 73 | main(overrides) 74 | -------------------------------------------------------------------------------- /run_hparam_search.py: -------------------------------------------------------------------------------- 1 | from src.utils import hydra_custom_resolvers 2 | import hydra 3 | from omegaconf import OmegaConf, DictConfig 4 | 5 | 6 | @hydra.main(version_base="1.2", config_path="configs", config_name="hparam_search_root") 7 | def main(hydra_config: DictConfig): 8 | # Imports should be nested inside @hydra.main to optimize tab completion 9 | # Read more here: https://github.com/facebookresearch/hydra/issues/934 10 | 11 | import src.utils.general as utils 12 | from src.hparam_search_pipeline import hparam_search 13 | 14 | # Applies optional utilities: 15 | # - disabling python warnings 16 | # - prints config 17 | utils.extras(hydra_config) 18 | 19 | # Run the hyperparameter search model 20 | hparam_search(hydra_config) 21 | 22 | 23 | if __name__ == "__main__": 24 | main() -------------------------------------------------------------------------------- /run_training.py: -------------------------------------------------------------------------------- 1 | from src.utils import hydra_custom_resolvers 2 | import hydra 3 | from omegaconf import DictConfig 4 | 5 | 6 | # python run_training.py training= ckpt_path= run_name= 7 | 8 | 9 | @hydra.main(version_base="1.2", config_path="configs", config_name="train_root") 10 | def main(hydra_config: DictConfig): 11 | # Imports should be nested inside @hydra.main to optimize tab completion 12 | # Read more here: https://github.com/facebookresearch/hydra/issues/934 13 | 14 | import src.utils.general as utils 15 | from src.training_pipeline import train 16 | 17 | # Applies optional utilities: 18 | # - disabling python warnings 19 | # - prints config 20 | utils.extras(hydra_config) 21 | 22 | # Train model 23 | train(hydra_config) 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-dlab/understanding-decoding/f0976e67964e5d083e90f2ff0b2a1662aa240ba3/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/get_detoxify_evaluation_results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | import numpy as np 4 | 5 | import json 6 | import os 7 | 8 | 9 | def get_model_name(file_name): 10 | parts = file_name.split("_") 11 | for part in parts: 12 | if part.startswith("step"): 13 | return part.split(".")[0] 14 | 15 | if "unbiased-albert" in file_name: 16 | return "oracle_unbiased-albert" 17 | 18 | raise Exception("Unexpected file name") 19 | 20 | 21 | def get_results(path): 22 | with open(path) as f: 23 | results = json.load(f) 24 | return results 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser(description="Convert old GenIE checkpoint to new setting.") 29 | parser.add_argument( 30 | "--results_folder", type=str, required=True, 31 | help="The directory containing the jsonfiles from the evaluations." 32 | ) 33 | args = parser.parse_args() 34 | 35 | model_name2results = {get_model_name(file_name): get_results(os.path.join(args.results_folder, file_name)) 36 | for file_name in os.listdir(args.results_folder) 37 | if file_name.startswith('results_N') or file_name.startswith("results_unbiased")} 38 | 39 | column_names = model_name2results['oracle_unbiased-albert'].keys() 40 | data = {model_name: [results[key] for key in column_names] for model_name, results in model_name2results.items()} 41 | results = pd.DataFrame.from_dict(data, orient='index', columns=column_names) 42 | results = results.iloc[np.argsort( 43 | list(map(lambda x: 10e20 if "oracle" in x else int(x.split("=")[1]), list(results.index.to_numpy()))))] 44 | results.to_csv("detoxify_noisy_oracle_results.csv", index=False) 45 | -------------------------------------------------------------------------------- /scripts/get_genie_ckpt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import torch 4 | 5 | sys.path.append("..") 6 | 7 | from src.models.collators import RebelCollator 8 | from pathlib import Path 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser(description="Convert old GenIE checkpoint to new setting.") 12 | parser.add_argument( 13 | "--input_ckpt_dir", type=str, required=True, help="The directory containing the old checkpoint." 14 | ) 15 | parser.add_argument("--output_ckpt_dir", type=str, required=True, help="The directory for the new checkpoint.") 16 | parser.add_argument( 17 | "--ckpt_name", 18 | type=str, 19 | default="genie_genre_r.ckpt", 20 | help="The full checkpoint name (e.g. genie_genre_r.ckpt).", 21 | ) 22 | args = parser.parse_args() 23 | 24 | model = torch.load(f"{args.input_ckpt_dir}/{args.ckpt_name}") 25 | hparams = model["hyper_parameters"] 26 | 27 | hparams["decoding"] = hparams.pop("inference") 28 | hparams["from_checkpoint"] = True 29 | hparams["random_initialization"] = False 30 | hparams.pop("model_name_or_path", None) 31 | hparams["pretrained_model_name_or_path"] = "martinjosifoski/genie-rw" 32 | hparams["hf_config"].update( 33 | { 34 | "min_length": 0, 35 | "max_length": 256, 36 | "early_stopping": False, 37 | "encoder_no_repeat_ngram_size": 0, 38 | "no_repeat_ngram_size": 0, 39 | "temperature": 1.0, 40 | "length_penalty": 1.0, 41 | "forced_bos_token_id": 0, 42 | } 43 | ) 44 | hparams["tokenizer"].model_max_length = 256 45 | 46 | collator = RebelCollator( 47 | hparams["tokenizer"], 48 | max_input_length=hparams["max_input_length"], 49 | max_output_length=hparams["max_output_length"], 50 | padding=True, 51 | truncation=True, 52 | ) 53 | hparams["collator"] = collator 54 | Path(args.output_ckpt_dir).mkdir(exist_ok=True, parents=True) 55 | torch.save(model, f"{args.output_ckpt_dir}/{args.ckpt_name}") 56 | print("Checkpoint saved at:", f"{args.output_ckpt_dir}/{args.ckpt_name}") 57 | -------------------------------------------------------------------------------- /scripts/launch_evaluation_from_file.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | #PRINT=true 6 | PRINT=false 7 | 8 | # Greedy + BS 9 | array=( "epfl-dlab/understanding-decoding/3410qxd0" "epfl-dlab/understanding-decoding/1y6yme2k" ) 10 | array2=( "translation" "translation" ) 11 | 12 | # Greedy + BS 13 | array=( "epfl-dlab/understanding-decoding/34kq9nsf" "epfl-dlab/understanding-decoding/6yf8p424" ) 14 | array2=( "toxicity" "toxicity" ) 15 | 16 | # Greedy + BS 17 | array=( "epfl-dlab/understanding-decoding/28vcj2b0" "epfl-dlab/understanding-decoding/9lbtlj0l") 18 | array2=( "cie" "cie") 19 | 20 | # Greedy + BS 21 | array=( "epfl-dlab/understanding-decoding/2fhbfm2e" "epfl-dlab/understanding-decoding/1k2x9th9" ) 22 | array2=( "solubility" "solubility") 23 | 24 | # Stochastic Beams (all tasks) 25 | array=("epfl-dlab/understanding-decoding/gigllz7f" "epfl-dlab/understanding-decoding/2mgcqsyk" "epfl-dlab/understanding-decoding/2cqpyxit" "epfl-dlab/understanding-decoding/3vh6a5h8" ) 26 | array2=( "translation" "toxicity" "cie" "solubility" ) 27 | 28 | # MT + VGBS 29 | array=( "epfl-dlab/understanding-decoding/2xogfbel" "epfl-dlab/understanding-decoding/38vrpbjt" "epfl-dlab/understanding-decoding/2gchwiy1" "epfl-dlab/understanding-decoding/3jtpx0zz" "epfl-dlab/understanding-decoding/3g2q6r6x" "epfl-dlab/understanding-decoding/yqyctpgx" ) 30 | array2=( "translation_bootstrap" "translation_bootstrap" "translation_bootstrap" "translation_bootstrap" "translation_bootstrap" "translation_bootstrap" ) 31 | 32 | # MT + MCTS 33 | array=( "epfl-dlab/understanding-decoding/2xdt1zgn" "epfl-dlab/understanding-decoding/7y0pr9sw" "epfl-dlab/understanding-decoding/13niq0lx" "epfl-dlab/understanding-decoding/uykq2pgp" "epfl-dlab/understanding-decoding/37wx7n8u" "epfl-dlab/understanding-decoding/2yc65rt8" ) 34 | array2=( "translation_bootstrap" "translation_bootstrap" "translation_bootstrap" "translation_bootstrap" "translation_bootstrap" "translation_bootstrap" ) 35 | 36 | # Toxicity + VGBS 37 | array=( "epfl-dlab/understanding-decoding/3a5ydpr4" "epfl-dlab/understanding-decoding/2lveviea" "epfl-dlab/understanding-decoding/2an1uakn" "epfl-dlab/understanding-decoding/1yc972vw" ) 38 | array2=( "toxicity_bootstrap" "toxicity_bootstrap" "toxicity_bootstrap" "toxicity_bootstrap" ) 39 | 40 | # Toxicity + MCTS 41 | array=( "epfl-dlab/understanding-decoding/ilkuz1w4" "epfl-dlab/understanding-decoding/2ah4esvq" "epfl-dlab/understanding-decoding/1jl4dhvh" "epfl-dlab/understanding-decoding/2flwrn8g" ) 42 | array2=( "toxicity_bootstrap" "toxicity_bootstrap" "toxicity_bootstrap" "toxicity_bootstrap" ) 43 | 44 | for i in "${!array[@]}"; 45 | do 46 | if [ $PRINT == true ] 47 | then 48 | echo python -m run_evaluation_from_file --wandb_run_path ${array[i]} --overrides evaluation_from_file=${array2[i]} 49 | else 50 | python -m run_evaluation_from_file --wandb_run_path ${array[i]} --overrides evaluation_from_file=${array2[i]} 51 | fi 52 | done 53 | -------------------------------------------------------------------------------- /scripts/launchers/mt_mcts_run1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | # General parameters 6 | #PRINT=true 7 | PRINT=false 8 | 9 | #DEBUG=true 10 | DEBUG=false 11 | 12 | NUM_DATAPOINTS=36 # Doesn't have an effect if DEBUG is false. 13 | LOGGER=wandb_group 14 | 15 | # GPUs and Multiprocessing 16 | VISIBLE_GPUS_STRING="'0,1,2,3'" # More GPUs can be added if available. 17 | NUM_THREADS=8 # Can be increases to the number of GPUs or more if the GPUs fit more models at once. 18 | DATAMODULE_NUM_WORKERS=2 19 | BATCH_SIZE=1 # Can be increases as much as the GPUs / compute allows. 20 | 21 | # Task Specific Parameters 22 | EVALUATION_MODEL="mt_noisy_oracle_b2" 23 | 24 | # Experiment Parameters 25 | LAMBDA=0.01 26 | PB_C_INIT=1.25 27 | 28 | if [ $PRINT == true ] 29 | then 30 | echo python -m run_evaluation evaluation=mbart_translation model/decoding=[mbart_generic,pplmcts] \ 31 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 32 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 33 | evaluation_model=$EVALUATION_MODEL \ 34 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 35 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 36 | trainer.progress_bar_refresh_rate=1 \ 37 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 38 | evaluation_model.noising_function_parameters.lambda=$LAMBDA \ 39 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 40 | logger=$LOGGER \ 41 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 42 | run_name=mbart_translation_mcts_lambda_${LAMBDA}_pb_c_init_${PB_C_INIT} 43 | else 44 | TOKENIZERS_PARALLELISM='false' python -m run_evaluation evaluation=mbart_translation model/decoding=[mbart_generic,pplmcts] \ 45 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 46 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 47 | evaluation_model=$EVALUATION_MODEL \ 48 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 49 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 50 | trainer.progress_bar_refresh_rate=1 \ 51 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 52 | evaluation_model.noising_function_parameters.lambda=$LAMBDA \ 53 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 54 | logger=$LOGGER \ 55 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 56 | run_name=mbart_translation_mcts_lambda_${LAMBDA}_pb_c_init_${PB_C_INIT} 57 | fi 58 | -------------------------------------------------------------------------------- /scripts/launchers/mt_mcts_run2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | # General parameters 6 | #PRINT=true 7 | PRINT=false 8 | 9 | #DEBUG=true 10 | DEBUG=false 11 | 12 | NUM_DATAPOINTS=36 # Doesn't have an effect if DEBUG is false. 13 | LOGGER=wandb_group 14 | 15 | # GPUs and Multiprocessing 16 | VISIBLE_GPUS_STRING="'0,1,2,3'" # More GPUs can be added if available. 17 | NUM_THREADS=8 # Can be increases to the number of GPUs or more if the GPUs fit more models at once. 18 | DATAMODULE_NUM_WORKERS=2 19 | BATCH_SIZE=1 # Can be increases as much as the GPUs / compute allows. 20 | 21 | # Task Specific Parameters 22 | EVALUATION_MODEL="mt_noisy_oracle_b2" 23 | 24 | # Experiment Parameters 25 | LAMBDA=0.15 26 | PB_C_INIT=0.25 27 | 28 | if [ $PRINT == true ] 29 | then 30 | echo python -m run_evaluation evaluation=mbart_translation model/decoding=[mbart_generic,pplmcts] \ 31 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 32 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 33 | evaluation_model=$EVALUATION_MODEL \ 34 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 35 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 36 | trainer.progress_bar_refresh_rate=1 \ 37 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 38 | evaluation_model.noising_function_parameters.lambda=$LAMBDA \ 39 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 40 | logger=$LOGGER \ 41 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 42 | run_name=mbart_translation_mcts_lambda_${LAMBDA}_pb_c_init_${PB_C_INIT} 43 | else 44 | TOKENIZERS_PARALLELISM='false' python -m run_evaluation evaluation=mbart_translation model/decoding=[mbart_generic,pplmcts] \ 45 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 46 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 47 | evaluation_model=$EVALUATION_MODEL \ 48 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 49 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 50 | trainer.progress_bar_refresh_rate=1 \ 51 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 52 | evaluation_model.noising_function_parameters.lambda=$LAMBDA \ 53 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 54 | logger=$LOGGER \ 55 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 56 | run_name=mbart_translation_mcts_lambda_${LAMBDA}_pb_c_init_${PB_C_INIT} 57 | fi 58 | -------------------------------------------------------------------------------- /scripts/launchers/mt_mcts_run3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | # General parameters 6 | #PRINT=true 7 | PRINT=false 8 | 9 | #DEBUG=true 10 | DEBUG=false 11 | 12 | NUM_DATAPOINTS=36 # Doesn't have an effect if DEBUG is false. 13 | LOGGER=wandb_group 14 | 15 | # GPUs and Multiprocessing 16 | VISIBLE_GPUS_STRING="'0,1,2,3'" # More GPUs can be added if available. 17 | NUM_THREADS=8 # Can be increases to the number of GPUs or more if the GPUs fit more models at once. 18 | DATAMODULE_NUM_WORKERS=2 19 | BATCH_SIZE=1 # Can be increases as much as the GPUs / compute allows. 20 | 21 | # Task Specific Parameters 22 | EVALUATION_MODEL="mt_noisy_oracle_b2" 23 | 24 | # Experiment Parameters 25 | LAMBDA=0.25 26 | PB_C_INIT=0.25 27 | 28 | if [ $PRINT == true ] 29 | then 30 | echo python -m run_evaluation evaluation=mbart_translation model/decoding=[mbart_generic,pplmcts] \ 31 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 32 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 33 | evaluation_model=$EVALUATION_MODEL \ 34 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 35 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 36 | trainer.progress_bar_refresh_rate=1 \ 37 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 38 | evaluation_model.noising_function_parameters.lambda=$LAMBDA \ 39 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 40 | logger=$LOGGER \ 41 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 42 | run_name=mbart_translation_mcts_lambda_${LAMBDA}_pb_c_init_${PB_C_INIT} 43 | else 44 | TOKENIZERS_PARALLELISM='false' python -m run_evaluation evaluation=mbart_translation model/decoding=[mbart_generic,pplmcts] \ 45 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 46 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 47 | evaluation_model=$EVALUATION_MODEL \ 48 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 49 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 50 | trainer.progress_bar_refresh_rate=1 \ 51 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 52 | evaluation_model.noising_function_parameters.lambda=$LAMBDA \ 53 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 54 | logger=$LOGGER \ 55 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 56 | run_name=mbart_translation_mcts_lambda_${LAMBDA}_pb_c_init_${PB_C_INIT} 57 | fi 58 | -------------------------------------------------------------------------------- /scripts/launchers/mt_mcts_run4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | # General parameters 6 | #PRINT=true 7 | PRINT=false 8 | 9 | #DEBUG=true 10 | DEBUG=false 11 | 12 | NUM_DATAPOINTS=36 # Doesn't have an effect if DEBUG is false. 13 | LOGGER=wandb_group 14 | 15 | # GPUs and Multiprocessing 16 | VISIBLE_GPUS_STRING="'0,1,2,3'" # More GPUs can be added if available. 17 | NUM_THREADS=8 # Can be increases to the number of GPUs or more if the GPUs fit more models at once. 18 | DATAMODULE_NUM_WORKERS=2 19 | BATCH_SIZE=1 # Can be increases as much as the GPUs / compute allows. 20 | 21 | # Task Specific Parameters 22 | EVALUATION_MODEL="mt_noisy_oracle_b2" 23 | 24 | # Experiment Parameters 25 | LAMBDA=0.35 26 | PB_C_INIT=0.5 27 | 28 | if [ $PRINT == true ] 29 | then 30 | echo python -m run_evaluation evaluation=mbart_translation model/decoding=[mbart_generic,pplmcts] \ 31 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 32 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 33 | evaluation_model=$EVALUATION_MODEL \ 34 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 35 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 36 | trainer.progress_bar_refresh_rate=1 \ 37 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 38 | evaluation_model.noising_function_parameters.lambda=$LAMBDA \ 39 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 40 | logger=$LOGGER \ 41 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 42 | run_name=mbart_translation_mcts_lambda_${LAMBDA}_pb_c_init_${PB_C_INIT} 43 | else 44 | TOKENIZERS_PARALLELISM='false' python -m run_evaluation evaluation=mbart_translation model/decoding=[mbart_generic,pplmcts] \ 45 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 46 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 47 | evaluation_model=$EVALUATION_MODEL \ 48 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 49 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 50 | trainer.progress_bar_refresh_rate=1 \ 51 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 52 | evaluation_model.noising_function_parameters.lambda=$LAMBDA \ 53 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 54 | logger=$LOGGER \ 55 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 56 | run_name=mbart_translation_mcts_lambda_${LAMBDA}_pb_c_init_${PB_C_INIT} 57 | fi 58 | -------------------------------------------------------------------------------- /scripts/launchers/mt_mcts_run5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | # General parameters 6 | #PRINT=true 7 | PRINT=false 8 | 9 | #DEBUG=true 10 | DEBUG=false 11 | 12 | NUM_DATAPOINTS=36 # Doesn't have an effect if DEBUG is false. 13 | LOGGER=wandb_group 14 | 15 | # GPUs and Multiprocessing 16 | VISIBLE_GPUS_STRING="'0,1,2,3'" # More GPUs can be added if available. 17 | NUM_THREADS=8 # Can be increases to the number of GPUs or more if the GPUs fit more models at once. 18 | DATAMODULE_NUM_WORKERS=2 19 | BATCH_SIZE=1 # Can be increases as much as the GPUs / compute allows. 20 | 21 | # Task Specific Parameters 22 | EVALUATION_MODEL="mt_noisy_oracle_b2" 23 | 24 | # Experiment Parameters 25 | LAMBDA=0.5 26 | PB_C_INIT=0.5 27 | 28 | if [ $PRINT == true ] 29 | then 30 | echo python -m run_evaluation evaluation=mbart_translation model/decoding=[mbart_generic,pplmcts] \ 31 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 32 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 33 | evaluation_model=$EVALUATION_MODEL \ 34 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 35 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 36 | trainer.progress_bar_refresh_rate=1 \ 37 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 38 | evaluation_model.noising_function_parameters.lambda=$LAMBDA \ 39 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 40 | logger=$LOGGER \ 41 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 42 | run_name=mbart_translation_mcts_lambda_${LAMBDA}_pb_c_init_${PB_C_INIT} 43 | else 44 | TOKENIZERS_PARALLELISM='false' python -m run_evaluation evaluation=mbart_translation model/decoding=[mbart_generic,pplmcts] \ 45 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 46 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 47 | evaluation_model=$EVALUATION_MODEL \ 48 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 49 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 50 | trainer.progress_bar_refresh_rate=1 \ 51 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 52 | evaluation_model.noising_function_parameters.lambda=$LAMBDA \ 53 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 54 | logger=$LOGGER \ 55 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 56 | run_name=mbart_translation_mcts_lambda_${LAMBDA}_pb_c_init_${PB_C_INIT} 57 | fi 58 | -------------------------------------------------------------------------------- /scripts/launchers/mt_mcts_run6.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | # General parameters 6 | #PRINT=true 7 | PRINT=false 8 | 9 | #DEBUG=true 10 | DEBUG=false 11 | 12 | NUM_DATAPOINTS=36 # Doesn't have an effect if DEBUG is false. 13 | LOGGER=wandb_group 14 | 15 | # GPUs and Multiprocessing 16 | VISIBLE_GPUS_STRING="'0,1,2,3'" # More GPUs can be added if available. 17 | NUM_THREADS=8 # Can be increases to the number of GPUs or more if the GPUs fit more models at once. 18 | DATAMODULE_NUM_WORKERS=2 19 | BATCH_SIZE=1 # Can be increases as much as the GPUs / compute allows. 20 | 21 | # Task Specific Parameters 22 | EVALUATION_MODEL="mt_noisy_oracle_b2" 23 | 24 | # Experiment Parameters 25 | LAMBDA=0.99 26 | PB_C_INIT=1.25 27 | 28 | if [ $PRINT == true ] 29 | then 30 | echo python -m run_evaluation evaluation=mbart_translation model/decoding=[mbart_generic,pplmcts] \ 31 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 32 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 33 | evaluation_model=$EVALUATION_MODEL \ 34 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 35 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 36 | trainer.progress_bar_refresh_rate=1 \ 37 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 38 | evaluation_model.noising_function_parameters.lambda=$LAMBDA \ 39 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 40 | logger=$LOGGER \ 41 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 42 | run_name=mbart_translation_mcts_lambda_${LAMBDA}_pb_c_init_${PB_C_INIT} 43 | else 44 | TOKENIZERS_PARALLELISM='false' python -m run_evaluation evaluation=mbart_translation model/decoding=[mbart_generic,pplmcts] \ 45 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 46 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 47 | evaluation_model=$EVALUATION_MODEL \ 48 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 49 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 50 | trainer.progress_bar_refresh_rate=1 \ 51 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 52 | evaluation_model.noising_function_parameters.lambda=$LAMBDA \ 53 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 54 | logger=$LOGGER \ 55 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 56 | run_name=mbart_translation_mcts_lambda_${LAMBDA}_pb_c_init_${PB_C_INIT} 57 | fi 58 | -------------------------------------------------------------------------------- /scripts/launchers/toxicity_mcts_run1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | # General parameters 6 | #PRINT=true 7 | PRINT=false 8 | 9 | #DEBUG=true 10 | DEBUG=false 11 | 12 | NUM_DATAPOINTS=96 # Doesn't have an effect if DEBUG is false. 13 | LOGGER=wandb_group 14 | 15 | # GPUs and Multiprocessing 16 | VISIBLE_GPUS_STRING="'0,1,2,3,4,5,6,7'" 17 | DATAMODULE_NUM_WORKERS=1 18 | BATCH_SIZE=4 19 | EVALUATION_MODEL_BATCH_SIZE=20 20 | 21 | NUM_THREADS=16 22 | 23 | # Experiment Parameters 24 | EVALUATION_MODEL="detoxify_oracle" 25 | PB_C_INIT=0.25 26 | 27 | if [ $PRINT == true ] 28 | then 29 | echo python -m run_evaluation evaluation=gpt2_toxicgen model/decoding=[gpt_generic,pplmcts] \ 30 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 31 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 32 | evaluation_model=$EVALUATION_MODEL \ 33 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 34 | evaluation_model.batch_size=$EVALUATION_MODEL_BATCH_SIZE \ 35 | evaluation_model.device=cpu \ 36 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 37 | +datamodule.dataset_parameters.test.dataset.subsample=True \ 38 | trainer.progress_bar_refresh_rate=1 \ 39 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 40 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 41 | logger=$LOGGER \ 42 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 43 | run_name=gpt2_toxicity_mcts_em_${EVALUATION_MODEL}_pb_c_init_${PB_C_INIT} 44 | else 45 | TOKENIZERS_PARALLELISM='false' python -m run_evaluation evaluation=gpt2_toxicgen model/decoding=[gpt_generic,pplmcts] \ 46 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 47 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 48 | evaluation_model=$EVALUATION_MODEL \ 49 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 50 | evaluation_model.batch_size=$EVALUATION_MODEL_BATCH_SIZE \ 51 | evaluation_model.device=cpu \ 52 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 53 | +datamodule.dataset_parameters.test.dataset.subsample=True \ 54 | trainer.progress_bar_refresh_rate=1 \ 55 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 56 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 57 | logger=$LOGGER \ 58 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 59 | run_name=gpt2_toxicity_mcts_em_${EVALUATION_MODEL}_pb_c_init_${PB_C_INIT} 60 | fi 61 | -------------------------------------------------------------------------------- /scripts/launchers/toxicity_mcts_run2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | # General parameters 6 | #PRINT=true 7 | PRINT=false 8 | 9 | #DEBUG=true 10 | DEBUG=false 11 | 12 | NUM_DATAPOINTS=96 # Doesn't have an effect if DEBUG is false. 13 | LOGGER=wandb_group 14 | 15 | # GPUs and Multiprocessing 16 | VISIBLE_GPUS_STRING="'0,1,2,3,4,5,6,7'" 17 | DATAMODULE_NUM_WORKERS=1 18 | BATCH_SIZE=4 19 | EVALUATION_MODEL_BATCH_SIZE=20 20 | 21 | NUM_THREADS=16 22 | 23 | # Experiment Parameters 24 | EVALUATION_MODEL="detoxify_noisy_oracle_rmse_01976" 25 | PB_C_INIT=1.25 26 | 27 | if [ $PRINT == true ] 28 | then 29 | echo python -m run_evaluation evaluation=gpt2_toxicgen model/decoding=[gpt_generic,pplmcts] \ 30 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 31 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 32 | evaluation_model=$EVALUATION_MODEL \ 33 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 34 | evaluation_model.batch_size=$EVALUATION_MODEL_BATCH_SIZE \ 35 | evaluation_model.device=cpu \ 36 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 37 | +datamodule.dataset_parameters.test.dataset.subsample=True \ 38 | trainer.progress_bar_refresh_rate=1 \ 39 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 40 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 41 | logger=$LOGGER \ 42 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 43 | run_name=gpt2_toxicity_mcts_em_${EVALUATION_MODEL}_pb_c_init_${PB_C_INIT} 44 | else 45 | TOKENIZERS_PARALLELISM='false' python -m run_evaluation evaluation=gpt2_toxicgen model/decoding=[gpt_generic,pplmcts] \ 46 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 47 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 48 | evaluation_model=$EVALUATION_MODEL \ 49 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 50 | evaluation_model.batch_size=$EVALUATION_MODEL_BATCH_SIZE \ 51 | evaluation_model.device=cpu \ 52 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 53 | +datamodule.dataset_parameters.test.dataset.subsample=True \ 54 | trainer.progress_bar_refresh_rate=1 \ 55 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 56 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 57 | logger=$LOGGER \ 58 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 59 | run_name=gpt2_toxicity_mcts_em_${EVALUATION_MODEL}_pb_c_init_${PB_C_INIT} 60 | fi 61 | -------------------------------------------------------------------------------- /scripts/launchers/toxicity_mcts_run3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | # General parameters 6 | #PRINT=true 7 | PRINT=false 8 | 9 | #DEBUG=true 10 | DEBUG=false 11 | 12 | NUM_DATAPOINTS=96 # Doesn't have an effect if DEBUG is false. 13 | LOGGER=wandb_group 14 | 15 | # GPUs and Multiprocessing 16 | VISIBLE_GPUS_STRING="'0,1,2,3,4,5,6,7'" 17 | DATAMODULE_NUM_WORKERS=1 18 | BATCH_SIZE=4 19 | EVALUATION_MODEL_BATCH_SIZE=20 20 | 21 | NUM_THREADS=16 22 | 23 | # Experiment Parameters 24 | EVALUATION_MODEL="detoxify_noisy_oracle_rmse_02242" 25 | PB_C_INIT=1.25 26 | ABSOLUTE_PATH_TO_OLD_EXP_DIR="/vc_data_1/users/bapatra/code/understanding-decoding/logs/evaluation/runs/gpt2_toxicity_mcts_em_detoxify_noisy_oracle_rmse_02242_pb_c_init_1.25/2022-07-30_18-04-31" 27 | 28 | if [ $PRINT == true ] 29 | then 30 | echo python -m run_evaluation evaluation=gpt2_toxicgen model/decoding=[gpt_generic,pplmcts] \ 31 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 32 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 33 | evaluation_model=$EVALUATION_MODEL \ 34 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 35 | evaluation_model.batch_size=$EVALUATION_MODEL_BATCH_SIZE \ 36 | evaluation_model.device=cpu \ 37 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 38 | +datamodule.dataset_parameters.test.dataset.subsample=True \ 39 | trainer.progress_bar_refresh_rate=1 \ 40 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 41 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 42 | logger=$LOGGER \ 43 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 44 | run_name=gpt2_toxicity_mcts_em_${EVALUATION_MODEL}_pb_c_init_${PB_C_INIT} 45 | else 46 | TOKENIZERS_PARALLELISM='false' python -m run_evaluation evaluation=gpt2_toxicgen model/decoding=[gpt_generic,pplmcts] \ 47 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 48 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 49 | evaluation_model=$EVALUATION_MODEL \ 50 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 51 | evaluation_model.batch_size=$EVALUATION_MODEL_BATCH_SIZE \ 52 | evaluation_model.device=cpu \ 53 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 54 | +datamodule.dataset_parameters.test.dataset.subsample=True \ 55 | trainer.progress_bar_refresh_rate=1 \ 56 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 57 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 58 | logger=$LOGGER \ 59 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 60 | run_name=gpt2_toxicity_mcts_em_${EVALUATION_MODEL}_pb_c_init_${PB_C_INIT} 61 | fi 62 | 63 | -------------------------------------------------------------------------------- /scripts/launchers/toxicity_mcts_run4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | # General parameters 6 | #PRINT=true 7 | PRINT=false 8 | 9 | #DEBUG=true 10 | DEBUG=false 11 | 12 | NUM_DATAPOINTS=96 # Doesn't have an effect if DEBUG is false. 13 | LOGGER=wandb_group 14 | 15 | # GPUs and Multiprocessing 16 | VISIBLE_GPUS_STRING="'0,1,2,3,4,5,6,7'" 17 | DATAMODULE_NUM_WORKERS=1 18 | BATCH_SIZE=4 19 | EVALUATION_MODEL_BATCH_SIZE=20 20 | 21 | NUM_THREADS=16 22 | 23 | # Experiment Parameters 24 | EVALUATION_MODEL="detoxify_noisy_oracle_rmse_02165" 25 | PB_C_INIT=1.25 26 | 27 | if [ $PRINT == true ] 28 | then 29 | echo python -m run_evaluation evaluation=gpt2_toxicgen model/decoding=[gpt_generic,pplmcts] \ 30 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 31 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 32 | evaluation_model=$EVALUATION_MODEL \ 33 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 34 | evaluation_model.batch_size=$EVALUATION_MODEL_BATCH_SIZE \ 35 | evaluation_model.device=cpu \ 36 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 37 | +datamodule.dataset_parameters.test.dataset.subsample=True \ 38 | trainer.progress_bar_refresh_rate=1 \ 39 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 40 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 41 | logger=$LOGGER \ 42 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 43 | run_name=gpt2_toxicity_mcts_em_${EVALUATION_MODEL}_pb_c_init_${PB_C_INIT} 44 | else 45 | python -m run_evaluation evaluation=gpt2_toxicgen model/decoding=[gpt_generic,pplmcts] \ 46 | model.decoding.hf_generation_params.mcts_topk_actions=20 \ 47 | model.decoding.hf_generation_params.mcts_num_simulations=50 \ 48 | evaluation_model=$EVALUATION_MODEL \ 49 | evaluation_model.device=cpu \ 50 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 51 | evaluation_model.batch_size=$EVALUATION_MODEL_BATCH_SIZE \ 52 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 53 | +datamodule.dataset_parameters.test.dataset.subsample=True \ 54 | trainer.progress_bar_refresh_rate=1 \ 55 | trainer=ddp trainer.gpus=0 +trainer.devices=$NUM_THREADS trainer.accelerator='cpu' +model.scatter_accross_gpus=True \ 56 | model.decoding.hf_generation_params.mcts_pb_c_init=$PB_C_INIT \ 57 | logger=$LOGGER \ 58 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 59 | run_name=gpt2_toxicity_mcts_em_${EVALUATION_MODEL}_pb_c_init_${PB_C_INIT} 60 | fi 61 | -------------------------------------------------------------------------------- /scripts/launchers/toxicity_vgbs_run1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | # General parameters 6 | #PRINT=true 7 | PRINT=false 8 | 9 | #DEBUG=true 10 | DEBUG=false 11 | 12 | NUM_DATAPOINTS=36 # Doesn't have an effect if DEBUG is false. 13 | LOGGER=wandb_group 14 | 15 | # GPUs and Multiprocessing 16 | VISIBLE_GPUS_STRING="'0,1,2,3,4,5,6,7'" 17 | DATAMODULE_NUM_WORKERS=3 18 | BATCH_SIZE=6 # Increase if there is more memory, decrease if it is too much. 19 | EVALUATION_MODEL_BATCH_SIZE=$((BATCH_SIZE * 100)) 20 | 21 | NUM_GPUS=8 22 | 23 | # Experiment Parameters 24 | EVALUATION_MODEL="detoxify_oracle" 25 | CONTRIBUTION_FACTOR=0.25 26 | 27 | 28 | if [ $PRINT == true ] 29 | then 30 | echo python -m run_evaluation evaluation=gpt2_toxicgen model/decoding=[gpt_generic,value_guided_beam_search] \ 31 | model.decoding.hf_generation_params.num_beams=5 \ 32 | model.decoding.hf_generation_params.tokens_considered_by_value_processor=20 \ 33 | evaluation_model=$EVALUATION_MODEL \ 34 | evaluation_model.batch_size=$EVALUATION_MODEL_BATCH_SIZE \ 35 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 36 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 37 | +datamodule.dataset_parameters.test.dataset.subsample=True \ 38 | trainer.progress_bar_refresh_rate=1 \ 39 | trainer=ddp trainer.gpus=$NUM_GPUS +model.scatter_accross_gpus=True \ 40 | model.decoding.hf_generation_params.contribution_factor=$CONTRIBUTION_FACTOR \ 41 | logger=$LOGGER \ 42 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 43 | run_name=gpt2_toxicity_vgbs_em_${EVALUATION_MODEL}_cf_${CONTRIBUTION_FACTOR} 44 | else 45 | python -m run_evaluation evaluation=gpt2_toxicgen model/decoding=[gpt_generic,value_guided_beam_search] \ 46 | model.decoding.hf_generation_params.num_beams=5 \ 47 | model.decoding.hf_generation_params.tokens_considered_by_value_processor=20 \ 48 | evaluation_model=$EVALUATION_MODEL \ 49 | evaluation_model.batch_size=$EVALUATION_MODEL_BATCH_SIZE \ 50 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 51 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 52 | +datamodule.dataset_parameters.test.dataset.subsample=True \ 53 | trainer.progress_bar_refresh_rate=1 \ 54 | trainer=ddp trainer.gpus=$NUM_GPUS +model.scatter_accross_gpus=True \ 55 | model.decoding.hf_generation_params.contribution_factor=$CONTRIBUTION_FACTOR \ 56 | logger=$LOGGER \ 57 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 58 | run_name=gpt2_toxicity_vgbs_em_${EVALUATION_MODEL}_cf_${CONTRIBUTION_FACTOR} 59 | fi -------------------------------------------------------------------------------- /scripts/launchers/toxicity_vgbs_run2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | # General parameters 6 | #PRINT=true 7 | PRINT=false 8 | 9 | #DEBUG=true 10 | DEBUG=false 11 | 12 | NUM_DATAPOINTS=36 # Doesn't have an effect if DEBUG is false. 13 | LOGGER=wandb_group 14 | 15 | # GPUs and Multiprocessing 16 | VISIBLE_GPUS_STRING="'0,1,2,3,4,5,6,7'" 17 | DATAMODULE_NUM_WORKERS=3 18 | BATCH_SIZE=6 # Increase if there is more memory, decrease if it is too much. 19 | EVALUATION_MODEL_BATCH_SIZE=$((BATCH_SIZE * 100)) 20 | 21 | NUM_GPUS=8 22 | 23 | # Experiment Parameters 24 | EVALUATION_MODEL="detoxify_noisy_oracle_rmse_01976" 25 | CONTRIBUTION_FACTOR=0.25 26 | 27 | 28 | if [ $PRINT == true ] 29 | then 30 | echo python -m run_evaluation evaluation=gpt2_toxicgen model/decoding=[gpt_generic,value_guided_beam_search] \ 31 | model.decoding.hf_generation_params.num_beams=5 \ 32 | model.decoding.hf_generation_params.tokens_considered_by_value_processor=20 \ 33 | evaluation_model=$EVALUATION_MODEL \ 34 | evaluation_model.batch_size=$EVALUATION_MODEL_BATCH_SIZE \ 35 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 36 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 37 | trainer.progress_bar_refresh_rate=1 \ 38 | trainer=ddp trainer.gpus=$NUM_GPUS +model.scatter_accross_gpus=True \ 39 | model.decoding.hf_generation_params.contribution_factor=$CONTRIBUTION_FACTOR \ 40 | logger=$LOGGER \ 41 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 42 | run_name=gpt2_toxicity_vgbs_em_${EVALUATION_MODEL}_cf_${CONTRIBUTION_FACTOR} 43 | else 44 | python -m run_evaluation evaluation=gpt2_toxicgen model/decoding=[gpt_generic,value_guided_beam_search] \ 45 | model.decoding.hf_generation_params.num_beams=5 \ 46 | model.decoding.hf_generation_params.tokens_considered_by_value_processor=20 \ 47 | evaluation_model=$EVALUATION_MODEL \ 48 | evaluation_model.batch_size=$EVALUATION_MODEL_BATCH_SIZE \ 49 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 50 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 51 | +datamodule.dataset_parameters.test.dataset.subsample=True \ 52 | trainer.progress_bar_refresh_rate=1 \ 53 | trainer=ddp trainer.gpus=$NUM_GPUS +model.scatter_accross_gpus=True \ 54 | model.decoding.hf_generation_params.contribution_factor=$CONTRIBUTION_FACTOR \ 55 | logger=$LOGGER \ 56 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 57 | run_name=gpt2_toxicity_vgbs_em_${EVALUATION_MODEL}_cf_${CONTRIBUTION_FACTOR} 58 | fi 59 | -------------------------------------------------------------------------------- /scripts/launchers/toxicity_vgbs_run3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | # General parameters 6 | #PRINT=true 7 | PRINT=false 8 | 9 | #DEBUG=true 10 | DEBUG=false 11 | 12 | NUM_DATAPOINTS=36 # Doesn't have an effect if DEBUG is false. 13 | LOGGER=wandb_group 14 | 15 | # GPUs and Multiprocessing 16 | VISIBLE_GPUS_STRING="'0,1,2,3,4,5,6,7'" 17 | DATAMODULE_NUM_WORKERS=3 18 | BATCH_SIZE=6 # Increase if there is more memory, decrease if it is too much. 19 | EVALUATION_MODEL_BATCH_SIZE=$((BATCH_SIZE * 100)) 20 | 21 | NUM_GPUS=8 22 | 23 | # Experiment Parameters 24 | EVALUATION_MODEL="detoxify_noisy_oracle_rmse_02242" 25 | CONTRIBUTION_FACTOR=0.25 26 | 27 | 28 | if [ $PRINT == true ] 29 | then 30 | echo python -m run_evaluation evaluation=gpt2_toxicgen model/decoding=[gpt_generic,value_guided_beam_search] \ 31 | model.decoding.hf_generation_params.num_beams=5 \ 32 | model.decoding.hf_generation_params.tokens_considered_by_value_processor=20 \ 33 | evaluation_model=$EVALUATION_MODEL \ 34 | evaluation_model.batch_size=$EVALUATION_MODEL_BATCH_SIZE \ 35 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 36 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 37 | trainer.progress_bar_refresh_rate=1 \ 38 | trainer=ddp trainer.gpus=$NUM_GPUS +model.scatter_accross_gpus=True \ 39 | model.decoding.hf_generation_params.contribution_factor=$CONTRIBUTION_FACTOR \ 40 | logger=$LOGGER \ 41 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 42 | run_name=gpt2_toxicity_vgbs_em_${EVALUATION_MODEL}_cf_${CONTRIBUTION_FACTOR} 43 | else 44 | python -m run_evaluation evaluation=gpt2_toxicgen model/decoding=[gpt_generic,value_guided_beam_search] \ 45 | model.decoding.hf_generation_params.num_beams=5 \ 46 | model.decoding.hf_generation_params.tokens_considered_by_value_processor=20 \ 47 | evaluation_model=$EVALUATION_MODEL \ 48 | evaluation_model.batch_size=$EVALUATION_MODEL_BATCH_SIZE \ 49 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 50 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 51 | +datamodule.dataset_parameters.test.dataset.subsample=True \ 52 | trainer.progress_bar_refresh_rate=1 \ 53 | trainer=ddp trainer.gpus=$NUM_GPUS +model.scatter_accross_gpus=True \ 54 | model.decoding.hf_generation_params.contribution_factor=$CONTRIBUTION_FACTOR \ 55 | logger=$LOGGER \ 56 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 57 | run_name=gpt2_toxicity_vgbs_em_${EVALUATION_MODEL}_cf_${CONTRIBUTION_FACTOR} 58 | fi 59 | -------------------------------------------------------------------------------- /scripts/launchers/toxicity_vgbs_run4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | # General parameters 6 | #PRINT=true 7 | PRINT=false 8 | 9 | #DEBUG=true 10 | DEBUG=false 11 | 12 | NUM_DATAPOINTS=36 # Doesn't have an effect if DEBUG is false. 13 | LOGGER=wandb_group 14 | 15 | # GPUs and Multiprocessing 16 | VISIBLE_GPUS_STRING="'0,1,2,3,4,5,6,7'" 17 | DATAMODULE_NUM_WORKERS=3 18 | BATCH_SIZE=6 # Increase if there is more memory, decrease if it is too much. 19 | EVALUATION_MODEL_BATCH_SIZE=$((BATCH_SIZE * 100)) 20 | 21 | NUM_GPUS=8 22 | 23 | # Experiment Parameters 24 | EVALUATION_MODEL="detoxify_noisy_oracle_rmse_02165" 25 | CONTRIBUTION_FACTOR=0.25 26 | 27 | 28 | if [ $PRINT == true ] 29 | then 30 | echo python -m run_evaluation evaluation=gpt2_toxicgen model/decoding=[gpt_generic,value_guided_beam_search] \ 31 | model.decoding.hf_generation_params.num_beams=5 \ 32 | model.decoding.hf_generation_params.tokens_considered_by_value_processor=20 \ 33 | evaluation_model=$EVALUATION_MODEL \ 34 | evaluation_model.batch_size=$EVALUATION_MODEL_BATCH_SIZE \ 35 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 36 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 37 | trainer.progress_bar_refresh_rate=1 \ 38 | trainer=ddp trainer.gpus=$NUM_GPUS +model.scatter_accross_gpus=True \ 39 | model.decoding.hf_generation_params.contribution_factor=$CONTRIBUTION_FACTOR \ 40 | logger=$LOGGER \ 41 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 42 | run_name=gpt2_toxicity_vgbs_em_${EVALUATION_MODEL}_cf_${CONTRIBUTION_FACTOR} 43 | else 44 | python -m run_evaluation evaluation=gpt2_toxicgen model/decoding=[gpt_generic,value_guided_beam_search] \ 45 | model.decoding.hf_generation_params.num_beams=5 \ 46 | model.decoding.hf_generation_params.tokens_considered_by_value_processor=20 \ 47 | evaluation_model=$EVALUATION_MODEL \ 48 | evaluation_model.batch_size=$EVALUATION_MODEL_BATCH_SIZE \ 49 | datamodule.dataset_parameters.test.dataloader.batch_size=$BATCH_SIZE \ 50 | datamodule.debug=$DEBUG datamodule.debug_k=$NUM_DATAPOINTS datamodule.num_workers=$DATAMODULE_NUM_WORKERS \ 51 | +datamodule.dataset_parameters.test.dataset.subsample=True \ 52 | trainer.progress_bar_refresh_rate=1 \ 53 | trainer=ddp trainer.gpus=$NUM_GPUS +model.scatter_accross_gpus=True \ 54 | model.decoding.hf_generation_params.contribution_factor=$CONTRIBUTION_FACTOR \ 55 | logger=$LOGGER \ 56 | +hydra.job.env_set.CUDA_VISIBLE_DEVICES=$VISIBLE_GPUS_STRING \ 57 | run_name=gpt2_toxicity_vgbs_em_${EVALUATION_MODEL}_cf_${CONTRIBUTION_FACTOR} 58 | fi -------------------------------------------------------------------------------- /scripts/plotting/figure_2_RQ1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | PRINT=false 6 | 7 | GRIDSIZE=10 8 | 9 | ENTITY=epfl-dlab 10 | ROOT=epfl-dlab/understanding-decoding 11 | VERSION=paper-v1 12 | 13 | datasets=("rtp" "wmt14" "rebel" "swissprot") 14 | paths=( 15 | "$ROOT/6yf8p424 $ROOT/34kq9nsf $ROOT/2mgcqsyk" 16 | "$ROOT/1y6yme2k $ROOT/3410qxd0 $ROOT/gigllz7f" 17 | "$ROOT/9lbtlj0l $ROOT/28vcj2b0 $ROOT/2cqpyxit" 18 | "$ROOT/1k2x9th9 $ROOT/2fhbfm2e $ROOT/3vh6a5h8" 19 | ) 20 | 21 | for i in "${!datasets[@]}"; 22 | do 23 | if [ $PRINT == true ] 24 | then 25 | echo python scripts/visualize_joint_plots.py \ 26 | --dataset ${datasets[i]} \ 27 | --gridsize $GRIDSIZE \ 28 | --share_colorbar \ 29 | --wandb_entity_for_report $ENTITY \ 30 | --wandb_run_paths ${paths[i]} \ 31 | --wandb_id_for_report Fig2-$VERSION 32 | else 33 | python scripts/visualize_joint_plots.py \ 34 | --dataset ${datasets[i]} \ 35 | --gridsize $GRIDSIZE \ 36 | --share_colorbar \ 37 | --wandb_entity_for_report $ENTITY \ 38 | --wandb_run_paths ${paths[i]} \ 39 | --wandb_id_for_report Fig2-$VERSION 40 | fi 41 | done 42 | -------------------------------------------------------------------------------- /scripts/plotting/figure_2_RQ1_unnormalized.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | PRINT=false 6 | 7 | GRIDSIZE=10 8 | 9 | ENTITY=epfl-dlab 10 | ROOT=epfl-dlab/understanding-decoding 11 | VERSION=paper-unnomralized-v1 12 | 13 | datasets=("rtp" "wmt14" "rebel" "swissprot") 14 | paths=( 15 | "$ROOT/6yf8p424 $ROOT/34kq9nsf $ROOT/2mgcqsyk" 16 | "$ROOT/1y6yme2k $ROOT/3410qxd0 $ROOT/gigllz7f" 17 | "$ROOT/9lbtlj0l $ROOT/28vcj2b0 $ROOT/2cqpyxit" 18 | "$ROOT/1k2x9th9 $ROOT/2fhbfm2e $ROOT/3vh6a5h8" 19 | ) 20 | 21 | for i in "${!datasets[@]}"; 22 | do 23 | if [ $PRINT == true ] 24 | then 25 | echo python scripts/visualize_joint_plots.py \ 26 | --dataset ${datasets[i]} \ 27 | --gridsize $GRIDSIZE \ 28 | --share_colorbar \ 29 | --wandb_entity_for_report $ENTITY \ 30 | --wandb_run_paths ${paths[i]} \ 31 | --wandb_id_for_report Fig2-$VERSION \ 32 | --no_target_normalization 33 | else 34 | python scripts/visualize_joint_plots.py \ 35 | --dataset ${datasets[i]} \ 36 | --gridsize $GRIDSIZE \ 37 | --share_colorbar \ 38 | --wandb_entity_for_report $ENTITY \ 39 | --wandb_run_paths ${paths[i]} \ 40 | --wandb_id_for_report Fig2-$VERSION \ 41 | --no_target_normalization 42 | fi 43 | done 44 | -------------------------------------------------------------------------------- /scripts/plotting/figure_3_RQ1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | PRINT=false 6 | 7 | GRIDSIZE=10 8 | 9 | ENTITY=epfl-dlab 10 | ROOT=epfl-dlab/understanding-decoding 11 | VERSION=paper-v1 12 | 13 | datasets=("rtp" "wmt14" "rebel" "swissprot") 14 | paths=( 15 | "$ROOT/6yf8p424 $ROOT/34kq9nsf $ROOT/2mgcqsyk" 16 | "$ROOT/1y6yme2k $ROOT/3410qxd0 $ROOT/gigllz7f" 17 | "$ROOT/9lbtlj0l $ROOT/28vcj2b0 $ROOT/2cqpyxit" 18 | "$ROOT/1k2x9th9 $ROOT/2fhbfm2e $ROOT/3vh6a5h8" 19 | ) 20 | 21 | for i in "${!datasets[@]}"; 22 | do 23 | if [ $PRINT == true ] 24 | then 25 | echo python scripts/visualize_joint_plots.py \ 26 | --dataset ${datasets[i]} \ 27 | --plot_correlation \ 28 | --gridsize $GRIDSIZE \ 29 | --share_colorbar \ 30 | --wandb_entity_for_report $ENTITY \ 31 | --wandb_run_paths ${paths[i]} \ 32 | --wandb_id_for_report Fig3-$VERSION 33 | else 34 | python scripts/visualize_joint_plots.py \ 35 | --dataset ${datasets[i]} \ 36 | --plot_correlation \ 37 | --gridsize $GRIDSIZE \ 38 | --share_colorbar \ 39 | --wandb_entity_for_report $ENTITY \ 40 | --wandb_run_paths ${paths[i]} \ 41 | --wandb_id_for_report Fig3-$VERSION 42 | fi 43 | done 44 | -------------------------------------------------------------------------------- /scripts/plotting/figure_4_RQ2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | PRINT=false 6 | 7 | CI=0.95_confidence_level_50_bootstrap_samples 8 | 9 | ENTITY=epfl-dlab 10 | ROOT=epfl-dlab/understanding-decoding 11 | VERSION=paper-v1 12 | 13 | if [ $PRINT == true ] 14 | then 15 | echo python -m scripts.visualize_figure4 \ 16 | --confidence_interval_id $CI \ 17 | --wandb_entity_for_report $ENTITY \ 18 | --wandb_run_paths \ 19 | $ROOT/2xogfbel $ROOT/38vrpbjt $ROOT/2gchwiy1 $ROOT/3jtpx0zz $ROOT/3g2q6r6x $ROOT/yqyctpgx \ 20 | $ROOT/2xdt1zgn $ROOT/7y0pr9sw $ROOT/13niq0lx $ROOT/uykq2pgp $ROOT/37wx7n8u $ROOT/2yc65rt8 \ 21 | --wandb_id_for_report Fig4-$VERSION; 22 | 23 | echo python -m scripts.visualize_figure4 \ 24 | --confidence_interval_id $CI \ 25 | --wandb_entity_for_report $ENTITY \ 26 | --wandb_run_paths \ 27 | $ROOT/3a5ydpr4 $ROOT/1yc972vw $ROOT/2an1uakn $ROOT/2lveviea \ 28 | $ROOT/ilkuz1w4 $ROOT/2ah4esvq $ROOT/1jl4dhvh $ROOT/2flwrn8g \ 29 | --wandb_id_for_report Fig4-$VERSION 30 | else 31 | python scripts/visualize_figure4.py \ 32 | --confidence_interval_id $CI \ 33 | --wandb_entity_for_report $ENTITY \ 34 | --wandb_run_paths \ 35 | $ROOT/2xogfbel $ROOT/38vrpbjt $ROOT/2gchwiy1 $ROOT/3jtpx0zz $ROOT/3g2q6r6x $ROOT/yqyctpgx \ 36 | $ROOT/2xdt1zgn $ROOT/7y0pr9sw $ROOT/13niq0lx $ROOT/uykq2pgp $ROOT/37wx7n8u $ROOT/2yc65rt8 \ 37 | --wandb_id_for_report Fig4-$VERSION; 38 | 39 | python scripts/visualize_figure4.py \ 40 | --confidence_interval_id $CI \ 41 | --wandb_entity_for_report $ENTITY \ 42 | --wandb_run_paths \ 43 | $ROOT/3a5ydpr4 $ROOT/1yc972vw $ROOT/2an1uakn $ROOT/2lveviea \ 44 | $ROOT/ilkuz1w4 $ROOT/2ah4esvq $ROOT/1jl4dhvh $ROOT/2flwrn8g \ 45 | --wandb_id_for_report Fig4-$VERSION 46 | fi 47 | -------------------------------------------------------------------------------- /scripts/plotting/figure_5_RQ3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=".:transformers/src:mctx" 4 | 5 | PRINT=false 6 | 7 | ENTITY=epfl-dlab 8 | VERSION=paper-v3 9 | 10 | if [ $PRINT == true ] 11 | then 12 | echo python scripts/visualize_figure5.py \ 13 | --wandb_entity_for_report $ENTITY \ 14 | --wandb_id_for_report Fig5-$VERSION 15 | else 16 | python scripts/visualize_figure5.py \ 17 | --wandb_entity_for_report $ENTITY \ 18 | --wandb_id_for_report Fig5-$VERSION 19 | fi 20 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-dlab/understanding-decoding/f0976e67964e5d083e90f2ff0b2a1662aa240ba3/src/__init__.py -------------------------------------------------------------------------------- /src/constrained_generation/__init__.py: -------------------------------------------------------------------------------- 1 | from .ie_prefix_constraints import get_information_extraction_prefix_allowed_tokens_fn_hf 2 | from .trie import Trie, get_trie_from_strings 3 | -------------------------------------------------------------------------------- /src/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .wmt14 import WMT14DataModule, WMT14Dataset, WMT14OutputDataset 2 | from .rebel import RebelDataModule, RebelDataset, RebelOutputDataset 3 | from .rtp import RTPDataModule, RTPDataset, RTPOutputDataset 4 | from .swissprot import ProteinDataModule, ProteinDataset, ProteinOutputDataset 5 | -------------------------------------------------------------------------------- /src/datamodules/helper_data_classes.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | 4 | class SampleInformation: 5 | pred_score: float 6 | pred_utility: float 7 | tgt_score: Optional[float] 8 | 9 | def __init__(self, pred_score, pred_utility=None, tgt_score=None, pred_correlation=None): 10 | self.pred_score = pred_score 11 | self.pred_utility = pred_utility 12 | self.tgt_score = tgt_score 13 | self.pred_correlation = pred_correlation 14 | 15 | def set_correlation(self, pred_correlation): 16 | self.pred_correlation = pred_correlation 17 | 18 | def get_score(self, difference, ratio): 19 | """ 20 | Returns the (normalized or unnormalized) score for the sample. 21 | 22 | Parameters 23 | ---------- 24 | difference : "normalize" by subtracting the target_score 25 | ratio : "normalize" by reporting the ratio w.r.t. to the target_score 26 | 27 | Returns 28 | ------- 29 | score : float 30 | is_normalized : bool 31 | """ 32 | if self.tgt_score is None: 33 | # Normalization cannot be applied without a tgt_score 34 | difference = False 35 | ratio = False 36 | 37 | assert not (difference and ratio), "Difference and ratio cannot be simultaneously selected!" 38 | 39 | if difference: 40 | return self.pred_score - self.tgt_score, difference or ratio 41 | 42 | if ratio: 43 | return self.pred_score / self.tgt_score, difference or ratio 44 | 45 | return self.pred_score, difference or ratio 46 | 47 | def get_summary(self, difference=True, ratio=False): 48 | """ 49 | Returns a summary for the sample that will be used in the visualization. 50 | 51 | Parameters 52 | ---------- 53 | difference : See get_score. 54 | ratio : See get_score. 55 | 56 | Returns 57 | ------- 58 | 59 | """ 60 | summary = {"score": self.get_score(difference, ratio), "utility": self.pred_utility} 61 | if self.pred_correlation is not None: 62 | summary["correlation"] = self.pred_correlation 63 | 64 | return summary 65 | -------------------------------------------------------------------------------- /src/datamodules/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .triplet_utils import TripletUtils 2 | from .resampling_utils import select_data_to_resample 3 | -------------------------------------------------------------------------------- /src/datamodules/utils/resampling_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | 5 | import src.utils.general as utils 6 | from src.utils.evaluation import select_indices_on_quantiles 7 | 8 | 9 | log = utils.get_logger(__name__) 10 | 11 | 12 | def _filter_data_points_on(data, key): 13 | filtered_data = [dp for dp in data if dp[key][0] != None] 14 | log.info(f"Samples filtered due to invalid `{key}`: {len(data)-len(filtered_data)}") 15 | return filtered_data 16 | 17 | 18 | def select_data_to_resample(output_dataset, num_qs, testing_output_parent_dir=None): 19 | dataset = output_dataset.data 20 | 21 | dataset = _filter_data_points_on(dataset, key="prediction_log_likelihood") 22 | has_targets = "target_log_likelihood" in dataset[0].keys() 23 | if has_targets: 24 | dataset = _filter_data_points_on(dataset, key="target_log_likelihood") 25 | 26 | data = pd.DataFrame(dataset) 27 | if has_targets: 28 | data["tgt_top_pred_likelihood_diff"] = data.apply( 29 | lambda x: np.exp(x.prediction_log_likelihood[0]) - np.exp(x.target_log_likelihood[0]), axis=1 30 | ) 31 | else: 32 | data["tgt_top_pred_likelihood_diff"] = data.apply(lambda x: np.exp(x.prediction_log_likelihood[0]), axis=1) 33 | 34 | indices_to_keep = select_indices_on_quantiles( 35 | data=data["tgt_top_pred_likelihood_diff"].to_numpy(), num_qs=num_qs, is_data_sorted=False 36 | ) 37 | 38 | target_quantiles = np.arange(0, num_qs + 1) / num_qs 39 | data_to_keep = data.iloc[indices_to_keep].copy() 40 | data_to_keep.reset_index(inplace=True, drop=True) 41 | data_to_keep["quantile"] = target_quantiles 42 | data_to_keep["empirical_quantile"] = [ 43 | sum([diff <= tgt_val for diff in data["tgt_top_pred_likelihood_diff"]]) 44 | / data["tgt_top_pred_likelihood_diff"].shape[0] 45 | for tgt_val in data_to_keep["tgt_top_pred_likelihood_diff"] 46 | ] 47 | data_to_keep["empirical_quantile"][0] = 0 48 | data_to_keep = data_to_keep[["id", "tgt_top_pred_likelihood_diff", "quantile", "empirical_quantile"]] 49 | print("[sanity check] Target quantiles:", data_to_keep["quantile"].to_list()) 50 | print("[sanity check] Observed quantiles:", data_to_keep["empirical_quantile"].to_list()) 51 | 52 | if testing_output_parent_dir is not None: 53 | data_save_path = "resample_input_stats.jsonl.gz" 54 | data_save_path = os.path.join(testing_output_parent_dir, data_save_path) 55 | data_to_keep.to_json(data_save_path, orient="records", lines=True, compression="gzip") 56 | 57 | ids_to_keep = data_to_keep["id"].to_numpy() 58 | 59 | return ids_to_keep 60 | -------------------------------------------------------------------------------- /src/datamodules/utils/triplet_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | class TripletUtils(object): 5 | @staticmethod 6 | def convert_text_sequence_to_text_triples(text, verbose=True, return_set=True): 7 | text_parts = [element.strip() for element in re.split(r"|||", text) if element.strip()] 8 | if verbose and len(text_parts) % 3 != 0: 9 | print(f"Textual sequence: ```{text}``` does not follow the , , , format!") 10 | 11 | text_triples = [tuple(text_parts[i : i + 3]) for i in range(0, len(text_parts) - 2, 3)] 12 | 13 | if not return_set: 14 | return text_triples 15 | 16 | unique_text_triples = set(text_triples) 17 | 18 | if verbose and len(unique_text_triples) != len(text_triples): 19 | print(f"Textual sequence: ```{text}``` has duplicated triplets!") 20 | 21 | return unique_text_triples 22 | 23 | @staticmethod 24 | def triples_to_output_format(triples): 25 | output_triples = [] 26 | 27 | for t in triples: 28 | sub, rel, obj = t 29 | formatted_triple = "{} {}{} {}{} {}{}".format( 30 | " ", sub.strip(), " ", rel.strip(), " ", obj.strip(), " " 31 | ) 32 | output_triples.append(formatted_triple) 33 | 34 | output = "".join(output_triples) 35 | return output 36 | 37 | @staticmethod 38 | def process_triple_of_ids(triple, ent_mapping, rel_mapping, query_wikidata, allow_labels): 39 | """returns match_status, id form, surface form, provenance """ 40 | if len(triple) != 3: 41 | raise Exception("Invalid triple:", triple) 42 | 43 | head_id, rel_id, tail_id = triple 44 | 45 | head_s, head_p = ent_mapping.get_from_wikidata_id( 46 | head_id, return_provenance=True, query_wikidata=query_wikidata, allow_labels=allow_labels 47 | ) 48 | tail_s, tail_p = ent_mapping.get_from_wikidata_id( 49 | tail_id, return_provenance=True, query_wikidata=query_wikidata, allow_labels=allow_labels 50 | ) 51 | 52 | rel_s, rel_p = rel_mapping.get_from_wikidata_id( 53 | rel_id, return_provenance=True, query_wikidata=query_wikidata, allow_labels=allow_labels 54 | ) 55 | 56 | surface_form = [head_s, rel_s, tail_s] 57 | provenance = [head_p, rel_p, tail_p] 58 | 59 | if head_p is None or rel_p is None or tail_p is None: 60 | status = "no_match" 61 | elif head_p == "en_label" or rel_p == "en_label" or tail_p == "en_label": 62 | status = "label" 63 | elif head_p == "en_title" and rel_p == "en_title" and tail_p == "en_title": 64 | status = "title" 65 | else: 66 | raise Exception("Invalid provenance") 67 | 68 | return status, triple, surface_form, provenance 69 | -------------------------------------------------------------------------------- /src/hparam_search/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-dlab/understanding-decoding/f0976e67964e5d083e90f2ff0b2a1662aa240ba3/src/hparam_search/__init__.py -------------------------------------------------------------------------------- /src/hparam_search/launchers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_launcher import BaseLauncher 2 | from .local_launcher import LocalLauncher -------------------------------------------------------------------------------- /src/hparam_search/launchers/base_launcher.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | 4 | class BaseLauncher(ABC): 5 | def run_trial(self, **kwargs): 6 | raise NotImplementedError() -------------------------------------------------------------------------------- /src/hparam_search/launchers/local_launcher.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | 4 | import src.utils.general as utils 5 | from . import BaseLauncher 6 | 7 | log = utils.get_logger(__name__) 8 | 9 | 10 | class LocalLauncher(BaseLauncher): 11 | def __init__(self, work_dir): 12 | super().__init__() 13 | self.work_dir = work_dir 14 | 15 | def prepare_cmd(self, script, config): 16 | overrides = " ".join([f"{key}={value}" for key, value in config.items() if not key.startswith('--')]) 17 | other = " ".join([f"{key} {value}" for key, value in config.items() if key.startswith('--')]) 18 | if not other: 19 | return f"{script} {overrides}", None 20 | cmd = f"{script} --overrides {overrides} {other}" 21 | return cmd, config['--exp_dir'] 22 | 23 | def run_trial(self, idx, script_1, config_1, script_2=None, config_2=None): 24 | # result_code = os.system(f"python tests/sleep.py --id {idx}") 25 | # result_code = subprocess.run(["python", 'tests/sleep.py', '--id', f'{idx}']) 26 | 27 | # args = " ".join([f"{key}={value}" for key, value in config_1.items()]) 28 | # cmd = f"{script_1} {args}" 29 | cmd, eval_path = self.prepare_cmd(script_1, config_1) 30 | log.info(f"[Trial {idx}] Executing command: {cmd}") 31 | result_code = subprocess.run(cmd.split(" "), cwd=self.work_dir) 32 | 33 | if script_2 is not None: 34 | # note that subprocess.run waits until the command passes to it finishes, so we 35 | # can safely continue the evaluation subprocess. 36 | # args = " ".join([f"{key}={value}" for key, value in config_2.items()]) 37 | # cmd = f"{script_2} {args}" 38 | cmd, eval_path = self.prepare_cmd(script_2, config_2) 39 | log.info(f"[Trial {idx}] Executing command: {cmd}") 40 | result_code = subprocess.run(cmd.split(" "), cwd=self.work_dir) 41 | 42 | # TODO: Verify that best_chkpoint file exists for the training 43 | # If there is evaluation, verify that results.json exists (if run was successful) 44 | 45 | if eval_path: 46 | with open(eval_path, 'r') as f: 47 | results = json.load(f) 48 | 49 | # TODO: Read the score from json file inside evaluation directory and return it 50 | # What to do if no eval? 51 | print(results) 52 | score = 0. 53 | # score = results['unbiased-small']['corpus'] 54 | # score = results['BLEU-4__tok-13a_sm-_exp_sv_None']['corpus'] 55 | return score -------------------------------------------------------------------------------- /src/hparam_search/search/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_search import BaseSearch 2 | from .grid_search import GridSearch -------------------------------------------------------------------------------- /src/hparam_search/search/base_search.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | 4 | class BaseSearch(ABC): 5 | def run_search(self, launcher, **kwargs): 6 | raise NotImplementedError() -------------------------------------------------------------------------------- /src/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | from .reconstruction_logger import ReconstructionLogger 2 | -------------------------------------------------------------------------------- /src/mdp/__init__.py: -------------------------------------------------------------------------------- 1 | from .transition_model import TransitionModel 2 | from .utility_value_model import UtilityValueModel 3 | from .tree import Tree 4 | from .search_tree_factory import SearchTreeFactory 5 | -------------------------------------------------------------------------------- /src/mdp/evaluation_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .abstract import EvaluationModel, BatchedEvaluationModel 2 | from .detoxify_evaluation_model import DetoxifyEvaluationModel 3 | from .explicit_evaluation_model import ExplicitEvaluationModel 4 | from .likelihood_evaluation_model import LikelihoodEvaluationModel 5 | from .solubility_evaluation_model import SolubilityEvaluationModel 6 | from .oracle_genie import OracleGenIE 7 | from .noisy_oracle_genie import NoisyOracleGenIE 8 | from .oracle_detoxify import OracleDetoxify 9 | from .noisy_oracle_detoxify import NoisyOracleDetoxify 10 | from .oracle_solubility import OracleSolubility 11 | from .noisy_oracle_solubility import NoisyOracleSolubility 12 | from .oracle_mt import OracleMT 13 | from .noisy_oracle_mt import NoisyOracleMT 14 | -------------------------------------------------------------------------------- /src/mdp/evaluation_models/abstract.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | 5 | from .utils import get_np_random_state 6 | from ..trees.abstract import Tree 7 | 8 | 9 | class EvaluationModel(ABC): 10 | def __init__(self, tree: Tree): 11 | self.tree = tree 12 | 13 | def _get_estimated_utility(self, node_id, **kwargs): 14 | """Estimated utility -- goodness -- of an internal node""" 15 | raise NotImplementedError() 16 | 17 | def _get_utility(self, node_id, **kwargs): 18 | """Utility of a terminal node""" 19 | raise NotImplementedError() 20 | 21 | def evaluate(self, node_id, **kwargs): 22 | if self.tree.is_terminal(node_id): 23 | return self._get_utility(node_id, **kwargs) 24 | 25 | return self._get_estimated_utility(node_id, **kwargs) 26 | 27 | def reset_device(self, device: torch.device): 28 | """ 29 | Reset the device on which the evaluation model is and move it to the given device. 30 | If the evaluation model does not use any device (except CPU), this method can be left empty. 31 | """ 32 | pass 33 | 34 | 35 | class BatchedEvaluationModel(ABC): 36 | def __init__(self, tree: Tree): 37 | self.tree = tree 38 | self.rank = 0 # The rank of the thread if running using multiple threads 39 | 40 | self.np_random_state = None 41 | self.reset_random_state() 42 | 43 | def evaluate(self, input_ids, **kwargs): 44 | raise NotImplementedError() 45 | 46 | def reset_device(self, device: torch.device): 47 | """ 48 | Reset the device on which the evaluation model is and move it to the given device. 49 | If the evaluation model does not use any device (except CPU), this method can be left empty. 50 | """ 51 | pass 52 | 53 | def reset_rank(self, rank): 54 | self.rank = rank 55 | 56 | def reset_random_state(self): 57 | self.np_random_state = get_np_random_state(self.rank) 58 | -------------------------------------------------------------------------------- /src/mdp/evaluation_models/detoxify_evaluation_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from detoxify import Detoxify 4 | 5 | from . import EvaluationModel 6 | from ..trees.lm_as_tree import LanguageModelAsTree 7 | 8 | 9 | class DetoxifyEvaluationModel(EvaluationModel): 10 | def __init__(self, tree: LanguageModelAsTree, model_type: str, device="cpu"): 11 | """ 12 | Evaluation model for predicting whether a textual sequence (or a list of textual sequences) is toxic. 13 | Can initialize 5 different model types from model type: 14 | - original: 15 | model trained on data from the Jigsaw Toxic Comment 16 | Classification Challenge 17 | - unbiased: 18 | model trained on data from the Jigsaw Unintended Bias in 19 | Toxicity Classification Challenge 20 | - multilingual: 21 | model trained on data from the Jigsaw Multilingual 22 | Toxic Comment Classification Challenge 23 | - original-small: 24 | lightweight version of the original model 25 | - unbiased-small: 26 | lightweight version of the unbiased model 27 | Parameters 28 | ---------- 29 | tree: instance of `LanguageModelAsTree` 30 | model_type : the name of the model (see above) 31 | device : a string tag (e.g. 'cpu', 'cuda', 'cuda:0', 'cuda:2, etc) or a torch.device object 32 | """ 33 | super().__init__(tree) 34 | self.tree: LanguageModelAsTree = tree 35 | self.device = device 36 | self.name = self.model_type = model_type 37 | self.model = Detoxify(model_type=model_type, device=device) 38 | 39 | def _get_estimated_utility(self, node_id, **kwargs): 40 | """Estimated utility -- goodness -- of an internal node""" 41 | return self._get_utility(node_id, **kwargs) 42 | 43 | def _get_utility(self, node_id, **kwargs): 44 | """Utility of a terminal node""" 45 | input_ids = node_id 46 | text = self.tree.tokenizer.decode(input_ids, skip_special_tokens=True) 47 | 48 | scores = self.model.predict(text) 49 | return 1 - torch.tensor(scores["toxicity"]) 50 | 51 | def reset_device(self, device: torch.device): 52 | self.device = device 53 | self.model.device = device 54 | self.model.model.to(device) 55 | -------------------------------------------------------------------------------- /src/mdp/evaluation_models/explicit_evaluation_model.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | from typing import Dict 4 | 5 | from .abstract import EvaluationModel 6 | from ..trees import Tree 7 | 8 | 9 | class ExplicitEvaluationModel(EvaluationModel): 10 | def __init__(self, tree: Tree, path: str, 11 | missing_value: float = 0.0, 12 | ignore_non_terminal: bool = False, 13 | ignore_terminal: bool = False): 14 | super().__init__(tree) 15 | 16 | self.missing_value = missing_value 17 | self.utilities = {} 18 | self.est_utilities = {} 19 | 20 | values = self.load_values(path) 21 | 22 | for node_id, val in values.items(): 23 | if tree.is_terminal(node_id): 24 | if not ignore_terminal: 25 | self.utilities[node_id] = val 26 | else: 27 | if not ignore_non_terminal: 28 | self.est_utilities[node_id] = val 29 | 30 | def _get_estimated_utility(self, node_id, **kwargs): 31 | return self.est_utilities.get(node_id, self.missing_value) 32 | 33 | def _get_utility(self, node_id, **kwargs): 34 | return self.utilities.get(node_id, self.missing_value) 35 | 36 | @staticmethod 37 | def load_values(path: str) -> Dict[str, float]: 38 | if path.endswith(".txt"): 39 | values = {} 40 | with open(path, "r") as f: 41 | for line in f.readlines(): 42 | parts = line.strip().split(" ") 43 | for i in range(0, len(parts), 2): 44 | values[int(parts[i])] = float(parts[i + 1]) 45 | else: 46 | with gzip.open(path, "rb") as f: 47 | values = json.loads(f.read().decode()) 48 | values = {int(k): v for k, v in values.items()} 49 | 50 | return values 51 | -------------------------------------------------------------------------------- /src/mdp/evaluation_models/likelihood_evaluation_model.py: -------------------------------------------------------------------------------- 1 | from .abstract import EvaluationModel 2 | from .. import Tree 3 | 4 | 5 | class LikelihoodEvaluationModel(EvaluationModel): 6 | def __init__(self, tree: Tree): 7 | super().__init__(tree) 8 | 9 | def _get_estimated_utility(self, node_id, **kwargs): 10 | """Estimated utility -- goodness -- of an internal node""" 11 | return self._get_utility(node_id, **kwargs) 12 | 13 | def _get_utility(self, node_id, **kwargs): 14 | """Utility of a terminal node""" 15 | assert 'likelihood' in kwargs 16 | return kwargs['likelihood'] 17 | -------------------------------------------------------------------------------- /src/mdp/evaluation_models/noisy_oracle_genie.py: -------------------------------------------------------------------------------- 1 | from . import BatchedEvaluationModel, OracleGenIE 2 | from .utils import add_noise_binary 3 | 4 | from .. import Tree 5 | 6 | 7 | class NoisyOracleGenIE(BatchedEvaluationModel): 8 | def __init__(self, noising_function_parameters: dict, tree: Tree = None): 9 | """ 10 | 11 | Parameters 12 | ---------- 13 | noising_function_parameters : contains sigma – standard_deviation of the truncated Normal noise 14 | tree: instance of `Tree` 15 | """ 16 | super().__init__(tree) 17 | self.noising_function_parameters = noising_function_parameters 18 | self.oracle = OracleGenIE(tree) 19 | 20 | def apply_noise(self, values): 21 | return add_noise_binary(values, np_random_state=self.np_random_state, **self.noising_function_parameters) 22 | 23 | def evaluate(self, input_ids, target_ids, next_token_ids, **kwargs): 24 | values = self.oracle.evaluate(input_ids, target_ids, next_token_ids) 25 | 26 | noisy_values = self.apply_noise(values) 27 | return noisy_values.to(input_ids.device) 28 | -------------------------------------------------------------------------------- /src/mdp/evaluation_models/noisy_oracle_solubility.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from . import BatchedEvaluationModel, OracleSolubility 4 | from .utils import add_noise_continuous 5 | from ..trees import Tree 6 | 7 | 8 | class NoisyOracleSolubility(BatchedEvaluationModel): 9 | def __init__(self, noising_function_parameters: dict, tree: Tree, oracle_parameters: dict): 10 | """ 11 | See the OracleDetoxify docstring. 12 | Parameters 13 | ---------- 14 | noising_function_parameters : contains probability_inversion – the probability of swapping the value between two actions 15 | """ 16 | super().__init__(tree) 17 | self.noising_function_parameters = noising_function_parameters 18 | self.oracle = OracleSolubility(tree, **oracle_parameters) 19 | 20 | def apply_noise(self, values): 21 | return torch.cat([add_noise_continuous(v, **self.noising_function_parameters)[None, :] for v in values]) 22 | 23 | def evaluate(self, input_ids, next_token_ids, **kwargs): 24 | """ 25 | 26 | Parameters 27 | ---------- 28 | input_ids : (num_samples, curr_len) 29 | next_token_ids : (num_samples, num_considered_tokens) 30 | kwargs : 31 | 32 | Returns 33 | ------- 34 | values : (num_samples, num_considered_tokens) 35 | """ 36 | values = self.oracle.evaluate(input_ids, next_token_ids) 37 | 38 | noisy_values = self.apply_noise(values) 39 | return noisy_values 40 | 41 | def reset_device(self, device: torch.device): 42 | self.oracle.reset_device(device) 43 | -------------------------------------------------------------------------------- /src/mdp/evaluation_models/oracle_detoxify.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from detoxify import Detoxify 4 | 5 | from . import BatchedEvaluationModel 6 | from ..trees import Tree 7 | 8 | 9 | class OracleDetoxify(BatchedEvaluationModel): 10 | def __init__(self, tree: Tree, model_type: str, batch_size: int, device="cpu"): 11 | """ 12 | Evaluation model for predicting whether a textual sequence (or a list of textual sequences) is toxic. 13 | Can initialize 5 different model types from model type: 14 | - original: 15 | model trained on data from the Jigsaw Toxic Comment 16 | Classification Challenge 17 | - unbiased: 18 | model trained on data from the Jigsaw Unintended Bias in 19 | Toxicity Classification Challenge 20 | - multilingual: 21 | model trained on data from the Jigsaw Multilingual 22 | Toxic Comment Classification Challenge 23 | - original-small: 24 | lightweight version of the original model 25 | - unbiased-small: 26 | lightweight version of the unbiased model 27 | Parameters 28 | ---------- 29 | tree: instance of `Tree` 30 | model_type : the name of the model (see above) 31 | batch_size : number of elements to run through the value function simultaneously 32 | device : a string tag (e.g. 'cpu', 'cuda', 'cuda:0', 'cuda:2, etc) or a torch.device object 33 | """ 34 | super().__init__(tree) 35 | self.tree = tree 36 | self.device = device 37 | self.name = self.model_type = model_type 38 | self.model = Detoxify(model_type=model_type, device=device) 39 | self.batch_size = batch_size 40 | 41 | def evaluate(self, input_ids, next_token_ids, **kwargs): 42 | """ 43 | 44 | Parameters 45 | ---------- 46 | input_ids : (num_samples, curr_len) 47 | next_token_ids : (num_samples, num_considered_tokens) 48 | kwargs : 49 | 50 | Returns 51 | ------- 52 | values : (num_samples, num_considered_tokens) 53 | """ 54 | num_considered_tokens = next_token_ids.shape[-1] 55 | next_token_ids_f = torch.flatten(next_token_ids)[:, None] 56 | 57 | input_ids_r = torch.repeat_interleave(input_ids, num_considered_tokens, dim=0) 58 | extended_input_ids = torch.cat([input_ids_r, next_token_ids_f], dim=1) 59 | 60 | text = self.tree.tokenizer.batch_decode(extended_input_ids, skip_special_tokens=True) 61 | 62 | # split sentences on batches 63 | batched_text = [text[i: i + self.batch_size] for i in range(0, len(text), self.batch_size)] 64 | toxicity_scores = [] 65 | for batch in batched_text: 66 | toxicity_scores.extend(self.model.predict(batch)["toxicity"]) 67 | 68 | values = 1 - torch.tensor(toxicity_scores).view_as(next_token_ids) 69 | 70 | return values.to(input_ids.device) 71 | 72 | def reset_device(self, device: torch.device): 73 | self.device = device 74 | self.model.device = device 75 | self.model.model.to(device) 76 | -------------------------------------------------------------------------------- /src/mdp/evaluation_models/oracle_genie.py: -------------------------------------------------------------------------------- 1 | from . import BatchedEvaluationModel 2 | 3 | import torch 4 | 5 | from .. import Tree 6 | 7 | 8 | class OracleGenIE(BatchedEvaluationModel): 9 | def __init__(self, tree: Tree = None): 10 | super().__init__(tree) 11 | 12 | def evaluate(self, input_ids, target_ids, next_token_ids, **kwargs): 13 | """ 14 | 15 | Parameters 16 | ---------- 17 | input_ids : (num_samples, curr_len) 18 | target_ids : (num_samples, max_target_len) 19 | next_token_ids : (num_samples, num_considered_tokens) 20 | kwargs : 21 | 22 | Returns 23 | ------- 24 | values : (num_samples, num_considered_tokens) 25 | """ 26 | 27 | # check if already generated tokens match 28 | curr_len = input_ids.shape[-1] 29 | if target_ids.shape[-1] <= curr_len: 30 | return torch.zeros_like(next_token_ids) 31 | prev_match = torch.all(torch.eq(input_ids, target_ids[:, :curr_len]), dim=-1, keepdim=True) 32 | 33 | # check if the next_token_id equals the next token 34 | values = torch.eq(target_ids[:, curr_len][:, None], next_token_ids) 35 | 36 | # check if prev_tokens_match and next_token_matches 37 | values = (prev_match & values).long() 38 | 39 | return values 40 | -------------------------------------------------------------------------------- /src/mdp/evaluation_models/oracle_mt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from . import BatchedEvaluationModel 5 | 6 | from .. import Tree 7 | from ...metrics import BLEUScore 8 | 9 | 10 | class OracleMT(BatchedEvaluationModel): 11 | def __init__(self, bleu_score: BLEUScore, tree: Tree = None, operate_on: str = "ids"): 12 | """ 13 | Parameters 14 | ---------- 15 | tree: instance of `Tree` 16 | operate_on: str, either "ids" or "text" 17 | """ 18 | super().__init__(None) 19 | self.bleu_score = bleu_score 20 | self.tree = tree 21 | self.operate_on = operate_on 22 | 23 | def evaluate(self, input_ids, next_token_ids, target_txt, full_target_ids, num_beams=1, **kwargs): 24 | """ 25 | 26 | Parameters 27 | ---------- 28 | input_ids : (num_samples, curr_len) 29 | next_token_ids : (num_samples, num_considered_tokens) 30 | target_txt : (num_samples) target text per sample 31 | full_target_ids : (num_samples) target ids per sample 32 | num_beams : the number of beams used by the decoding algorithm 33 | kwargs : 34 | 35 | Returns 36 | ------- 37 | values : (num_samples, num_considered_tokens) 38 | """ 39 | # ToDo: Can be speed up if necessary, by computing the BLEU per prefix and incorporating the value 40 | # contributed by the next token 41 | num_considered_tokens = next_token_ids.shape[-1] 42 | next_token_ids_f = torch.flatten(next_token_ids)[:, None] 43 | 44 | # Get input texts 45 | input_ids_r = torch.repeat_interleave(input_ids, num_considered_tokens, dim=0) 46 | extended_input_ids = torch.cat([input_ids_r, next_token_ids_f], dim=1) 47 | 48 | if self.operate_on == "text": 49 | # Get current and target text 50 | text = self.tree.tokenizer.batch_decode(extended_input_ids, skip_special_tokens=True) 51 | target_txts = np.repeat(target_txt, num_considered_tokens * num_beams) 52 | values = np.array([self.bleu_score.compute([h], [[r[:len(h)]]], return_score_only=True) if h != "" else 0 53 | for h, r in zip(text, target_txts)]) 54 | 55 | elif self.operate_on == "ids": 56 | target_ids_r = torch.repeat_interleave(full_target_ids, num_considered_tokens * num_beams, dim=0) 57 | values = [] 58 | for h, r in zip(extended_input_ids, target_ids_r): 59 | h = h.tolist() # remove padding 60 | r = r[r != 1][:len(h)].tolist() # Remove padding and get only the first len(h) tokens 61 | values.append(self.bleu_score.compute([" ".join(map(str, h))], 62 | [[" ".join(map(str, r))]], 63 | return_score_only=True)) 64 | else: 65 | raise ValueError("Unknown type to operate on: {}".format(self.operate_on)) 66 | 67 | 68 | values = torch.tensor(values).view_as(next_token_ids) 69 | 70 | return values.to(input_ids.device) 71 | -------------------------------------------------------------------------------- /src/mdp/evaluation_models/oracle_solubility.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline 4 | 5 | from . import BatchedEvaluationModel 6 | from ..trees import Tree 7 | 8 | 9 | class OracleSolubility(BatchedEvaluationModel): 10 | def __init__(self, tree: Tree, batch_size: int, device="cpu"): 11 | """ 12 | Class for scoring a protein sequence (or a list of textual sequences) with its solubility. 13 | Parameters 14 | ---------- 15 | tree: instance of `Tree` 16 | batch_size : number of elements to run through the value function simultaneously 17 | device : a string tag (e.g. 'cpu', 'cuda', 'cuda:0', 'cuda:2, etc) or a torch.device object 18 | """ 19 | super().__init__(tree) 20 | self.tree = tree 21 | self.device = device 22 | self.model = TextClassificationPipeline( 23 | model=AutoModelForSequenceClassification.from_pretrained("Rostlab/prot_bert_bfd_membrane"), 24 | tokenizer=AutoTokenizer.from_pretrained("Rostlab/prot_bert_bfd_membrane"), device=self.device) 25 | self.batch_size = batch_size 26 | 27 | def evaluate(self, input_ids, next_token_ids, **kwargs): 28 | """ 29 | Parameters 30 | ---------- 31 | input_ids : (num_samples, curr_len) 32 | next_token_ids : (num_samples, num_considered_tokens) 33 | kwargs : 34 | 35 | Returns 36 | ------- 37 | values : (num_samples, num_considered_tokens) 38 | """ 39 | num_considered_tokens = next_token_ids.shape[-1] 40 | next_token_ids_f = torch.flatten(next_token_ids)[:, None] 41 | input_ids_r = torch.repeat_interleave(input_ids, num_considered_tokens, dim=0) 42 | extended_input_ids = torch.cat([input_ids_r, next_token_ids_f], dim=1) 43 | text = self.tree.tokenizer.batch_decode(extended_input_ids, skip_special_tokens=True) 44 | 45 | # split sentences on batches 46 | batched_text = [text[i : i + self.batch_size] for i in range(0, len(text), self.batch_size)] 47 | solubility_scores = [] 48 | for batch in batched_text: 49 | score = self.model(batch)[0]["score"] 50 | solubility_scores.extend([score]) 51 | 52 | values = 1 - torch.tensor(solubility_scores).view_as(next_token_ids) 53 | return values.to(input_ids.device) 54 | 55 | def reset_device(self, device: torch.device): 56 | self.device = device 57 | self.model.device = device 58 | self.model.model.to(device) 59 | -------------------------------------------------------------------------------- /src/mdp/evaluation_models/solubility_evaluation_model.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline 2 | import numpy as np 3 | import torch 4 | 5 | from . import EvaluationModel 6 | from ..trees.lm_as_tree import LanguageModelAsTree 7 | 8 | 9 | class SolubilityEvaluationModel(EvaluationModel): 10 | def __init__(self, tree: LanguageModelAsTree, device="cpu"): 11 | """ 12 | Class for scoring a protein sequence (or a list of textual sequences) with its solubility. 13 | Parameters 14 | ---------- 15 | device : number of cuda device integer (e.g. '0', '1', '2' etc) 16 | 17 | Parameters 18 | ---------- 19 | tree: instance of `LanguageModelAsTree` 20 | device : a string tag (e.g. 'cpu', 'cuda', 'cuda:0', 'cuda:2, etc) or a torch.device object 21 | """ 22 | 23 | super().__init__(tree) 24 | self.tree: LanguageModelAsTree = tree 25 | self.device = device 26 | self.model = TextClassificationPipeline( 27 | model=AutoModelForSequenceClassification.from_pretrained("Rostlab/prot_bert_bfd_membrane"), 28 | tokenizer=AutoTokenizer.from_pretrained("Rostlab/prot_bert_bfd_membrane"), device=self.device) 29 | 30 | def _get_estimated_utility(self, node_id, **kwargs): 31 | """Estimated utility -- goodness -- of an internal node""" 32 | return self._get_utility(node_id, **kwargs) 33 | 34 | def _get_utility(self, node_id, **kwargs): 35 | """Utility of a terminal node""" 36 | input_ids = node_id 37 | text = self.tree.tokenizer.batch_decode(input_ids, skip_special_tokens=True) 38 | #text = ' '.join([re.sub(r"[UZOB]", "X", sequence) for sequence in text[i]]) 39 | 40 | score = self.model(text) 41 | soluble_score = score[0]['score'] 42 | return 1 - soluble_score # non_solubility 43 | 44 | def reset_device(self, device: torch.device): 45 | self.device = device 46 | self.model.device = device 47 | self.model.model.to(device) 48 | -------------------------------------------------------------------------------- /src/mdp/evaluation_models/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import torch 5 | from scipy.stats import truncnorm, bernoulli 6 | 7 | import src.utils.general as general_utils 8 | 9 | log = general_utils.get_logger(__name__) 10 | 11 | 12 | def get_np_random_state(rank, debug_info=True): 13 | seed = 123 14 | seed += rank 15 | 16 | if debug_info: 17 | msg = f"Random state: rank -> {rank}, seed -> {seed}" 18 | log.info(msg) 19 | print(msg) 20 | 21 | rng_state = np.random.RandomState(seed) 22 | return rng_state 23 | 24 | 25 | def truncated_normal(min_v, max_v, mean, sigma, random_state: Union[np.random.RandomState, None]): 26 | return ( 27 | truncnorm(a=(min_v - mean) / sigma, b=(max_v - mean) / sigma, loc=mean, scale=sigma) 28 | .rvs(mean.shape[-1], random_state=random_state) 29 | .squeeze() 30 | ) 31 | 32 | 33 | def add_noise_binary(value_scores, sigma: float, np_random_state: Union[np.random.RandomState, None]): 34 | assert value_scores.ndim == 2 # Input must be a batch 35 | 36 | if sigma == 0: 37 | return value_scores 38 | 39 | dtype = torch.float32 40 | device = "cpu" 41 | if isinstance(value_scores, torch.Tensor): 42 | device = value_scores.device 43 | dtype = value_scores.dtype 44 | value_scores = value_scores.cpu() 45 | 46 | min_v = 0.0 47 | max_v = 1.0 48 | new_vector = [] 49 | for v in value_scores: 50 | new_vector.append(truncated_normal(min_v, max_v, v, sigma, random_state=np_random_state)) 51 | new_vector = np.array(new_vector) 52 | new_vector = new_vector.reshape(value_scores.shape) 53 | return torch.tensor(new_vector, dtype=dtype, device=device) 54 | 55 | 56 | def _biased_coin_flip(p: float, random_state: Union[np.random.RandomState, None]): 57 | return bernoulli(p).rvs(1, random_state=random_state)[0] 58 | 59 | 60 | def add_noise_continuous( 61 | value_scores, probability_inversion: float, np_random_state: Union[np.random.RandomState, None] 62 | ): 63 | if probability_inversion == 0: 64 | return value_scores 65 | 66 | new_vector = np.zeros(value_scores.shape[0]) 67 | indices = list(range(value_scores.shape[0])) 68 | swapped_indices = [] 69 | for i in indices: 70 | if i in swapped_indices: 71 | continue 72 | if _biased_coin_flip(probability_inversion, random_state=np_random_state) and i != indices[-1]: 73 | available_indices = [idx for idx in indices[i + 1 :] if idx not in swapped_indices] 74 | if len(available_indices) == 0: 75 | continue 76 | if np_random_state is not None: 77 | idx_to_swap = np_random_state.choice(available_indices) 78 | else: 79 | idx_to_swap = np.random.choice(available_indices) 80 | new_vector[i] = value_scores[idx_to_swap] 81 | new_vector[idx_to_swap] = value_scores[i] 82 | swapped_indices.append(i) 83 | swapped_indices.append(idx_to_swap) 84 | else: 85 | new_vector[i] = value_scores[i] 86 | return torch.FloatTensor(new_vector) 87 | -------------------------------------------------------------------------------- /src/mdp/oracles/__init__.py: -------------------------------------------------------------------------------- 1 | from .perfect_oracle import PerfectOracle 2 | -------------------------------------------------------------------------------- /src/mdp/oracles/perfect_oracle.py: -------------------------------------------------------------------------------- 1 | from ..evaluation_models import ExplicitEvaluationModel 2 | from ..trees import Tree 3 | 4 | 5 | class PerfectOracle(ExplicitEvaluationModel): 6 | def __init__(self, tree: Tree, path: str): 7 | super().__init__(tree=tree, path=path, ignore_non_terminal=True, ignore_terminal=False) 8 | self.propagate_utilities() 9 | 10 | def propagate_utility_from_node(self, node_id): 11 | utility = self.utilities[node_id] 12 | curr_node_id = self.tree.get_parent(node_id) 13 | 14 | while curr_node_id != self.tree.NO_PARENT: 15 | self.est_utilities[curr_node_id] = utility 16 | curr_node_id = self.tree.get_parent(curr_node_id) 17 | 18 | def propagate_utilities(self): 19 | for node_id in self.utilities: 20 | self.propagate_utility_from_node(node_id) 21 | -------------------------------------------------------------------------------- /src/mdp/tree.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | 4 | # class NodeTree(object): 5 | # def __init__(self, id: int = None, data: dict = None, is_leaf: bool = None): 6 | # self.id = id 7 | # self.data = data 8 | # self.is_leaf = is_leaf 9 | # 10 | # def __hash__(self): 11 | # return hash(str(self.id)) 12 | # 13 | # def __eq__(self, other): 14 | # return self.id == other.id 15 | # 16 | # def __repr__(self): 17 | # node_repr = f"{self.id}" 18 | # return node_repr 19 | 20 | 21 | class Tree: 22 | def __init__(self, structure: dict = None): 23 | if structure: 24 | self.structure = structure 25 | else: 26 | self.structure = defaultdict(lambda: defaultdict(float)) 27 | 28 | non_root_states = set([nbs_ids for node_id, nbs in self.structure.items() for nbs_ids in nbs.keys()]) 29 | self.non_terminal_states = set(self.structure.keys()) 30 | self.terminal_states = non_root_states - self.non_terminal_states 31 | # self.terminal_states = [self.get_node_from_id(str(i)) for i in self.terminal_states] 32 | 33 | def get_root(self): 34 | return list(self.structure.keys())[0] 35 | 36 | # def get_node_from_id(self, node_id): 37 | # if hasattr(self, 'node_dict'): 38 | # return self.node_dict[node_id] 39 | # non_root_states = set([nbs_ids for node_id, nbs in self.structure.items() for nbs_ids in nbs.keys()]) 40 | # self.node_dict = {k.id: k for k in non_root_states} 41 | # root = self.get_root() 42 | # self.node_dict[root.id] = root 43 | # return self.node_dict[node_id] 44 | 45 | def terminals_from_node(self, node): 46 | if node not in self.structure: 47 | return [node] 48 | else: 49 | terminals = [] 50 | for child in self.structure[node]: 51 | terminals.extend(self.terminals_from_node(child)) 52 | return terminals 53 | 54 | def get_neighbours(self, state_id): 55 | return self.structure.get(state_id, {}) 56 | 57 | def nb_nodes(self): 58 | non_root_states = set([nbs_ids for node_id, nbs in self.structure.items() for nbs_ids in nbs.keys()]) 59 | return len(non_root_states) + 1 60 | 61 | # def visualize(self): 62 | # def _edge_to_str(e_id, **kwargs): 63 | # desc = "\n".join([f"{key}: {value:.3f}" if isinstance(value, float) else f"{key}: {value}" for key, value in 64 | # kwargs.items()]) 65 | # desc = desc.strip('\n') 66 | # return (f"{e_id}\n{desc}") 67 | # 68 | # graph = pygraphviz.AGraph(directed=True) 69 | # 70 | # # add root 71 | # root = self.get_root() 72 | # graph.add_node(root.id, label=str(root), color="green") 73 | # 74 | # for node in sorted(self.non_terminal_states): 75 | # children = self.get_neighbours(node) 76 | # for child, tr_p in children.items(): 77 | # graph.add_node(child.id, 78 | # label=str(child), 79 | # color="red") 80 | # 81 | # # p = tr_m.get_transition_prob(node_id, child_id) 82 | # # r = r_m.get_reward(node_id, child_id) 83 | # graph.add_edge(node.id, child.id, label=edge_to_str(e_id=child.id)) 84 | # 85 | # return graph 86 | -------------------------------------------------------------------------------- /src/mdp/trees/__init__.py: -------------------------------------------------------------------------------- 1 | from .abstract import Tree 2 | from .explicit_tree import ExplicitTree 3 | from .lm_as_tree import LanguageModelAsTree 4 | from .analysis_tree import AnalysisTree 5 | -------------------------------------------------------------------------------- /src/mdp/trees/lm_as_tree.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import pytorch_lightning as pl 4 | 5 | from transformers import PreTrainedModel, PreTrainedTokenizer 6 | from ..trees import Tree 7 | 8 | 9 | class LanguageModelAsTree(Tree): 10 | def __init__(self, lm_model: Union[PreTrainedModel, pl.LightningModule], tokenizer: PreTrainedTokenizer): 11 | super().__init__() 12 | 13 | if isinstance(lm_model, pl.LightningModule): 14 | lm_model = lm_model.model 15 | 16 | self.lm_model = lm_model 17 | self.tokenizer = tokenizer 18 | self.root_id = tuple() 19 | 20 | def get_children(self, node_id: tuple) -> Dict[tuple, float]: 21 | raise NotImplemented() # Not necessary atm 22 | 23 | def get_transition_prob(self, src_id: str, trg_id: str) -> float: 24 | raise NotImplemented() # Not necessary atm 25 | 26 | def get_parent(self, node_id: tuple) -> tuple: 27 | if node_id == self.root_id: 28 | return self.NO_PARENT 29 | 30 | return node_id[:-1] 31 | 32 | def is_terminal(self, node_id: tuple) -> bool: 33 | """ToDo: Verify that this assumption is correct for all decoding algorithms. Update it if necessary""" 34 | return node_id[-1] == self.tokenizer.eos_token_id 35 | -------------------------------------------------------------------------------- /src/mdp/utility_value_model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import gzip 3 | import json 4 | 5 | from src.mdp.transition_model import TransitionModel 6 | 7 | 8 | class UtilityValueModel: 9 | def __init__(self, transition_model: TransitionModel, utilities: dict = None, path: str = None): 10 | self.transition_model = transition_model 11 | 12 | if path: 13 | utilities = UtilityValueModel.load_from_file(path) 14 | 15 | self.utilities = utilities 16 | self.expected_utilities = {} 17 | self._compute_expected_utilities(self.transition_model.get_root(), self.expected_utilities) 18 | 19 | def _compute_expected_utilities(self, state_id, results): 20 | r = 0 21 | for child_id, transition_prob in self.transition_model.get_neighbours(state_id).items(): 22 | branch_reward = self.get_utility(child_id) + self._compute_expected_utilities(child_id, results) 23 | r += transition_prob * branch_reward 24 | 25 | results[state_id] = r 26 | return r 27 | 28 | def get_utility(self, node_id): 29 | """ 30 | Reward only when reaching a terminal node. 31 | Inner node do not have reward, they have expected reward. 32 | """ 33 | return self.utilities.get(node_id, 0.0) 34 | 35 | def get_expected_utility(self, node_id): 36 | """ 37 | For terminal node return the utility 38 | For inner node return the expected utility, 39 | """ 40 | return self.expected_utilities.get(node_id, 0.0) 41 | 42 | def save_to_file(self, path): 43 | # to_save = {k: v for k, v in self.utilities.items()} 44 | with gzip.open(path, "wb") as f: 45 | f.write(json.dumps(self.utilities).encode()) 46 | 47 | @staticmethod 48 | def load_from_file(path): 49 | if path.endswith(".txt"): 50 | utilities = {} 51 | with open(path, "r") as f: 52 | for line in f.readlines(): 53 | parts = line.strip().split(" ") 54 | for i in range(0, len(parts), 2): 55 | utilities[int(parts[i])] = float(parts[i + 1]) 56 | else: 57 | with gzip.open(path, "rb") as f: 58 | utilities = json.loads(f.read().decode()) 59 | utilities = {int(k): v for k, v in utilities.items()} 60 | 61 | return utilities 62 | -------------------------------------------------------------------------------- /src/mdp/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | 4 | import numpy as np 5 | import scipy.stats as stats 6 | 7 | from IPython.display import Image 8 | 9 | 10 | def max_depth(struct, node, depth=0): 11 | if node not in struct: 12 | return depth 13 | left_first_child = list(struct[node].keys())[0] 14 | return max_depth( 15 | struct, left_first_child, depth + 1 16 | ) # we know that the left_first_child will be instantiated as long as we are not to the bottom 17 | 18 | 19 | def swap_position(li, pos1, pos2): 20 | li[pos1], li[pos2] = li[pos2], li[pos1] 21 | return li 22 | 23 | 24 | def get_list_with_tau(la, tau, fixed_pos=None, eps=0.1, max_iter=1000): 25 | if tau < 0: # require less swapping to work on positive tau: use positive tau on reversed sequence 26 | la = la[::-1] 27 | tau = -tau 28 | 29 | lb = copy.deepcopy(la) 30 | nb_iter = 0 31 | while np.abs(stats.kendalltau(la, lb)[0] - tau) > eps: 32 | nb_iter += 1 33 | if nb_iter > max_iter: 34 | break 35 | 36 | decreased_range = 1 if fixed_pos else 0 37 | r = range(len(la) - decreased_range) 38 | id_a = random.choice(r) 39 | id_b = random.choice(r) 40 | if id_a == fixed_pos or id_b == fixed_pos: 41 | continue 42 | lb = swap_position(lb, id_a, id_b) 43 | 44 | return lb 45 | 46 | 47 | # Tree Visualization 48 | 49 | def draw_graph(graph): 50 | return Image(graph.draw(format='png', prog='dot')) 51 | 52 | 53 | def save_graph(graph, output_file): 54 | return graph.draw(output_file, prog="dot") 55 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .bleu import BLEUScore 2 | from .triplet_set_f1 import TSF1 3 | from .triplet_set_precision import TSPrecision 4 | from .triplet_set_recall import TSRecall 5 | from .detoxify import DetoxifyScore 6 | from .protbert import SolubleScore 7 | -------------------------------------------------------------------------------- /src/metrics/detoxify.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import numpy as np 4 | import tqdm 5 | from detoxify import Detoxify 6 | 7 | 8 | class DetoxifyScore: 9 | def __init__(self, model_type, device="cpu"): 10 | """ 11 | Class for predicting whether a textual sequence (or a list of textual sequences) is toxic. 12 | Can initialize 5 different model types from model type: 13 | - original: 14 | model trained on data from the Jigsaw Toxic Comment 15 | Classification Challenge 16 | - unbiased: 17 | model trained on data from the Jigsaw Unintended Bias in 18 | Toxicity Classification Challenge 19 | - multilingual: 20 | model trained on data from the Jigsaw Multilingual 21 | Toxic Comment Classification Challenge 22 | - original-small: 23 | lightweight version of the original model 24 | - unbiased-small: 25 | lightweight version of the unbiased model 26 | Parameters 27 | ---------- 28 | model_type : the name of the model (see above) 29 | device : a string tag (e.g. 'cpu', 'cuda', 'cuda:0', 'cuda:2, etc) or a torch.device object 30 | """ 31 | self.device = device 32 | self.name = self.model_type = model_type 33 | self.model = Detoxify(model_type=model_type, device=device) 34 | 35 | # used as cache 36 | self.scores_per_datapoint = None 37 | 38 | def compute(self, text: Union[str, List[str]]) -> List[float]: 39 | """ 40 | 41 | Parameters 42 | ---------- 43 | text : a string or a list of strings to score 44 | 45 | Returns 46 | ------- 47 | scores : a list of toxicity scores for each of the textual sequences 48 | """ 49 | if isinstance(text, str): 50 | text = [text] 51 | 52 | # TODO batches? 53 | scores = [] 54 | for t in tqdm.tqdm(text): 55 | scores.append(self.model.predict([t])["toxicity"][0]) 56 | return scores 57 | 58 | def compute_from_dataset(self, dataset, use_if_cached=False, per_datapoint=False, per_beam=False): 59 | assert sum([per_datapoint, per_beam]) <= 1 60 | 61 | if per_beam: 62 | hypotheses = [dataset.get_predictions(item, top_pred_only=False) for item in dataset] 63 | return [self.compute(hyp) for hyp in hypotheses] 64 | 65 | if not (use_if_cached and self.scores_per_datapoint): 66 | hypotheses = [dataset.get_predictions(item, top_pred_only=True) for item in dataset] 67 | self.scores_per_datapoint = self.compute(hypotheses) 68 | 69 | if per_datapoint: 70 | return self.scores_per_datapoint 71 | 72 | return np.mean(self.scores_per_datapoint) 73 | 74 | # def compute_from_dataset_resample_exp(self, dataset): 75 | # hypotheses = [dataset.get_predictions(item, top_pred_only=False) for item in dataset] 76 | # return [self.compute(hyp) for hyp in hypotheses] 77 | -------------------------------------------------------------------------------- /src/metrics/protbert.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline 2 | import re 3 | from typing import List, Union 4 | 5 | import numpy as np 6 | 7 | class SolubleScore: 8 | def __init__(self, name, device=0): 9 | """ 10 | Class for scoring a protein sequence (or a list of textual sequences) with its solubility. 11 | Parameters 12 | ---------- 13 | device : number of cuda device integer (e.g. '0', '1', '2' etc) 14 | """ 15 | self.device = device 16 | self.name = name 17 | self.model = TextClassificationPipeline( 18 | model=AutoModelForSequenceClassification.from_pretrained("Rostlab/prot_bert_bfd_membrane"), 19 | tokenizer=AutoTokenizer.from_pretrained("Rostlab/prot_bert_bfd_membrane"), device=self.device) 20 | 21 | # used as cache 22 | self.scores_per_datapoint = None 23 | 24 | def compute(self, text: Union[str, List[str]]) -> List[float]: 25 | """ 26 | Parameters 27 | ---------- 28 | text : a string or a list of strings to score 29 | 30 | Returns 31 | ------- 32 | scores : a list of toxicity scores for each each of the textual sequences 33 | """ 34 | if isinstance(text, str): 35 | text = [text] 36 | 37 | scores = [] 38 | for i in range(len(text)): 39 | #Create or load sequences and map rarely occured amino acids 40 | sequences = ' '.join([re.sub(r"[UZOB]", "X", sequence) for sequence in text[i]]) 41 | score = self.model(sequences) 42 | soluble_score = score[0]['score'] 43 | scores.append(soluble_score) 44 | return scores 45 | 46 | def compute_from_dataset(self, dataset, use_if_cached=False, per_datapoint=False, per_beam=False): 47 | assert sum([per_datapoint, per_beam]) <= 1 48 | 49 | if per_beam: 50 | hypotheses = [dataset.get_predictions(item, top_pred_only=False) for item in dataset] 51 | return [self.compute(hyp) for hyp in hypotheses] 52 | 53 | if not (use_if_cached and self.scores_per_datapoint): 54 | hypotheses = [dataset.get_predictions(item, top_pred_only=True) for item in dataset] 55 | self.scores_per_datapoint = self.compute(hypotheses) 56 | 57 | if per_datapoint: 58 | return self.scores_per_datapoint 59 | 60 | return np.mean(self.scores_per_datapoint) 61 | -------------------------------------------------------------------------------- /src/metrics/triplet_set_precision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import Metric 3 | from typing import List, Union 4 | 5 | 6 | class TSPrecision(Metric): 7 | def __init__(self, dist_sync_on_step=False): 8 | super().__init__(dist_sync_on_step=dist_sync_on_step) 9 | 10 | self.add_state("total_correct", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum") 11 | self.add_state("total_predicted", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum") 12 | self.name = "triplet_set_precision" 13 | 14 | @staticmethod 15 | def _process_test_sample(target_triples: set, pred_triples: set): 16 | num_matched = len(target_triples.intersection(pred_triples)) 17 | num_predicted = len(pred_triples) 18 | num_target = len(target_triples) 19 | 20 | return num_matched, num_predicted, num_target 21 | 22 | def update(self, preds: list, targets: list): 23 | assert len(preds) == len(targets) 24 | 25 | num_correct = [] 26 | num_predicted = [] 27 | 28 | for t, p in zip(targets, preds): 29 | n_matched, n_predicted, n_target = TSPrecision._process_test_sample(t, p) 30 | 31 | num_correct.append(n_matched) 32 | num_predicted.append(n_predicted) 33 | 34 | num_correct = torch.tensor(num_correct).long() 35 | num_predicted = torch.tensor(num_predicted).long() 36 | 37 | self.total_correct += torch.sum(num_correct) 38 | self.total_predicted += torch.sum(num_predicted) 39 | 40 | @staticmethod 41 | def _compute(correct, predicted, use_tensor=False) -> Union[float, torch.Tensor]: 42 | if predicted == 0: 43 | return torch.tensor(0).float() if use_tensor else 0.0 44 | 45 | correct = correct.float() if use_tensor else float(correct) 46 | precision = correct / predicted 47 | 48 | return precision 49 | 50 | def compute(self): 51 | if self.total_predicted == 0: 52 | return torch.tensor(0).float() 53 | 54 | return self.total_correct.float() / self.total_predicted 55 | 56 | @staticmethod 57 | def compute_from_datapoint(target_triples: set, pred_triples: set) -> float: 58 | n_matched, n_predicted, _ = TSPrecision._process_test_sample(target_triples, pred_triples) 59 | return TSPrecision._compute(n_matched, n_predicted) 60 | 61 | def compute_from_dataset(self, dataset, per_datapoint=False) -> Union[float, List[float]]: 62 | targets = [dataset.get_text_triples(dataset.get_targets(item)) for item in dataset] 63 | preds = [dataset.get_text_triples(dataset.get_predictions(item)) for item in dataset] 64 | 65 | if per_datapoint: 66 | precision = [TSPrecision.compute_from_datapoint(t, p) for t, p in zip(targets, preds)] 67 | else: 68 | num_correct = 0 69 | num_predicted = 0 70 | 71 | for t, p in zip(targets, preds): 72 | n_matched, n_predicted, _ = TSPrecision._process_test_sample(t, p) 73 | 74 | num_correct += n_matched 75 | num_predicted += n_predicted 76 | 77 | precision = TSPrecision._compute(num_correct, num_predicted) 78 | 79 | return precision 80 | -------------------------------------------------------------------------------- /src/metrics/triplet_set_recall.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import Metric 3 | from typing import List, Union 4 | 5 | 6 | class TSRecall(Metric): 7 | def __init__(self, dist_sync_on_step=False): 8 | super().__init__(dist_sync_on_step=dist_sync_on_step) 9 | 10 | self.add_state("total_correct", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum") 11 | self.add_state("total_target", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum") 12 | self.name = "triplet_set_recall" 13 | 14 | @staticmethod 15 | def _process_test_sample(target_triples: set, pred_triples: set): 16 | num_matched = len(target_triples.intersection(pred_triples)) 17 | num_predicted = len(pred_triples) 18 | num_target = len(target_triples) 19 | 20 | return num_matched, num_predicted, num_target 21 | 22 | def update(self, preds: list, targets: list): 23 | assert len(preds) == len(targets) 24 | 25 | num_correct = [] 26 | num_target = [] 27 | 28 | for t, p in zip(targets, preds): 29 | n_matched, n_predicted, n_target = TSRecall._process_test_sample(t, p) 30 | 31 | num_correct.append(n_matched) 32 | # num_predicted.append(n_predicted) 33 | num_target.append(n_target) 34 | 35 | num_correct = torch.tensor(num_correct).long() 36 | num_target = torch.tensor(num_target).long() 37 | 38 | self.total_correct += torch.sum(num_correct) 39 | self.total_target += torch.sum(num_target) 40 | 41 | @staticmethod 42 | def _compute(correct, target, use_tensor=False) -> Union[float, torch.Tensor]: 43 | if target == 0: 44 | return torch.tensor(0).float() if use_tensor else 0.0 45 | 46 | correct = correct.float() if use_tensor else float(correct) 47 | recall = correct / target 48 | 49 | return recall 50 | 51 | def compute(self): 52 | if self.total_target == 0: 53 | return torch.tensor(0).float() 54 | 55 | return self.total_correct.float() / self.total_target 56 | 57 | @staticmethod 58 | def compute_from_datapoint(target_triples: set, pred_triples: set) -> float: 59 | n_matched, _, n_target = TSRecall._process_test_sample(target_triples, pred_triples) 60 | return TSRecall._compute(n_matched, n_target) 61 | 62 | def compute_from_dataset(self, dataset, per_datapoint=False) -> Union[float, List[float]]: 63 | targets = [dataset.get_text_triples(dataset.get_targets(item)) for item in dataset] 64 | preds = [dataset.get_text_triples(dataset.get_predictions(item)) for item in dataset] 65 | 66 | if per_datapoint: 67 | recall = [TSRecall.compute_from_datapoint(t, p) for t, p in zip(targets, preds)] 68 | else: 69 | num_correct = 0 70 | num_target = 0 71 | 72 | for t, p in zip(targets, preds): 73 | n_matched, _, n_target = TSRecall._process_test_sample(t, p) 74 | 75 | num_correct += n_matched 76 | num_target += n_target 77 | 78 | recall = TSRecall._compute(num_correct, num_target) 79 | 80 | return recall 81 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mbart_for_conditional_generation import MBartForConditionalGeneration 2 | from .genie import GeniePL 3 | from .gpt2_for_generation import GPT2ForGeneration 4 | from .protgpt2_for_generation import ProtGPT2ForGeneration 5 | -------------------------------------------------------------------------------- /src/models/collators/__init__.py: -------------------------------------------------------------------------------- 1 | from .rebel_collator import RebelCollator 2 | -------------------------------------------------------------------------------- /src/models/collators/rebel_collator.py: -------------------------------------------------------------------------------- 1 | class RebelCollator: 2 | def __init__(self, tokenizer, **kwargs): 3 | self.tokenizer = tokenizer 4 | self.params = kwargs 5 | 6 | def collate_fn(self, batch): 7 | collated_batch = {} 8 | 9 | for attr_name in "src", "tgt": 10 | if attr_name == "src": 11 | max_length = self.params["max_input_length"] 12 | 13 | elif attr_name == "tgt": 14 | max_length = self.params["max_output_length"] 15 | else: 16 | raise Exception(f"Unexpected attribute name `{attr_name}`!") 17 | 18 | tokenizer_output = self.tokenizer( 19 | [sample[attr_name] for sample in batch], 20 | return_tensors="pt", 21 | return_attention_mask=True, 22 | padding=self.params["padding"], 23 | max_length=max_length, 24 | truncation=self.params["truncation"], 25 | ) 26 | 27 | for k, v in tokenizer_output.items(): 28 | collated_batch["{}_{}".format(attr_name, k)] = v 29 | 30 | if self.params.get("target_padding_token_id", None) is not None: 31 | tgt_input_ids = collated_batch["tgt_input_ids"] 32 | tgt_input_ids.masked_fill_( 33 | tgt_input_ids == self.tokenizer.pad_token_id, self.params["target_padding_token_id"] 34 | ) 35 | 36 | collated_batch["raw"] = batch 37 | 38 | return collated_batch 39 | -------------------------------------------------------------------------------- /src/models/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-dlab/understanding-decoding/f0976e67964e5d083e90f2ff0b2a1662aa240ba3/src/models/utils/__init__.py -------------------------------------------------------------------------------- /src/models/utils/general.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pickle 4 | 5 | 6 | def label_smoothed_nll_loss( 7 | lprobs: torch.Tensor, 8 | target: torch.Tensor, 9 | target_attention_mask: torch.Tensor, 10 | epsilon: float, 11 | ignore_index: int = None, 12 | reduce: bool = True, 13 | ): 14 | # target.shape -> batch_size x tgt_seq_length; lprobs.shape -> batch_size x tgt_seq_length x vocabulary_size 15 | if target.dim() == lprobs.dim() - 1: 16 | target = target.unsqueeze(-1) # target.shape -> batch_size x tgt_seq_length x 1 17 | 18 | if ignore_index is not None: 19 | pad_mask = target.eq(ignore_index) 20 | target.clamp_min_(0) 21 | 22 | nll_loss = -lprobs.gather(dim=-1, index=target) # get the log prob terms corresponding to the target indices 23 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) # calculations needed for the smoothed loss 24 | 25 | nll_loss.masked_fill_(pad_mask, 0.0) 26 | smooth_loss.masked_fill_(pad_mask, 0.0) 27 | else: 28 | # TODO: make sure that the dimensions here match i.e. do you actually want the squeeze?? 29 | # TODO: If anywhere, shouldn't it be outside? (see below) 30 | nll_loss = -lprobs.gather(dim=-1, index=target) 31 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) 32 | 33 | # nll_loss = nll_loss.squeeze(-1) 34 | # smooth_loss = smooth_loss.squeeze(-1) 35 | 36 | nll_loss = nll_loss.squeeze(-1) 37 | smooth_loss = smooth_loss.squeeze(-1) 38 | 39 | if reduce: 40 | nll_loss = nll_loss.sum() 41 | smooth_loss = smooth_loss.sum() 42 | 43 | # TODO: Why is "the number of classes" (lprobs.size(-1) - 1)? What is the -1 for? Is it padding? 44 | eps_i = epsilon / (lprobs.size(-1) - 1) 45 | # TODO: Check correctness. shouldn't it be loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss 46 | # loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss 47 | loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss 48 | 49 | # Normalize the loss by diving with the number of non-padded 50 | num_tokens = target_attention_mask.sum() 51 | loss, nll_loss = loss / num_tokens, nll_loss / num_tokens 52 | 53 | return loss, nll_loss 54 | 55 | 56 | def set_seed(args): 57 | np.random.seed(args.seed) 58 | torch.manual_seed(args.seed) 59 | if args.n_gpu > 0: 60 | torch.cuda.manual_seed_all(args.seed) 61 | 62 | 63 | def chunk_it(seq, num=1): 64 | assert num > 0 65 | chunk_len = len(seq) // num 66 | chunks = [seq[i * chunk_len : i * chunk_len + chunk_len] for i in range(num)] 67 | 68 | diff = len(seq) - chunk_len * num 69 | for i in range(diff): 70 | chunks[i].append(seq[chunk_len * num + i]) 71 | 72 | return chunks 73 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-dlab/understanding-decoding/f0976e67964e5d083e90f2ff0b2a1662aa240ba3/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/additional_loggers.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from einops import rearrange 3 | 4 | 5 | def get_img_rec_table_data(ids, labels, imgs, reconstructions, step, num_samples_to_log, num_samples_per_z_to_consider): 6 | if reconstructions.dim() != imgs.dim(): 7 | reconstructions = rearrange( 8 | reconstructions, 9 | "(bs z_ns) (c h w) -> bs z_ns c h w", 10 | bs=imgs.shape[0], 11 | c=imgs.shape[1], 12 | h=imgs.shape[2], 13 | w=imgs.shape[3], 14 | ) 15 | 16 | imgs = [wandb.Image(x) for x in imgs[:num_samples_to_log]] 17 | 18 | if reconstructions.shape[1] != 1 and num_samples_per_z_to_consider != 1: 19 | reconstructions = [ 20 | [wandb.Image(x_hat) for x_hat in rec_set[:num_samples_per_z_to_consider]] 21 | for rec_set in reconstructions[:num_samples_to_log] 22 | ] 23 | else: 24 | reconstructions = [wandb.Image(x_hat) for x_hat in reconstructions[:num_samples_to_log]] 25 | 26 | columns = ["step", "ids", "label", "x", "x_hat"] 27 | data = [ 28 | [step, _id, label, x, x_hat] 29 | for _id, label, x, x_hat in zip( 30 | ids[:num_samples_to_log], 31 | labels[:num_samples_to_log], 32 | imgs, 33 | reconstructions, 34 | ) 35 | ] 36 | return columns, data 37 | 38 | 39 | def get_text_rec_table_data( 40 | ids, input_ids, input_text, reconstructions, step, num_samples_to_log, num_samples_per_z_to_consider 41 | ): 42 | if len(reconstructions) != input_ids.shape[0]: 43 | batch_size = input_ids.shape[0] 44 | num_samples_per_input = len(reconstructions) // batch_size 45 | reconstructions = [ 46 | reconstructions[i : i + num_samples_per_input] for i in range(0, len(reconstructions), batch_size) 47 | ] 48 | 49 | input_ids = input_ids[:num_samples_to_log] 50 | input_text = input_text[:num_samples_to_log] 51 | 52 | if isinstance(reconstructions[0], list): 53 | reconstructions = [rec_set[:num_samples_per_z_to_consider] for rec_set in reconstructions[:num_samples_to_log]] 54 | else: 55 | reconstructions = reconstructions[:num_samples_to_log] 56 | 57 | columns = ["step", "id", "input_ids", "input_text", "output_text"] 58 | data = [ 59 | [step, _id, tokenized_x, textual_x, output_text] 60 | for _id, tokenized_x, textual_x, output_text in zip( 61 | ids[:num_samples_to_log], 62 | input_ids, 63 | input_text, 64 | reconstructions, 65 | ) 66 | ] 67 | return columns, data 68 | 69 | 70 | def add_column_to_table_data(columns, data, new_col_name, new_col_data): 71 | # TODO: Verify correctness 72 | columns.append(new_col_name) 73 | for row, new_cell_value in zip(data, new_col_data): 74 | row.append(new_cell_value) 75 | -------------------------------------------------------------------------------- /src/utils/hydra_ad_hoc.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import OmegaConf 3 | 4 | 5 | def get_config(overrides=[], config_name="train_root.yaml", work_dir="../../", data_dir="../../data/"): 6 | configs_folder = "../../configs" 7 | default_overrides = [f"work_dir={work_dir}", f"data_dir={data_dir}"] 8 | overrides += default_overrides 9 | 10 | with hydra.initialize(version_base="1.2", config_path=configs_folder): 11 | config = hydra.compose(config_name=config_name, overrides=overrides) 12 | 13 | return config 14 | 15 | def compare_configs(config, parameters_to_comapre_keys, parameters_to_comapre_values, tag): 16 | comparison = True 17 | 18 | for i, parameter in enumerate(parameters_to_comapre_keys): 19 | if tag == "train" or tag == "run": 20 | par_config = OmegaConf.select(config, parameter) 21 | comparison = comparison and (par_config == parameters_to_comapre_values[i]) 22 | else: 23 | if "model" in parameter: 24 | par_config = OmegaConf.select(config, parameter.replace("model.", "model.hparams_overrides.")) 25 | comparison = comparison and (par_config == parameters_to_comapre_values[i]) 26 | 27 | return comparison -------------------------------------------------------------------------------- /src/utils/hydra_custom_resolvers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | from omegaconf import OmegaConf 5 | 6 | 7 | def add_args(*args): 8 | return sum(float(x) for x in args) 9 | 10 | 11 | def add_args_int(*args): 12 | return int(sum(float(x) for x in args)) 13 | 14 | 15 | def multiply_args(*args): 16 | return math.prod(float(x) for x in args) 17 | 18 | 19 | def multiply_args_int(*args): 20 | return int(math.prod(float(x) for x in args)) 21 | 22 | 23 | def floor_division(dividend, divisor): 24 | return dividend // divisor 25 | 26 | 27 | def num_files_in_directory(path): 28 | isExist = os.path.exists(path) 29 | if not isExist: 30 | return 0 31 | 32 | isDir = os.path.isdir(path) 33 | if not isDir: 34 | raise Exception(f"Path `{path}` does not correspond to a directory!") 35 | 36 | ls_dir = os.listdir(path) 37 | return len(ls_dir) 38 | 39 | def path_to_python_executable(): 40 | import sys 41 | 42 | return sys.executable 43 | 44 | def max_args(*args): 45 | return max(x for x in args) 46 | 47 | 48 | 49 | OmegaConf.register_new_resolver("add", add_args) 50 | OmegaConf.register_new_resolver("mult", multiply_args) 51 | 52 | OmegaConf.register_new_resolver("add_int", add_args_int) 53 | OmegaConf.register_new_resolver("mult_int", multiply_args_int) 54 | 55 | OmegaConf.register_new_resolver("num_files", num_files_in_directory) 56 | 57 | OmegaConf.register_new_resolver("path_to_python_executable", path_to_python_executable) 58 | OmegaConf.register_new_resolver("floor_div", floor_division) 59 | 60 | OmegaConf.register_new_resolver("max", max_args) 61 | -------------------------------------------------------------------------------- /src/utils/oracles.py: -------------------------------------------------------------------------------- 1 | import scipy.stats as stats 2 | import scipy.integrate as integrate 3 | import math 4 | import torch 5 | 6 | 7 | def bart_get_evaluation_model_target_ids(model, tokenizer, target_ids): 8 | # decoder_input_ids = model.prepare_decoder_input_ids_from_labels( 9 | # target_ids 10 | # ) 11 | eos_token = torch.ones(target_ids.shape[0], 1, dtype=target_ids.dtype, 12 | device=target_ids.device) * tokenizer.eos_token_id 13 | # last_token = torch.where(target_ids[:, -1] == tokenizer.pad_token_id, tokenizer.pad_token_id, eos_token) 14 | 15 | full_target_ids = torch.cat([eos_token, target_ids], dim=-1) 16 | return full_target_ids 17 | 18 | 19 | def compute_flip_probability(sigma, k): 20 | """ 21 | Computes the probability that at least one utility value that was previously 0 will exceed the utility value that was previously 1 when adding noise from a truncated normal distribution. 22 | 23 | Variable names follow the notation from https://en.wikipedia.org/wiki/Truncated_normal_distribution 24 | 25 | Parameters: 26 | sigma (float): std deviation 27 | k (int): number of 0-entries 28 | 29 | Returns: 30 | prob (float): the probability that there will be at least one former 0-entry that is now larger than the former 1-entry 31 | error (float): the estimated error due to the numerical integration 32 | """ 33 | normal_rv = stats.norm() 34 | 35 | def pdf_X_1(c): 36 | ksi = (c - 1) / sigma 37 | alpha = -1 / sigma 38 | beta = 0 39 | Z = normal_rv.cdf(beta) - normal_rv.cdf(alpha) 40 | return normal_rv.pdf(ksi) / (sigma * Z) 41 | 42 | def cdf_X_0(c): 43 | ksi = c / sigma 44 | alpha = 0 45 | beta = 1 / sigma 46 | Z = normal_rv.cdf(beta) - normal_rv.cdf(alpha) 47 | return (normal_rv.cdf(ksi) - normal_rv.cdf(alpha)) / Z 48 | 49 | integrand = lambda c: pdf_X_1(c) * math.pow(cdf_X_0(c), k) 50 | 51 | one_minus_flip_prob, error = integrate.quad(integrand, 0, 1) 52 | 53 | return 1 - one_minus_flip_prob, error 54 | 55 | 56 | --------------------------------------------------------------------------------