├── .gitignore ├── .pre-commit-config.yaml ├── ACKNOWLEDGEMENTS.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.txt ├── README.md ├── act ├── __init__.py ├── configs │ ├── intervention_params │ │ ├── aura.yaml │ │ ├── gaussian_ot.yaml │ │ ├── linear_ot.yaml │ │ ├── mean_ot.yaml │ │ └── none.yaml │ ├── model │ │ ├── FLUX.1-schnell.yaml │ │ ├── Meta-Llama-3-8B.yaml │ │ ├── SDXL-Lightning.yaml │ │ ├── gemma-2-2b.yaml │ │ └── gpt2.yaml │ ├── task_params │ │ ├── coco_concepts.yaml │ │ ├── coco_styles.yaml │ │ ├── diffusion_prompts.yaml │ │ ├── giraffes.yaml │ │ └── toxicity.yaml │ ├── text_generation.yaml │ ├── text_to_image_generation.yaml │ └── wandb │ │ └── act.yaml ├── datasets │ ├── __init__.py │ ├── coco_captions.py │ ├── collators.py │ ├── dummy_text2image_prompts.py │ ├── jigsaw_dataset.py │ ├── json_prompts.py │ ├── json_subsets_dataset.py │ ├── onesec_dataset.py │ ├── responses_io.py │ └── samplers.py ├── evaluations │ ├── __init__.py │ ├── calculate_clip_score.py │ ├── evaluate_0shot.py │ ├── evaluate_eleuther.py │ ├── evaluate_perplexity.py │ └── evaluate_toxicity.py ├── hooks │ ├── __init__.py │ ├── aura_hook.py │ ├── custom_exceptions.py │ ├── identity.py │ ├── intervention_hook.py │ ├── pooling_ops.py │ ├── postprocess_and_save_hook.py │ ├── responses_hook.py │ ├── return_outputs_hook.py │ └── transport.py ├── models │ ├── __init__.py │ ├── get_model.py │ └── model_with_hooks.py ├── optimal_transport │ ├── __init__.py │ ├── archs.py │ └── ot_maps.py ├── scripts │ ├── __init__.py │ ├── download_external_data.py │ ├── generate_with_hooks.py │ ├── generate_with_hooks_diffusion.py │ ├── learn_intervention.py │ └── pipeline.py └── utils │ ├── __init__.py │ ├── auroc.py │ ├── get_module_names.py │ ├── perplexity.py │ ├── quantiles.py │ └── utils.py ├── assets ├── main_figure.png ├── main_figure.svg └── main_figure.webp ├── data ├── diffusion_concept_prompts.json ├── diffusion_prompts.json ├── giraffes.json └── style_prompts.json ├── pyproject.toml ├── requirements.txt └── tests ├── __init__.py ├── configs ├── conf_test_interventions.yaml ├── hook_config.yaml ├── pipeline_test.yaml ├── responses_incremental_test.yaml └── responses_test.yaml ├── data ├── aura-toxicity-max │ └── tiny-gpt2 │ │ ├── transformer.h.0.mlp.c_proj.statedict │ │ └── transformer.h.1.mlp.c_proj.statedict ├── coco_captions_2017 │ ├── captions_train2017.json │ └── captions_val2017.json ├── jigsaw │ └── train.csv ├── prompted_gens_gpt2.jsonl ├── toxicity-responses-actadd │ └── tiny-gpt2 │ │ └── act_add │ │ ├── non-toxic │ │ ├── transformer.h.0.mlp.c_proj:0 │ │ │ └── mean │ │ │ │ └── 1.pt │ │ └── transformer.h.1.mlp.c_proj:0 │ │ │ └── mean │ │ │ └── 1.pt │ │ └── toxic │ │ ├── transformer.h.0.mlp.c_proj:0 │ │ └── mean │ │ │ └── 0.pt │ │ └── transformer.h.1.mlp.c_proj:0 │ │ └── mean │ │ └── 0.pt ├── toxicity-responses │ └── tiny-gpt2 │ │ └── jigsaw │ │ ├── non-toxic │ │ ├── transformer.h.0.mlp.c_proj:0 │ │ │ └── mean │ │ │ │ ├── 0000997932d777bf.pt │ │ │ │ ├── 000bfd0867774845.pt │ │ │ │ ├── 000eefc67a2c930f.pt │ │ │ │ ├── 000ffab30195c5e1.pt │ │ │ │ ├── 0010833a96e1f886.pt │ │ │ │ ├── 00128363e367d703.pt │ │ │ │ ├── 0015f4aa35ebe9b5.pt │ │ │ │ └── 001735f961a23fc4.pt │ │ └── transformer.h.1.mlp.c_proj:0 │ │ │ └── mean │ │ │ ├── 0000997932d777bf.pt │ │ │ ├── 000bfd0867774845.pt │ │ │ ├── 000eefc67a2c930f.pt │ │ │ ├── 000ffab30195c5e1.pt │ │ │ ├── 0010833a96e1f886.pt │ │ │ ├── 00128363e367d703.pt │ │ │ ├── 0015f4aa35ebe9b5.pt │ │ │ └── 001735f961a23fc4.pt │ │ └── toxic │ │ ├── transformer.h.0.mlp.c_proj:0 │ │ └── mean │ │ │ ├── 0002bcb3da6cb337.pt │ │ │ ├── 0005c987bdfc9d4b.pt │ │ │ ├── 0007e25b2121310b.pt │ │ │ ├── 0020fd96ed3b8c8b.pt │ │ │ ├── 0028d62e8a5629aa.pt │ │ │ ├── 003217c3eb469ba9.pt │ │ │ ├── 0036621e4c7e10b5.pt │ │ │ └── 00472b8e2d38d1ea.pt │ │ └── transformer.h.1.mlp.c_proj:0 │ │ └── mean │ │ ├── 0002bcb3da6cb337.pt │ │ ├── 0005c987bdfc9d4b.pt │ │ ├── 0007e25b2121310b.pt │ │ ├── 0020fd96ed3b8c8b.pt │ │ ├── 0028d62e8a5629aa.pt │ │ ├── 003217c3eb469ba9.pt │ │ ├── 0036621e4c7e10b5.pt │ │ └── 00472b8e2d38d1ea.pt ├── whispx-test-max │ └── tiny-gpt2 │ │ ├── transformer.h.0.mlp.c_proj.statedict │ │ └── transformer.h.1.mlp.c_proj.statedict └── wikipedia_sentences.csv ├── test_0shot.py ├── test_datasets.py ├── test_interventions.py ├── test_model.py ├── test_perplexity.py ├── test_pipeline.py ├── test_responses.py └── test_responses_io.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # celery beat schedule file 87 | celerybeat-schedule 88 | 89 | # SageMath parsed files 90 | *.sage.py 91 | 92 | # Environments 93 | .env 94 | .venv 95 | env/ 96 | venv/ 97 | ENV/ 98 | env.bak/ 99 | venv.bak/ 100 | 101 | # Spyder project settings 102 | .spyderproject 103 | .spyproject 104 | 105 | # Rope project settings 106 | .ropeproject 107 | 108 | # mkdocs documentation 109 | /site 110 | 111 | # mypy 112 | .mypy_cache/ 113 | .dmypy.json 114 | dmypy.json 115 | 116 | # Pyre type checker 117 | .pyre/ 118 | 119 | # VSCODE IDE 120 | .vscode 121 | 122 | #Pycharm IDE 123 | .idea 124 | 125 | #Hydra logs 126 | outputs/ -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | repos: 5 | # Using this mirror lets us use mypyc-compiled black, which is about 2x faster 6 | - repo: https://github.com/psf/black-pre-commit-mirror 7 | rev: 24.2.0 8 | hooks: 9 | - id: black 10 | # It is recommended to specify the latest version of Python 11 | # supported by your project here, or alternatively use 12 | # pre-commit's default_language_version, see 13 | # https://pre-commit.com/#top_level-default_language_version 14 | language_version: python3 15 | - repo: local 16 | hooks: 17 | - id: pytest-check 18 | stages: [push] 19 | name: pytest-check 20 | entry: pytest 21 | language: system 22 | pass_filenames: false 23 | always_run: true -------------------------------------------------------------------------------- /ACKNOWLEDGEMENTS.md: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this `ml-whispx` Software may utilize the following copyrighted 3 | material, the use of which is hereby acknowledged. 4 | 5 | _____________________ 6 | 7 | ## Frameworks 8 | 9 | ### [Pytorch](https://pytorch.org/) 10 | 11 | ``` 12 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 13 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 14 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 15 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 16 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 17 | Copyright (c) 2011-2013 NYU (Clement Farabet) 18 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 19 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 20 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 21 | ``` 22 | 23 | ### [Huggingface](https://github.com/huggingface) 24 | 25 | ```Copyright 2018- The Hugging Face team. All rights reserved. 26 | 27 | Apache License 28 | Version 2.0, January 2004 29 | http://www.apache.org/licenses/ 30 | ``` 31 | 32 | ## Pre-trained models 33 | 34 | ### [Gpt-2 (openai-community)](https://huggingface.co/openai-community/gpt2) 35 | 36 | ``` 37 | MIT License 38 | 39 | Copyright (c) [year] [fullname] 40 | 41 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 42 | 43 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 44 | 45 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 46 | ``` 47 | 48 | 49 | ### [Mistral-7b](https://huggingface.co/mistralai/Mistral-7B-v0.1) 50 | 51 | ``` 52 | Apache License 53 | Version 2.0, January 2004 54 | http://www.apache.org/licenses/ 55 | ``` 56 | 57 | ### [Llama3](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/LICENSE) 58 | 59 | ``` 60 | META LLAMA 3 COMMUNITY LICENSE AGREEMENT 61 | Meta Llama 3 Version Release Date: April 18, 2024 62 | 63 | “Agreement” means the terms and conditions for use, reproduction, distribution and modification of the 64 | Llama Materials set forth herein. 65 | ``` 66 | 67 | 68 | ### [Gemma2](https://ai.google.dev/gemma/terms) 69 | 70 | ``` 71 | You may use, reproduce, modify, Distribute, perform or display any of the Gemma Services only in accordance with the terms of this Agreement, and must not violate (or encourage or permit anyone else to violate) any term of this Agreement. 72 | ``` 73 | 74 | ### [StableDiffusion XL lightning](https://huggingface.co/ByteDance/SDXL-Lightning/blob/main/LICENSE.md) 75 | 76 | ``` 77 | Copyright (c) 2024 Bytedance Inc. Copyright (c) 2023 Stability AI CreativeML Open RAIL++-M License dated July 26, 2023 78 | 79 | ``` 80 | 81 | ### [FLUX.1 schnell](https://github.com/black-forest-labs/flux/blob/main/model_licenses/LICENSE-FLUX1-schnell) 82 | 83 | ``` 84 | Apache License 85 | Version 2.0, January 2004 86 | http://www.apache.org/licenses/ 87 | ``` 88 | 89 | 90 | ### [WandB]() 91 | 92 | ``` 93 | MIT License 94 | 95 | Copyright (c) 2021 Weights and Biases, Inc. 96 | 97 | Permission is hereby granted, free of charge, to any person obtaining a copy 98 | of this software and associated documentation files (the "Software"), to deal 99 | in the Software without restriction, including without limitation the rights 100 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 101 | copies of the Software, and to permit persons to whom the Software is 102 | furnished to do so, subject to the following conditions: 103 | ``` 104 | 105 | ## [Datasets] 106 | 107 | ### [Jigsaw toxic comment classification](https://github.com/praj2408/Jigsaw-Toxic-Comment-Classification/blob/main/LICENSE) 108 | 109 | ``` 110 | MIT License 111 | 112 | Copyright (c) 2023 Prajwal Krishna 113 | 114 | Permission is hereby granted, free of charge, to any person obtaining a copy 115 | of this software and associated documentation files (the "Software"), to deal 116 | in the Software without restriction, including without limitation the rights 117 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 118 | copies of the Software, and to permit persons to whom the Software is 119 | furnished to do so, subject to the following conditions: 120 | 121 | The above copyright notice and this permission notice shall be included in all 122 | copies or substantial portions of the Software. 123 | 124 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 125 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 126 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 127 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 128 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 129 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 130 | SOFTWARE. 131 | ``` 132 | 133 | ### [Real Toxicity Prompts](https://huggingface.co/datasets/allenai/real-toxicity-prompts) 134 | 135 | ``` 136 | Apache License 137 | Version 2.0, January 2004 138 | http://www.apache.org/licenses/ 139 | ``` -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). 12 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /act/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /act/configs/intervention_params/aura.yaml: -------------------------------------------------------------------------------- 1 | name: aura 2 | pooling_op: max 3 | incremental: atonce 4 | state_path: null 5 | hook_params: 6 | intervention_position: all 7 | strength: 1.0 8 | device: ${device} 9 | dtype: ${dtype:torch.float32} 10 | quantiles_src: q_all -------------------------------------------------------------------------------- /act/configs/intervention_params/gaussian_ot.yaml: -------------------------------------------------------------------------------- 1 | name: gaussian_ot 2 | pooling_op: mean 3 | incremental: incr 4 | state_path: null 5 | hook_params: 6 | intervention_position: all 7 | strength: 1.0 8 | device: ${device} 9 | dtype: ${dtype:torch.float32} 10 | quantiles_src: q_all -------------------------------------------------------------------------------- /act/configs/intervention_params/linear_ot.yaml: -------------------------------------------------------------------------------- 1 | name: linear_ot 2 | pooling_op: mean 3 | incremental: incr 4 | state_path: null 5 | hook_params: 6 | intervention_position: all 7 | strength: 1.0 8 | device: ${device} 9 | dtype: ${dtype:torch.float32} 10 | quantiles_src: q_0_100 -------------------------------------------------------------------------------- /act/configs/intervention_params/mean_ot.yaml: -------------------------------------------------------------------------------- 1 | name: mean_ot 2 | pooling_op: mean 3 | incremental: incr 4 | state_path: null 5 | hook_params: 6 | intervention_position: all 7 | strength: 1.0 8 | device: ${device} 9 | dtype: ${dtype:torch.float32} 10 | quantiles_src: q_all -------------------------------------------------------------------------------- /act/configs/intervention_params/none.yaml: -------------------------------------------------------------------------------- 1 | name: none 2 | pooling_op: mean 3 | incremental: atonce 4 | state_path: null 5 | hook_params: 6 | intervention_position: all 7 | strength: 0.0 8 | device: ${device} 9 | dtype: ${dtype:torch.float32} 10 | quantiles_src: q_all -------------------------------------------------------------------------------- /act/configs/model/FLUX.1-schnell.yaml: -------------------------------------------------------------------------------- 1 | model_path: "black-forest-labs/FLUX.1-schnell" 2 | default_batch_size: 4 3 | guidance_scale: 0 4 | inference_steps: 4 5 | default_module_names: 6 | default: ${.transformer_blocks_12} 7 | fast: ["transformer.transformer_blocks.0:0"] 8 | transformer_blocks_12: 9 | - "transformer.transformer_blocks.[0-9]+:[0-1]" 10 | - "transformer.single_transformer_blocks.[0-9]:0" 11 | - "transformer.single_transformer_blocks.11:0" -------------------------------------------------------------------------------- /act/configs/model/Meta-Llama-3-8B.yaml: -------------------------------------------------------------------------------- 1 | model_path: meta-llama/Meta-Llama-3-8B 2 | default_batch_size: 32 3 | seq_len: 128 4 | dtype: ${dtype:torch.bfloat16} 5 | default_module_names: 6 | icml24: 7 | - "model.layers.*.mlp.up_proj" 8 | - "model.layers.*.mlp.down_proj" 9 | - "model.layers.*.mlp.gate_proj" 10 | layernorm: ['.+layernorm'] 11 | fast: ['model.layers.0.mlp.down_proj'] 12 | default: ${.icml24} -------------------------------------------------------------------------------- /act/configs/model/SDXL-Lightning.yaml: -------------------------------------------------------------------------------- 1 | model_path: ByteDance/SDXL-Lightning 2 | default_batch_size: 16 3 | guidance_scale: 0 4 | inference_steps: 4 5 | default_module_names: 6 | default: ${.layernorm} 7 | layernorm: ['unet.*norm.*'] 8 | fast: ['text_encoder.text_model.encoder.layers.0.mlp:0'] -------------------------------------------------------------------------------- /act/configs/model/gemma-2-2b.yaml: -------------------------------------------------------------------------------- 1 | model_path: google/gemma-2-2b 2 | default_batch_size: 32 3 | seq_len: 128 4 | dtype: ${dtype:torch.bfloat16} 5 | default_module_names: 6 | icml24: 7 | - "model.layers.*.mlp.up_proj" 8 | - "model.layers.*.mlp.down_proj" 9 | - "model.layers.*.mlp.gate_proj" 10 | layernorm: ['.+layernorm'] 11 | post_layernorm: [".*post_attention_layernorm", ".*post_feedforward_layernorm"] 12 | fast: ['model.layers.0.mlp.down_proj'] 13 | default: ${.post_layernorm} -------------------------------------------------------------------------------- /act/configs/model/gpt2.yaml: -------------------------------------------------------------------------------- 1 | model_path: openai-community/gpt2 2 | default_batch_size: 32 3 | seq_len: 128 4 | dtype: ${dtype:torch.float32} 5 | default_module_names: 6 | layernorm: ['.+ln_.+'] 7 | fast: ['.*.0\..+ln_.+'] 8 | default: ${.layernorm} -------------------------------------------------------------------------------- /act/configs/task_params/coco_concepts.yaml: -------------------------------------------------------------------------------- 1 | model_task: "text-to-image-generation" 2 | dataset: "coco-captions-concepts" 3 | src_subsets: ["pink_elephant"] 4 | dst_subsets: ["none"] 5 | prompt_subset: ["pink_elephant"] 6 | default_evaluation: ['text-to-image-generation', 'clip_score'] -------------------------------------------------------------------------------- /act/configs/task_params/coco_styles.yaml: -------------------------------------------------------------------------------- 1 | model_task: "text-to-image-generation" 2 | dataset: "coco-captions-styles" 3 | src_subsets: ["none"] 4 | dst_subsets: ["art_nouveau"] 5 | prompt_subset: [none] 6 | default_evaluation: ['text-to-image-generation', 'clip_score'] -------------------------------------------------------------------------------- /act/configs/task_params/diffusion_prompts.yaml: -------------------------------------------------------------------------------- 1 | model_task: "text-to-image-generation" 2 | dataset: "diffusion-prompts" 3 | src_subsets: ["no_trees"] 4 | dst_subsets: ["trees"] 5 | prompt_subset: [none] 6 | default_evaluation: ['text-to-image-generation'] 7 | dataset_params: 8 | json_path: "data/diffusion_prompts.json" -------------------------------------------------------------------------------- /act/configs/task_params/giraffes.yaml: -------------------------------------------------------------------------------- 1 | model_task: "text-generation" 2 | dataset: "json-subsets" 3 | src_subsets: ["none"] 4 | dst_subsets: ["giraffe"] 5 | default_evaluation: ['text-generation'] 6 | dataset_params: 7 | json_path: "data/giraffes.json" -------------------------------------------------------------------------------- /act/configs/task_params/toxicity.yaml: -------------------------------------------------------------------------------- 1 | model_task: "text-generation" 2 | dataset: "jigsaw" 3 | src_subsets: ["toxic"] 4 | dst_subsets: ["non-toxic"] 5 | default_evaluation: ['text-generation', 'model_perplexity', 'mmlu', 'zero_shot', 'rtp'] -------------------------------------------------------------------------------- /act/configs/text_to_image_generation.yaml: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | hydra: 5 | job: 6 | chdir: false 7 | 8 | defaults: 9 | - _self_ 10 | #################### 11 | # Wandb config # 12 | #################### 13 | - wandb/act.yaml 14 | #################### 15 | # Intervention # 16 | #################### 17 | - intervention_params: linear_ot 18 | #################### 19 | # Diffusion # 20 | #################### 21 | # Tasks 22 | - task_params: coco_styles.yaml 23 | # Models 24 | - model: SDXL-Lightning.yaml 25 | 26 | intervention_params: null 27 | task_params: null 28 | 29 | # Data and Model Loading Settings 30 | data_dir: ${oc.env:DATA_DIR,/mnt/data} 31 | cache_dir: ${oc.env:CACHE_DIR,${data_dir}} 32 | results_dir: ${oc.env:OUTPUT_DIR,/tmp/results} 33 | 34 | # Some globals 35 | device: 'cuda' 36 | fast: false 37 | seed: 42 38 | # By default we let model config specify batch_size 39 | batch_size: ${model.default_batch_size} 40 | 41 | # Decides which scripts to run 42 | compute_responses: true 43 | compute_interventions: true 44 | 45 | # evaluation: e.g. ['clip_score'] 46 | evaluation: ${task_params.default_evaluation} 47 | 48 | model: 49 | module_names: ${.default_module_names.default} 50 | 51 | responses: 52 | # Response Generation Settings 53 | batch_size: ${batch_size} 54 | balanced_data: true 55 | save_fields: [] # Extra fields to save with responses 56 | shuffle: true 57 | device: "${device}" 58 | dtype: ${dtype:torch.float32} 59 | max_batches: null 60 | num_workers: 1 61 | seed: ${seed} 62 | resume: true 63 | data_dir: ${data_dir} 64 | cache_dir: ${cache_dir} 65 | tag: "responses" 66 | # Params for response-saving hooks 67 | intervention_params: 68 | name: "postprocess_and_save" 69 | pooling_op: ${interventions.intervention_params.pooling_op} 70 | hook_params: 71 | raise_exception: false 72 | # see configs/task_params 73 | task_params: ${task_params} 74 | model_params: ${model} 75 | 76 | interventions: 77 | # Response Generation Settings 78 | batch_size: 2 79 | max_batches: null 80 | shuffle: true 81 | device: "cpu" 82 | dtype: {dtype:torch.float32} 83 | load_fields: [] 84 | num_workers: 1 85 | seed: ${seed} 86 | resume: false 87 | cache_dir: ${cache_dir} 88 | tag: "interventions" 89 | # see configs/task_params 90 | task_params: ${task_params} 91 | # see configs/intervention_params 92 | intervention_params: ${intervention_params} 93 | # see configs/model 94 | model_params: ${model} 95 | 96 | text_to_image_generation: 97 | max_batches: 4 98 | num_workers: 2 99 | device: "${device}" 100 | seed: ${seed} 101 | data_dir: ${data_dir} 102 | cache_dir: ${cache_dir} 103 | # Text generation batch size. 104 | batch_size: ${batch_size} 105 | # Will sweep over these numbers 106 | diffusion_guidance_scale: [] 107 | # Base number for guidance scale 108 | guidance_scale: ${.model_params.guidance_scale} 109 | # Base number of diffusion inference steps 110 | num_inference_steps: ${.model_params.inference_steps} 111 | # Diffusion image resolution 112 | generation_resolution: 224 113 | # Where results are saved 114 | results_dir: ${results_dir} 115 | # If true, also save a gif animation 116 | create_gif: true 117 | # see configs/intervention_params 118 | intervention_params: ${intervention_params} 119 | # Strentgths for which to generate images 120 | min_strength: 0.0 121 | max_strength: 1.0 122 | # Will execute np.linspace over this number of steps 123 | strength_steps: 11 124 | # If true, run script in fast mode (small batches, small data, not useful for true results) 125 | fast: ${fast} 126 | verbose: 1 127 | # See defaults on top of this file 128 | task_params: ${task_params} 129 | # see configs/model 130 | model_params: ${model} 131 | # wandb params 132 | wandb: ${wandb} 133 | # Use these prompts to generate instead of reading from a dataset 134 | prompt_override: null 135 | 136 | clip_score: 137 | input_folder: ${results_dir}/generate_with_hooks_diffusion 138 | results_dir: ${results_dir} 139 | device: ${device} 140 | # wandb params 141 | wandb: ${wandb} -------------------------------------------------------------------------------- /act/configs/wandb/act.yaml: -------------------------------------------------------------------------------- 1 | project: AcT 2 | group: null 3 | tags: 4 | - ${intervention_params.name} 5 | - ${task_params.dataset} 6 | - ${model.model_path} 7 | dir: ${cache_dir} 8 | entity: null 9 | mode: online # options: offline, online, disabled -------------------------------------------------------------------------------- /act/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import typing as t 5 | from pathlib import Path 6 | 7 | import torch 8 | import transformers 9 | 10 | from .coco_captions import ( 11 | get_coco_captions_dataset, 12 | get_coco_concepts_dataset, 13 | get_coco_styles_dataset, 14 | ) 15 | from .jigsaw_dataset import get_jigsaw_dataset 16 | from .json_prompts import get_example_text2img_prompts_dataset 17 | from .json_subsets_dataset import get_json_subsets_dataset 18 | from .onesec_dataset import get_onesec_dataset 19 | from .samplers import StratifiedSampler 20 | 21 | DATASET_LOADERS_REGISTRY = { 22 | "jigsaw": get_jigsaw_dataset, 23 | "OneSecConcepts-100_1.5.0": get_onesec_dataset, 24 | "example-text2img-prompts": get_example_text2img_prompts_dataset, 25 | "diffusion-prompts": get_example_text2img_prompts_dataset, 26 | "coco-captions-2017": get_coco_captions_dataset, 27 | "coco-captions-styles": get_coco_styles_dataset, 28 | "coco-captions-concepts": get_coco_concepts_dataset, 29 | "json-subsets": get_json_subsets_dataset, 30 | } 31 | 32 | 33 | def get_dataset( 34 | name: str, 35 | datasets_folder: Path, 36 | split: str, 37 | subsets: t.Set[str], 38 | tokenizer: t.Optional[transformers.PreTrainedTokenizer] = None, 39 | **kwargs, 40 | ) -> t.Tuple[torch.utils.data.Dataset, t.Callable]: 41 | """Loads and returns a dataset split given its name. It also returns a collator function for the dataloader 42 | 43 | Args: 44 | name (str): dataset name 45 | datasets_folder (Path): path where dataset is located 46 | split (bool): train, val, test 47 | tokenizer (t.Optional[transformers.PreTrainedTokenizer], optional): a huggingface tokenizer in case it is a text dataset. Defaults to None. 48 | 49 | Returns: 50 | t.Tuple[torch.utils.data.Dataset, t.Callable]: pytorch Dataset instance and collator function 51 | """ 52 | assert ( 53 | name in DATASET_LOADERS_REGISTRY 54 | ), f"{name} not in DATASET_LOADERS_REGISTRY ({DATASET_LOADERS_REGISTRY.keys()})" 55 | data_loader = DATASET_LOADERS_REGISTRY[name] 56 | return data_loader( 57 | datasets_folder, 58 | split=split, 59 | subsets=subsets, 60 | tokenizer=tokenizer, 61 | **kwargs, 62 | ) 63 | 64 | 65 | def get_dataloader( 66 | dataset: torch.utils.data.Dataset, 67 | batch_size: int, 68 | num_workers: int, 69 | collate_fn: int, 70 | drop_last: bool, 71 | shuffle: bool, 72 | balanced: bool = False, 73 | seed: int = 0, 74 | **kwargs: dict, 75 | ) -> torch.utils.data.DataLoader: 76 | if balanced: 77 | sampler = StratifiedSampler(dataset.subsets, seed=seed) 78 | shuffle = False 79 | else: 80 | sampler = None 81 | shuffle = shuffle 82 | return torch.utils.data.DataLoader( 83 | dataset, 84 | batch_size=batch_size, 85 | num_workers=num_workers, 86 | collate_fn=collate_fn, 87 | drop_last=drop_last, 88 | shuffle=shuffle, 89 | sampler=sampler, 90 | **kwargs, 91 | ) 92 | 93 | 94 | if __name__ == "__main__": 95 | pass 96 | -------------------------------------------------------------------------------- /act/datasets/collators.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import typing as t 5 | from collections import defaultdict 6 | 7 | import torch 8 | import transformers 9 | 10 | 11 | class DictCollatorWithPadding(transformers.DataCollatorWithPadding): 12 | """Helper class that pads text sequences of multiple lengths 13 | to a contiguous batch by applying huggingface's 14 | DataCollatorWithPadding. For the rest of data types 15 | it converts them from list of dicts to dict of lists. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | tokenizer: transformers.PreTrainedTokenizer, 21 | return_tensors: str = "pt", 22 | pad_to_multiple_of: int = 2, 23 | **kwargs, 24 | ): 25 | super().__init__( 26 | tokenizer=tokenizer, 27 | return_tensors=return_tensors, 28 | pad_to_multiple_of=pad_to_multiple_of, 29 | **kwargs, 30 | ) 31 | self.tensor_set = set(["input_ids", "attention_mask"]) 32 | 33 | def __call__(self, batch: t.List) -> t.Tuple[t.Dict[str, torch.Tensor], t.Dict]: 34 | """Function to be applied on list of samples from dataloader to form a batch 35 | 36 | Args: 37 | batch (t.List): list of samples 38 | 39 | Returns: 40 | t.Tuple[t.Dict[str, torch.Tensor], t.Dict]: Tuple with tokens and additional metadata like labels 41 | """ 42 | ret = defaultdict(list) 43 | tensors = [] 44 | meta_set = set(batch[0].keys()) - self.tensor_set 45 | for sample in batch: 46 | tensors.append({k: sample[k] for k in self.tensor_set}) 47 | for k in meta_set: 48 | ret[k].append(sample[k]) 49 | tensors = transformers.DataCollatorWithPadding.__call__( 50 | self, tensors 51 | ) # huggingface does the padding 52 | ret.update(tensors) 53 | return ret 54 | -------------------------------------------------------------------------------- /act/datasets/dummy_text2image_prompts.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import json 5 | import typing as t 6 | from pathlib import Path 7 | 8 | import torch 9 | from torch.utils.data import DataLoader, Dataset, default_collate 10 | 11 | 12 | class DummyText2ImagePrompts(Dataset): 13 | def __init__(self, file_path): 14 | super().__init__() 15 | 16 | with open(file_path, "r") as f: 17 | data = json.load(f) 18 | 19 | self.data = [] 20 | self.subsets = [] 21 | for subset in data: 22 | for i, string in enumerate(data[subset]): 23 | self.data.append( 24 | {"id": len(self.data), "prompt": string, "subset": subset} 25 | ) 26 | self.subsets.append(subset) 27 | 28 | def get_all_subsets(self): 29 | return set(self.subsets) 30 | 31 | def __len__(self): 32 | return len(self.data) 33 | 34 | def __getitem__(self, idx): 35 | return self.data[idx] 36 | 37 | 38 | def get_dummy_text2image_prompts( 39 | *args, 40 | **kwargs, 41 | ) -> torch.utils.data.Dataset: 42 | return ( 43 | DummyText2ImagePrompts(Path("data/dummy_text2img_prompts.json")), 44 | default_collate, 45 | ) 46 | -------------------------------------------------------------------------------- /act/datasets/jigsaw_dataset.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import typing as t 5 | from collections import OrderedDict 6 | from pathlib import Path 7 | 8 | import pandas as pd 9 | import torch 10 | import transformers 11 | 12 | from act.datasets.collators import DictCollatorWithPadding 13 | 14 | 15 | class JigsawDataset(torch.utils.data.Dataset): 16 | """ 17 | Implements a loader for the Jigsaw toxicity dataset. 18 | To get the files download from the following URL into `path`: 19 | https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge/data 20 | """ 21 | 22 | SUPER_TOXIC_FLAG = "all" 23 | SUPER_NONTOXIC_FLAG = "non-toxic" 24 | JIGSAW_CATEGORIES = [ 25 | "toxic", 26 | "severe_toxic", 27 | "obscene", 28 | "threat", 29 | "insult", 30 | "identity_hate", 31 | ] 32 | 33 | def __init__( 34 | self, 35 | path: Path, 36 | split: str, 37 | subsets: t.Set[str], 38 | tokenizer: transformers.PreTrainedTokenizer, 39 | ) -> torch.utils.data.Dataset: 40 | self.split = split 41 | self.path = path 42 | self.target_subsets = set(subsets) 43 | # Some magic, if we select "all" it appends all the jigsaw categories. 44 | self.using_all = False 45 | if "all" in self.target_subsets: 46 | self.target_subsets.remove("all") 47 | self.target_subsets.update(set(self.JIGSAW_CATEGORIES)) 48 | self.using_all = True 49 | 50 | assert self.target_subsets.issubset(set(self.JIGSAW_CATEGORIES + ["non-toxic"])) 51 | self.tokenizer = tokenizer 52 | 53 | if self.split == "train": 54 | train_data = pd.read_csv(path / "train.csv", index_col="id") 55 | self.data = self._preprocess(train_data) 56 | elif self.split == "test": 57 | test_data = pd.read_csv(path / "test.csv", index_col="id") 58 | test_labels = pd.read_csv(path / "test_labels.csv", index_col="id") 59 | test_dataset = pd.concat( 60 | [test_data, test_labels], axis=1, ignore_index=False 61 | ) 62 | # test dataset comes with unannotated data (label=-1) 63 | test_dataset = test_dataset.loc[ 64 | (test_dataset[test_dataset.columns[1:]] > -1).all(axis=1) 65 | ] 66 | self.data = self._preprocess(test_dataset) 67 | _ = self.data[0] # small test 68 | 69 | def get_label(elem): 70 | """ 71 | Returns the (binary) label string of a given Jigsaw datapoint. 72 | 73 | If at least one of the categories in self.target_subsets is satisfied, return the "toxic" class. 74 | If no category is satisfied and we require non-toxic sentences, return the "non-toxic" class. 75 | If some category is satisfied NOT in self.target_subsets, returns None (meaning we should skip this datapoint). 76 | """ 77 | is_other = 0 78 | for categ in self.JIGSAW_CATEGORIES: 79 | if elem[categ] == 1 and categ in self.target_subsets: 80 | return self.SUPER_TOXIC_FLAG if self.using_all else categ 81 | if elem[categ] == 1 and categ not in self.target_subsets: 82 | is_other = True 83 | return ( 84 | self.SUPER_NONTOXIC_FLAG 85 | if (not is_other and self.SUPER_NONTOXIC_FLAG in self.target_subsets) 86 | else None 87 | ) 88 | 89 | self.subsets = [get_label(d) for d in self.data] 90 | self.data = [d for s, d in zip(self.subsets, self.data) if s is not None] 91 | self.subsets = [s for s in self.subsets if s is not None] 92 | 93 | def _preprocess(self, df: pd.DataFrame): 94 | return df.reset_index().to_dict("records") 95 | 96 | def __getitem__(self, item) -> t.Dict: 97 | datum: t.Dict = self.data[item] 98 | tokens = self.tokenizer(datum["comment_text"], truncation=True, padding=False) 99 | datum.update(tokens) 100 | datum["subset"] = self.subsets[item] 101 | return datum 102 | 103 | def __len__(self) -> int: 104 | return len(self.data) 105 | 106 | 107 | def get_jigsaw_dataset( 108 | path: Path, 109 | split: str, 110 | subsets: t.Set[str], 111 | tokenizer: transformers.PreTrainedTokenizer, 112 | **kwargs 113 | ) -> torch.utils.data.Dataset: 114 | return JigsawDataset( 115 | Path(path) / "jigsaw", split, subsets=subsets, tokenizer=tokenizer 116 | ), DictCollatorWithPadding(tokenizer) 117 | -------------------------------------------------------------------------------- /act/datasets/json_prompts.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import json 5 | import typing as t 6 | from fnmatch import fnmatch 7 | from pathlib import Path 8 | 9 | import torch 10 | from torch.utils.data import DataLoader, Dataset, default_collate 11 | 12 | 13 | class JsonPromptsDataset(Dataset): 14 | """ 15 | A PyTorch Dataset that loads data from a JSON file. 16 | 17 | This dataset loads a list of prompts from a JSON file, and allows for filtering the subsets to include based on wildcard patterns. 18 | Each item in the dataset is a dictionary containing 'id', 'prompt' and 'subset' fields. The 'id' is an automatically generated index, 19 | 'prompt' is the prompt string loaded from the JSON file, and 'subset' is the subset name from which this prompt belongs to. 20 | 21 | Args: 22 | file_path (str): Path to the JSON file containing the prompts data. 23 | subsets (Optional[Union[List[str], str]]): A list of wildcard patterns or a single pattern to include only certain subsets of the data. 24 | If None, all subsets are included. Default is None. 25 | 26 | Attributes: 27 | data (List[Dict]): List of dictionaries where each dictionary represents an item in the dataset with keys 'id', 'prompt' and 'subset'. 28 | subsets (List[str]): List of subset names that were loaded from the JSON file. 29 | 30 | Methods: 31 | get_all_subsets(): Returns a set of all unique subsets. 32 | __len__(): Returns the number of items in the dataset. 33 | __getitem__(idx): Retrieves an item by its index, returns it as a dictionary with keys 'id', 'prompt' and 'subset'. 34 | """ 35 | 36 | def __init__(self, file_path, subsets=t.List[str]): 37 | super().__init__() 38 | if subsets is None: 39 | subsets = "*" 40 | if not isinstance(subsets, (list, tuple, set)): 41 | subsets = [subsets] 42 | with open(file_path, "r") as f: 43 | data = json.load(f) 44 | 45 | self.data = [] 46 | self.subsets = [] 47 | for subset in data: 48 | if not any(map(lambda pattern: fnmatch(subset, pattern), subsets)): 49 | continue 50 | for i, string in enumerate(data[subset]): 51 | self.data.append( 52 | { 53 | "id": len(self.data), 54 | "original_prompt": string, 55 | "prompt": string, 56 | "subset": subset, 57 | } 58 | ) 59 | self.subsets.append(subset) 60 | 61 | def get_all_subsets(self): 62 | return set(self.subsets) 63 | 64 | def __len__(self): 65 | return len(self.data) 66 | 67 | def __getitem__(self, idx): 68 | return self.data[idx] 69 | 70 | 71 | def get_example_text2img_prompts_dataset( 72 | root: Path, 73 | *args, 74 | subsets: t.List = None, 75 | json_path: Path = None, 76 | **kwargs, 77 | ) -> torch.utils.data.Dataset: 78 | return ( 79 | JsonPromptsDataset(json_path, subsets=subsets), 80 | default_collate, 81 | ) 82 | -------------------------------------------------------------------------------- /act/datasets/json_subsets_dataset.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import json 5 | import typing as t 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | import transformers 11 | from torch.utils.data import DataLoader, Dataset, default_collate 12 | 13 | from act.datasets.collators import DictCollatorWithPadding 14 | 15 | 16 | class JsonSubsetsDataset(Dataset): 17 | """ 18 | A dataset class for loading and handling text data from a JSON file, 19 | where the data is organized into subsets. Each subset corresponds to 20 | a key in the JSON file, and the associated value is a list of sentences 21 | or text entries. 22 | 23 | Attributes: 24 | data (list): A list containing all the sentences from the specified subsets. 25 | subsets (list): A list indicating the subset each sentence belongs to. 26 | 27 | Methods: 28 | get_all_subsets(): 29 | Returns a set of all unique subsets present in the dataset. 30 | __len__(): 31 | Returns the total number of text entries in the dataset. 32 | __getitem__(idx): 33 | Retrieves a dictionary containing the ID, subset, and text entry 34 | for the specified index. 35 | 36 | Args: 37 | json_path (str): Path to the JSON file containing the dataset. 38 | subsets (list, optional): A list of keys corresponding to the subsets 39 | to be included in the dataset. If not specified or set to None, 40 | all subsets will be included. Defaults to "*" (all subsets). 41 | """ 42 | 43 | def __init__( 44 | self, 45 | json_path: Path, 46 | subsets: t.List[str], 47 | tokenizer: transformers.PreTrainedTokenizer, 48 | **kwargs, 49 | ): 50 | super().__init__() 51 | if subsets is None: 52 | subsets = "*" 53 | if not isinstance(subsets, (list, tuple, set)): 54 | subsets = [subsets] 55 | with open(json_path, "r") as f: 56 | data: t.Dict = json.load(f) 57 | 58 | self.tokenizer = tokenizer 59 | self.data = [] 60 | self.subsets = [] 61 | self.idx_in_subset = [] 62 | for key, sentences in data.items(): 63 | if subsets == "*" or key in subsets: 64 | subset = key 65 | else: 66 | continue 67 | self.data += sentences 68 | self.subsets += [subset] * len(sentences) 69 | self.idx_in_subset += np.arange(len(sentences)).tolist() 70 | 71 | def get_all_subsets(self): 72 | return set(self.subsets) 73 | 74 | def __len__(self): 75 | return len(self.data) 76 | 77 | def __getitem__(self, idx): 78 | datum = { 79 | "id": idx, 80 | "subset": self.subsets[idx], 81 | "text": self.data[idx], 82 | "idx_in_subset": self.idx_in_subset[idx], 83 | } 84 | tokens = self.tokenizer(datum["text"], truncation=True, padding=False) 85 | datum.update(tokens) 86 | return datum 87 | 88 | 89 | def get_json_subsets_dataset( 90 | *args, 91 | # file_path=Path("data/giraffe_eagle_situations.json"), 92 | subsets=None, 93 | tokenizer=None, 94 | **kwargs, 95 | ) -> torch.utils.data.Dataset: 96 | assert tokenizer is not None, "Must pass a tokenizer" 97 | return ( 98 | JsonSubsetsDataset(subsets=subsets, tokenizer=tokenizer, **kwargs), 99 | DictCollatorWithPadding(tokenizer), 100 | ) 101 | -------------------------------------------------------------------------------- /act/datasets/onesec_dataset.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import typing as t 5 | from collections import OrderedDict 6 | from pathlib import Path 7 | 8 | import pandas as pd 9 | import torch 10 | import transformers 11 | 12 | from act.utils import utils 13 | 14 | from .collators import DictCollatorWithPadding 15 | 16 | 17 | class OneSecDataset(torch.utils.data.Dataset): 18 | LABEL_MAP = OrderedDict([("negative", 0), ("positive", 1)]) 19 | LABEL_NAMES = ["negative", "positive"] 20 | 21 | def __init__( 22 | self, 23 | path: Path, 24 | split: str, 25 | subsets: t.List[str], 26 | tokenizer: transformers.PreTrainedTokenizer, 27 | ) -> torch.utils.data.Dataset: 28 | assert split == "train", "Only train split is supported for now" 29 | self.path = path 30 | self.split = split 31 | self.tokenizer = tokenizer 32 | all_concepts = pd.read_csv(self.path / "concept_list.csv") 33 | all_concepts = list(sorted(all_concepts["concept"].values)) 34 | self.concepts = set([s.replace("non-", "") for s in subsets]) 35 | assert self.concepts.issubset(set(all_concepts)) 36 | self.target_subsets = set(subsets) 37 | self.data = [] 38 | self.subsets = [] 39 | self.ids = [] 40 | for concept in self.concepts: 41 | self.data_path = self.path / "sense" / f"{concept}.json" 42 | data = utils.load_json(self.data_path) 43 | if f"non-{concept}" in self.target_subsets: 44 | self.data += data["sentences"]["negative"] 45 | self.subsets.extend( 46 | [f"non-{concept}"] * len(data["sentences"]["negative"]) 47 | ) 48 | self.ids.extend(list(range(len(data["sentences"]["negative"])))) 49 | if concept in self.target_subsets: 50 | self.data += data["sentences"]["positive"] 51 | self.subsets.extend([concept] * len(data["sentences"]["positive"])) 52 | self.ids.extend(list(range(len(data["sentences"]["positive"])))) 53 | 54 | def __getitem__(self, item) -> dict: 55 | datum = { 56 | "text": self.data[item], 57 | "subset": self.subsets[item], 58 | "id": f"{self.ids[item]:04d}", 59 | } 60 | tokens = self.tokenizer(datum["text"], truncation=True, padding=False) 61 | datum.update(tokens) 62 | return datum 63 | 64 | def __len__(self) -> int: 65 | return len(self.data) 66 | 67 | 68 | def get_onesec_dataset( 69 | path: Path, 70 | split: str, 71 | subsets: t.List[str], 72 | tokenizer: transformers.PreTrainedTokenizer, 73 | **kwargs, 74 | ) -> OneSecDataset: 75 | return OneSecDataset( 76 | Path(path) / "OneSecConcepts-100_1.5.0", 77 | split="train", 78 | subsets=subsets, 79 | tokenizer=tokenizer, 80 | ), DictCollatorWithPadding(tokenizer=tokenizer) 81 | -------------------------------------------------------------------------------- /act/datasets/samplers.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import typing as t 5 | 6 | import torch 7 | 8 | 9 | class StratifiedSampler(torch.utils.data.Sampler): 10 | """Stratified Sampling 11 | 12 | Provides equal representation of target classes in each batch 13 | """ 14 | 15 | def __init__(self, labels: t.List[str], seed: int): 16 | """ 17 | Arguments 18 | --------- 19 | class_vector : torch tensor 20 | a vector of class labels 21 | batch_size : integer 22 | batch_size 23 | """ 24 | label_map = {l: i for i, l in enumerate(set(labels))} 25 | self.class_vector = torch.tensor([label_map[l] for l in labels]) 26 | self.idx = torch.arange(len(self.class_vector)) 27 | self.idx_per_label = {} 28 | uniques, counts = torch.unique(self.class_vector, return_counts=True) 29 | for label in uniques: 30 | self.idx_per_label[label] = self.idx[self.class_vector == label] 31 | self.min_count = torch.min(counts) 32 | self.seed = seed 33 | self.set_epoch(0) 34 | 35 | def set_epoch(self, epoch: int) -> None: 36 | """Puts RNG to the correct epoch 37 | 38 | Args: 39 | epoch (int): epoch to put the generator in 40 | """ 41 | self.epoch = epoch 42 | self.rng = torch.Generator() 43 | self.rng.manual_seed(self.seed) 44 | for _ in range(epoch): 45 | iter(self) 46 | 47 | def __iter__(self): 48 | indices = [] 49 | tail = [] 50 | for label, idx in self.idx_per_label.items(): 51 | indices.append( 52 | idx[torch.randperm(len(idx), generator=self.rng)[: self.min_count]] 53 | ) 54 | indices = torch.stack(indices, 1).ravel() 55 | return iter(indices) 56 | 57 | def __len__(self): 58 | return self.min_count * len(self.idx_per_label) 59 | -------------------------------------------------------------------------------- /act/evaluations/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /act/evaluations/calculate_clip_score.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import argparse 5 | import json 6 | import logging 7 | from collections import defaultdict 8 | from pathlib import Path 9 | 10 | import hydra 11 | import pandas as pd 12 | import torch 13 | from omegaconf import DictConfig, OmegaConf 14 | from PIL import Image 15 | from transformers import ( 16 | CLIPImageProcessor, 17 | CLIPTextModelWithProjection, 18 | CLIPTokenizer, 19 | CLIPVisionModelWithProjection, 20 | ) 21 | 22 | from act.utils import utils 23 | 24 | # Set up logging 25 | logging.basicConfig( 26 | level=logging.INFO, format="%(asctime)s - %(name)-12s %(levelname)-8s %(message)s" 27 | ) 28 | logger = logging.getLogger(__name__) 29 | 30 | import torch.nn as nn 31 | import torch.nn.functional as F 32 | 33 | 34 | # https://huggingface.co/docs/diffusers/en/conceptual/evaluation 35 | class DirectionalSimilarity(nn.Module): 36 | def __init__(self, device="cuda"): 37 | super().__init__() 38 | clip_id = "openai/clip-vit-large-patch14" 39 | self.tokenizer = CLIPTokenizer.from_pretrained(clip_id) 40 | self.text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_id).to( 41 | device 42 | ) 43 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_id).to( 44 | device 45 | ) 46 | self.image_processor = CLIPImageProcessor.from_pretrained(clip_id) 47 | 48 | self.device = device 49 | 50 | def preprocess_image(self, image): 51 | image = self.image_processor(image, return_tensors="pt")["pixel_values"] 52 | return {"pixel_values": image.to(self.device)} 53 | 54 | def tokenize_text(self, text): 55 | inputs = self.tokenizer( 56 | text, 57 | max_length=self.tokenizer.model_max_length, 58 | padding="max_length", 59 | truncation=True, 60 | return_tensors="pt", 61 | ) 62 | return { 63 | "input_ids": inputs.input_ids.to(self.device), 64 | "attention_mask": inputs.attention_mask.to(self.device), 65 | } 66 | 67 | def encode_image(self, image): 68 | preprocessed_image = self.preprocess_image(image) 69 | image_features = self.image_encoder(**preprocessed_image).image_embeds 70 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 71 | return image_features 72 | 73 | def encode_text(self, text): 74 | tokenized_text = self.tokenize_text(text) 75 | text_features = self.text_encoder(**tokenized_text).text_embeds 76 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 77 | return text_features 78 | 79 | def compute_directional_similarity( 80 | self, img_feat_one, img_feat_two, text_feat_one, text_feat_two 81 | ): 82 | sim_direction = F.cosine_similarity( 83 | img_feat_two - img_feat_one, text_feat_two - text_feat_one 84 | ) 85 | return sim_direction 86 | 87 | @torch.inference_mode() 88 | def forward( 89 | self, 90 | image_one, 91 | image_two, 92 | caption_one, 93 | caption_two, 94 | caption_zero_shot_one, 95 | caption_zero_shot_two, 96 | ): 97 | img_feat_one = self.encode_image(image_one) 98 | img_feat_two = self.encode_image(image_two) 99 | text_feat_one = self.encode_text(caption_one) 100 | text_feat_two = self.encode_text(caption_two) 101 | text_feat_zero_shot_one = self.encode_text(caption_zero_shot_one) 102 | text_feat_zero_shot_two = self.encode_text(caption_zero_shot_two) 103 | text_similarity = ( 104 | F.cosine_similarity(text_feat_one, text_feat_two).detach().cpu().numpy() 105 | ) 106 | image_similarity = ( 107 | F.cosine_similarity(img_feat_one, img_feat_two).detach().cpu().numpy() 108 | ) 109 | conditional_similarity = ( 110 | F.cosine_similarity(img_feat_two, text_feat_two).detach().cpu().numpy() 111 | ) 112 | unconditional_similarity = ( 113 | F.cosine_similarity(img_feat_two, text_feat_one).detach().cpu().numpy() 114 | ) 115 | directional_similarity = ( 116 | self.compute_directional_similarity( 117 | img_feat_one, img_feat_two, text_feat_one, text_feat_two 118 | ) 119 | .detach() 120 | .cpu() 121 | .numpy() 122 | ) 123 | unconditional_zero_shot_similarity = F.cosine_similarity( 124 | img_feat_two, text_feat_zero_shot_one 125 | ) 126 | conditional_zero_shot_similarity = F.cosine_similarity( 127 | img_feat_two, text_feat_zero_shot_two 128 | ) 129 | zero_shot_score = ( 130 | F.softmax( 131 | torch.stack( 132 | [ 133 | unconditional_zero_shot_similarity, 134 | conditional_zero_shot_similarity, 135 | ], 136 | dim=1, 137 | ), 138 | dim=1, 139 | ) 140 | .detach() 141 | .cpu() 142 | .numpy()[:, 1] 143 | ) 144 | return { 145 | "text_similarity": text_similarity, 146 | "image_similarity": image_similarity, 147 | "conditional_similarity": conditional_similarity, 148 | "unconditional_similarity": unconditional_similarity, 149 | "directional_similarity": directional_similarity, 150 | "conditional_zero_shot_score": zero_shot_score, 151 | } 152 | 153 | 154 | def calculate_clip_score(cfg: DictConfig) -> None: 155 | """ 156 | Main function to calculate CLIP scores for images based on prompts from JSON files or command line arguments. 157 | 158 | This function handles the parsing of command line arguments, reading of image data, and calculation of CLIP scores using zero-shot learning. 159 | 160 | Args: 161 | args (argparse.Namespace): The parsed command line arguments containing input folder path and prompt field. 162 | """ 163 | meta_dict = defaultdict(list) 164 | logger.info(f"Processing directory: {cfg.input_folder}") 165 | for img_path in sorted(Path(cfg.input_folder).glob("**/*.png")): 166 | # images += [Image.open(img_path)] 167 | with (Path(img_path).with_suffix(".json")).open("r") as fp: 168 | meta = json.load(fp) 169 | for k, v in meta.items(): 170 | if isinstance(v, list): 171 | meta_dict[k].extend(v) 172 | else: 173 | meta_dict[k].append(v) 174 | df = pd.DataFrame(meta_dict) 175 | assert len(df) > 0, "No images found in input folder." 176 | similarity = DirectionalSimilarity(cfg.device) 177 | results = [] 178 | for id in df["id"].unique(): 179 | df_id = df[df["id"] == id] 180 | unconditional_image_data = df_id[df_id["strength"] == 0] 181 | unconditional_image = Image.open(unconditional_image_data["image_path"].iloc[0]) 182 | unconditional_prompt = [unconditional_image_data["original_prompt"].iloc[0]] 183 | conditional_images = [] 184 | conditional_prompts = [] 185 | conditional_zero_shot_prompt = [] 186 | unconditional_zero_shot_prompt = [] 187 | for idx, row in df_id.iterrows(): 188 | condition = ( 189 | row["src_subsets"] 190 | if "none" in row["dst_subsets"] 191 | else row["dst_subsets"] 192 | ) 193 | conditional_images += [Image.open(row["image_path"])] 194 | conditional_prompts += [row["conditional_prompt"]] 195 | conditional_zero_shot_prompt += [ 196 | f"A picture of {condition.replace('_', ' ').replace('-', ' ')}." 197 | ] 198 | unconditional_zero_shot_prompt += [f"A picture of something."] 199 | clip_score = similarity.forward( 200 | unconditional_image, 201 | conditional_images, 202 | unconditional_prompt, 203 | conditional_prompts, 204 | unconditional_zero_shot_prompt, 205 | conditional_zero_shot_prompt, 206 | ) 207 | for k, v in clip_score.items(): 208 | df_id[k] = v 209 | results += [df_id] 210 | results = pd.concat(results) 211 | if cfg.results_dir is not None: 212 | output_path = Path(Path(__file__).stem) 213 | output_path = Path(cfg.results_dir, output_path) 214 | output_path.mkdir(exist_ok=True, parents=True) 215 | else: 216 | output_path = None 217 | if output_path is not None: 218 | results.to_csv(Path(output_path) / "clip_score.csv") 219 | if cfg.wandb.mode != "disabled": 220 | utils.log_wandb(clip_score=results) 221 | return results 222 | 223 | 224 | @hydra.main(config_path="../act/configs", config_name="config", version_base="1.3") 225 | def main(cfg: DictConfig) -> None: 226 | calculate_clip_score(cfg) 227 | 228 | 229 | if __name__ == "__main__": 230 | main() 231 | -------------------------------------------------------------------------------- /act/evaluations/evaluate_eleuther.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import logging 5 | import pickle 6 | import typing as t 7 | from pathlib import Path 8 | 9 | import hydra 10 | from lm_eval import evaluator 11 | from lm_eval.models.huggingface import HFLM 12 | from lm_eval.utils import make_table 13 | from omegaconf import DictConfig, OmegaConf 14 | 15 | # Local imports 16 | from act.models import get_model 17 | from act.models.model_with_hooks import ModelWithHooks 18 | from act.utils import utils 19 | 20 | logger = logging.getLogger(__name__) 21 | logger.setLevel(logging.INFO) 22 | 23 | 24 | def save_results(results: t.Dict, output_path: str) -> None: 25 | """ 26 | Saves results dictionary as pickle named after args used for evaluation run in cfg.output_dir (which is created if it does not exist yet) 27 | Args: 28 | results: nested dictionary containing the results to be saved 29 | args: args used for evaluation run 30 | 31 | Returns: None 32 | 33 | """ 34 | out_file = Path(output_path) / "eleuther.pkl" 35 | logger.info(f"Saving eleuther eval results to {out_file}") 36 | with out_file.open("wb") as fp: 37 | pickle.dump(results, fp) 38 | 39 | 40 | def evaluate(cfg: DictConfig) -> dict: 41 | """ 42 | 43 | Args: 44 | args: the argument namespace from the argparser 45 | 46 | Returns: a results object which is the standard output of the lm-eval-harness 47 | 48 | """ 49 | module, tokenizer = get_model( 50 | cfg.model_params.model_path, 51 | cfg.cache_dir, 52 | cfg.dtype, 53 | cfg.device, 54 | model_task="text-generation", 55 | seq_len=128, 56 | ) 57 | 58 | assert cfg.model_params.module_names is not None, logging.error( 59 | f"Intervention specified as {cfg.intervention_params.name}, but no module names passed (passed {cfg.model_params.module_names})" 60 | ) 61 | 62 | # Create hooked model 63 | model_with_hooks = ModelWithHooks( 64 | module=module, 65 | ) 66 | model_with_hooks.load_hooks_from_folder( 67 | folder=Path(cfg.intervention_params.state_path), 68 | module_names=cfg.model_params.module_names, 69 | hook_type=cfg.intervention_params.name, 70 | **cfg.intervention_params.hook_params, 71 | ) 72 | model_with_hooks.register_hooks() 73 | model = model_with_hooks.module 74 | 75 | # Convert into an "evaluable" HF model 76 | lm = HFLM(model, tokenizer=tokenizer) 77 | 78 | # Run evaluation 79 | results = evaluator.simple_evaluate( 80 | lm, 81 | tasks=list(cfg.tasks), 82 | num_fewshot=cfg.num_fewshot, 83 | limit=cfg.limit, # can set a limit for quicker testing 84 | bootstrap_iters=cfg.bootstrap_iters, # for statistical significance estimation 85 | random_seed=cfg.rs, 86 | numpy_random_seed=cfg.nrs, 87 | torch_random_seed=cfg.trs, 88 | device=cfg.device, 89 | batch_size=cfg.batch_size, 90 | cache_requests=True, 91 | ) 92 | 93 | return results 94 | 95 | 96 | def main(cfg: DictConfig) -> None: 97 | run_eleuther_eval(cfg) 98 | 99 | 100 | def run_eleuther_eval(cfg: DictConfig) -> None: 101 | # Run actual evaluations 102 | results = evaluate(cfg) 103 | 104 | # Save complete results dict as pickle 105 | if cfg.results_dir is not None: 106 | output_path = Path(Path(__file__).stem) 107 | output_path = Path(cfg.results_dir) / output_path 108 | output_path.mkdir(parents=True, exist_ok=True) 109 | else: 110 | output_path = None 111 | save_results(results, output_path) 112 | if cfg.wandb.mode != "disabled": 113 | overall_results: t.Dict = results["results"]["mmlu"] 114 | overall_results = {"mmlu-" + k: v for k, v in overall_results.items()} 115 | utils.log_wandb(overall_results) 116 | 117 | # Pop created samples for simplified output printing in console and separate logging in wandb 118 | # TODO: Save samples? 119 | samples = results.pop("samples") 120 | 121 | # Console printing of summarized results 122 | logger.info( 123 | f"{cfg.model_params.model_path}, limit: {cfg.limit}, num_fewshot: {cfg.num_fewshot}" 124 | ) 125 | logger.info(make_table(results)) 126 | if "groups" in results: 127 | logger.info(make_table(results, "groups")) 128 | 129 | 130 | @hydra.main( 131 | config_path="../act/configs", config_name="text_generation", version_base="1.3" 132 | ) 133 | def main(cfg: DictConfig) -> None: 134 | run_eleuther_eval(cfg) 135 | 136 | 137 | if __name__ == "__main__": 138 | main() 139 | -------------------------------------------------------------------------------- /act/evaluations/evaluate_perplexity.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import logging 5 | import os 6 | import typing as t 7 | from pathlib import Path 8 | from sys import platform 9 | 10 | import hydra 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | from omegaconf import DictConfig, OmegaConf 15 | 16 | from act.models import get_model 17 | from act.utils import utils 18 | from act.utils.perplexity import measure_perplexity 19 | 20 | if platform == "darwin": 21 | # Xavi: MacOS, remove dynamo errors, or test_perplexity will fail. 22 | # Xavi: This happened for transformers==4.44.2 23 | import torch._dynamo 24 | 25 | torch._dynamo.config.suppress_errors = True 26 | 27 | 28 | logging.getLogger().setLevel(logging.INFO) 29 | 30 | # Already run in parallel inside DataLoader 31 | os.environ["TOKENIZERS_PARALLELISM"] = "False" 32 | 33 | 34 | @torch.inference_mode() 35 | def evaluate(cfg: DictConfig) -> t.Dict[str, float]: 36 | if cfg.results_dir is not None: 37 | output_path = Path(Path(__file__).stem) 38 | output_path = Path(cfg.results_dir, output_path) 39 | output_path.mkdir(parents=True, exist_ok=True) 40 | else: 41 | output_path = None 42 | 43 | # Set random seed 44 | if cfg.seed: 45 | np.random.seed(cfg.seed) 46 | torch.manual_seed(cfg.seed) 47 | 48 | # Setup device and distributed learning 49 | if cfg.device in ["cuda", None] and torch.cuda.is_available(): 50 | cfg.device = "cuda" 51 | elif cfg.device == "cuda": 52 | raise (RuntimeError("Cuda not available")) 53 | elif cfg.device is None: 54 | cfg.device = "cpu" 55 | 56 | # Models and Tokenizers 57 | module, tokenizer = get_model( 58 | model_path=cfg.perplexity_model_path, 59 | cache_dir=cfg.data_dir, 60 | device=cfg.device, 61 | dtype=cfg.dtype, 62 | rand_weights=False, 63 | seq_len=cfg.seq_len, 64 | model_task="text-generation", 65 | ) 66 | # module = torch.compile(module) 67 | 68 | # Trying , and ; as delimiters. 69 | try: 70 | df = pd.read_csv(cfg.data_path, index_col=0) 71 | except: 72 | try: 73 | df = pd.read_csv(cfg.data_path, delimiter=";", index_col=0) 74 | except Exception as exc: 75 | raise RuntimeError(exc) 76 | 77 | sentences = df[cfg.column_sentences[0]].values.tolist() 78 | sentences = [s.replace(" ", "") for s in sentences] 79 | if len(cfg.column_sentences) > 1: 80 | prompts = df[cfg.column_sentences[1]].values.tolist() 81 | prompts = [s.replace(" ", "").strip() for s in prompts] 82 | else: 83 | prompts = None 84 | 85 | logging.info( 86 | f"Computing PPL with {cfg.perplexity_model_path} on {len(sentences)} sentences." 87 | ) 88 | 89 | ppl = measure_perplexity( 90 | continuations=sentences, 91 | prompts=prompts, 92 | model=module, 93 | tokenizer=tokenizer, 94 | device=cfg.device, 95 | batch_size=cfg.batch_size, 96 | autoregressive=( 97 | cfg.intervention_params.hook_params.intervention_position == "last" 98 | ), 99 | ) 100 | 101 | # Add PPL column! 102 | model_name = Path(cfg.perplexity_model_path).name 103 | col = f"ppl_{model_name}" 104 | if col in df.columns: 105 | col = col + "-v2" 106 | df[col] = ppl 107 | logging.info(f"Average ppl_{model_name}: {df[col].mean()}") 108 | 109 | logging.info(df) 110 | if output_path is not None: 111 | output_file = Path(output_path) / "model_perplexity.csv" 112 | df.to_csv(output_file) 113 | logging.info(output_file) 114 | if cfg.wandb.mode != "disabled": 115 | utils.log_wandb(model_perplexity=df) 116 | 117 | return df 118 | 119 | 120 | @hydra.main( 121 | config_path="../act/configs", config_name="text_generation", version_base="1.3" 122 | ) 123 | def main(cfg: DictConfig) -> None: 124 | evaluate(cfg) 125 | 126 | 127 | if __name__ == "__main__": 128 | main() 129 | -------------------------------------------------------------------------------- /act/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from .aura_hook import AURAHook 5 | from .identity import IdentityHook 6 | from .postprocess_and_save_hook import PostprocessAndSaveHook 7 | from .responses_hook import ResponsesHook 8 | from .return_outputs_hook import ReturnOutputsHook 9 | from .transport import GaussianOTHook, LinearOTHook, OnlyMeanHook 10 | 11 | HOOK_REGISTRY = { 12 | "postprocess_and_save": PostprocessAndSaveHook, 13 | "return_outputs": ReturnOutputsHook, 14 | "aura": AURAHook, 15 | "mean_ot": OnlyMeanHook, 16 | "gaussian_ot": GaussianOTHook, 17 | "linear_ot": LinearOTHook, 18 | "identity": IdentityHook, 19 | "none": IdentityHook, 20 | } 21 | 22 | 23 | def get_hook(name: str, *args, **kwargs) -> ResponsesHook: 24 | hook_cls = HOOK_REGISTRY[name] 25 | hook = hook_cls(*args, **kwargs) 26 | return hook 27 | -------------------------------------------------------------------------------- /act/hooks/aura_hook.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import logging 5 | import multiprocessing 6 | import typing as t 7 | 8 | import torch 9 | 10 | from act.hooks.intervention_hook import InterventionHook 11 | from act.utils.auroc import compute_auroc 12 | 13 | logger = logging.getLogger(__name__) 14 | logger.setLevel(logging.INFO) 15 | 16 | 17 | class AURAHook(InterventionHook): 18 | """Applies AuRoc Adaptation (AurA) (https://arxiv.org/abs/2407.12824) to a given module's output. 19 | 20 | AurA is an intervention technique that modifies the output of a neural network 21 | module based on the AUROC of its outputs at classifying a concept. This helps mitigating 22 | the concept by dampening each output proportionally to its AuROC. 23 | 24 | Attributes: 25 | module_name (str): The name of the module to apply AurA intervention to. 26 | device (str, optional): The device on which the model and hook reside. Defaults to None. 27 | intervention_position (str, optional): Specifies where in the module's forward pass to apply 28 | the intervention. Options are 'all', 'pre', or 'post'. Defaults to 'all'. 29 | dtype (torch.dtype, optional): The data type for tensors used by the hook. Defaults to torch.float32. 30 | strength (float, optional): Controls the intensity of AurA intervention. A value of 1.0 applies 31 | AurA fully, while 0.0 disables it. Defaults to 1.0. 32 | 33 | """ 34 | 35 | def __init__( 36 | self, 37 | module_name: str, 38 | device: str = None, 39 | intervention_position: str = "all", 40 | dtype: torch.dtype = torch.float32, 41 | strength: float = 1.0, 42 | **kwargs, 43 | ): 44 | super().__init__( 45 | module_name=module_name, 46 | device=device, 47 | intervention_position=intervention_position, 48 | dtype=dtype, 49 | ) 50 | self.strength = float(strength) 51 | self.register_buffer("auroc", torch.empty(0)) 52 | 53 | def __str__(self): 54 | txt = ( 55 | f"AurA(" 56 | f"module_name={self.module_name}, " 57 | f"strength={self.strength}" 58 | f")" 59 | ) 60 | return txt 61 | 62 | def _post_load(self) -> None: 63 | super()._post_load() 64 | # Pre-compute dampening once. 65 | self.alpha = torch.ones_like(self.auroc, dtype=self.dtype) 66 | mask = self.auroc > 0.5 67 | self.alpha[mask] = 1 - 2 * (self.auroc[mask] - 0.5) 68 | 69 | def load_state_dict( 70 | self, 71 | state_dict: t.Mapping[str, t.Any], 72 | strict: bool = True, 73 | assign: bool = False, 74 | ): 75 | self.auroc = state_dict["auroc"].to(self.device).to(self.dtype) 76 | self._post_load() 77 | 78 | def fit( 79 | self, 80 | responses: torch.Tensor, 81 | labels=torch.Tensor, 82 | pool: multiprocessing.Pool = None, 83 | **kwargs, 84 | ) -> None: 85 | logger.info(f"Computing AUROC on {responses.shape} responses ...") 86 | auroc = compute_auroc( 87 | responses=responses.detach().cpu().numpy(), 88 | labels=labels.detach().cpu().numpy(), 89 | chunk_size=10, 90 | pool=None, 91 | ) 92 | self.auroc = torch.tensor(auroc, dtype=self.dtype, device=self.device) 93 | self._post_load() 94 | 95 | def forward(self, module, input, output) -> t.Any: 96 | if output.ndim == 4: 97 | alpha = self.alpha.view(1, -1, 1, 1) 98 | elif output.ndim == 3: 99 | alpha = self.alpha.view(1, 1, -1) 100 | else: 101 | raise NotImplementedError() 102 | 103 | # Apply AurA dampening 104 | output_aura = output * alpha 105 | 106 | # Adding strength to AurA 107 | output = (1 - self.strength) * output + self.strength * output_aura 108 | return output 109 | -------------------------------------------------------------------------------- /act/hooks/custom_exceptions.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | 5 | class TargetModuleReached(Exception): 6 | """ 7 | Custom Exception to stop the model after reaching a certain module 8 | """ 9 | 10 | pass 11 | -------------------------------------------------------------------------------- /act/hooks/identity.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import torch 5 | 6 | from act.hooks.intervention_hook import InterventionHook 7 | 8 | 9 | class IdentityHook(InterventionHook): 10 | """ 11 | A "do nothing" intervention. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | module_name: str, 17 | device: str = None, 18 | intervention_position: str = "original", 19 | dtype: torch.dtype = torch.float32, 20 | **kwargs, 21 | ): 22 | super().__init__( 23 | module_name=module_name, 24 | device=device, 25 | intervention_position=intervention_position, 26 | dtype=dtype, 27 | ) 28 | 29 | def __str__(self): 30 | txt = f"Identity(" f"module_name={self.module_name}" f")" 31 | return txt 32 | 33 | def fit(self, *args, **kwargs): 34 | pass 35 | 36 | def forward(self, module, input_, output): 37 | return self(module, input, output) 38 | 39 | def __call__(self, module, input, output): 40 | return output 41 | -------------------------------------------------------------------------------- /act/hooks/intervention_hook.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import abc 5 | import typing as t 6 | from abc import ABC, abstractmethod 7 | from pathlib import Path 8 | 9 | import torch 10 | 11 | 12 | class InterventionHook(torch.nn.Module): 13 | """ 14 | Abstract base class for a hook that intervenes during the forward pass of a PyTorch module. 15 | 16 | This class allows you to specify which part of the output tensor (or all outputs) to modify and at what point in the computation this modification should occur. 17 | 18 | Args: 19 | module_name (str): The name of the module or layer where the intervention is needed. If a specific output index is required, it can be specified after a colon (e.g., "module_name:output_index"). 20 | intervention_position (str): Specifies when to intervene in the forward pass. 'all' means at every step, 'last' means only on the last element of the output tensor sequence. 21 | dtype (torch.dtype): The desired data type for the intervention. Default is torch.float32. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | module_name: str, 27 | device: str, 28 | intervention_position: t.Literal["all", "last"], 29 | dtype: torch.dtype = torch.float32, 30 | ): 31 | super().__init__() 32 | self.module_name = module_name 33 | self.device = device 34 | if ":" in module_name: 35 | self.select_output = int(module_name.split(":")[1]) 36 | else: 37 | self.select_output = None 38 | self.intervention_position = intervention_position 39 | self.dtype = dtype 40 | 41 | def register_named_buffers(self, **kwargs) -> None: 42 | for k, v in kwargs.items(): 43 | self.register_buffer(k, v.to(self.dtype)) 44 | 45 | def save_state_dict(self, state_path: Path) -> None: 46 | torch.save(self.state_dict(), state_path) 47 | 48 | def from_state_path(self, state_path: Path) -> None: 49 | """ 50 | Loads intervention state from a state path pointing to a torch-saved state_dict. 51 | 52 | :param state_path: The state path to load. 53 | """ 54 | self.load_state_dict(torch.load(state_path)) 55 | 56 | @abc.abstractmethod 57 | def fit(self, *args, **kwargs): 58 | raise NotImplementedError("Method fit() must be implemented.") 59 | 60 | def _post_load(self) -> None: 61 | """ 62 | This method should be called after loading the states of the hook. 63 | So calls must be placed at the end of .fit() and at the end of .load_state_dict(). 64 | 65 | Re-implement as needed, but do not forget to call super()._post_load() in the implementation. 66 | """ 67 | # Check all buffers are duly initialized. 68 | for buffer_name, buffer in self.named_buffers(): 69 | assert buffer.numel() > 0, f"Buffer {buffer_name} has not been initialized." 70 | 71 | def update(self, *args, **kwargs): 72 | """ 73 | Updates the state or arguments of this hook with new input data at runtime. 74 | 75 | This method can be overridden by subclasses to provide custom updating logic. By default, it does nothing and returns None. 76 | 77 | Parameters: 78 | *args : variable-length argument list 79 | Variable length argument list that will be used as is for the update operation. 80 | 81 | **kwargs : keyworded arguments 82 | Keyworded arguments that can also be used to update state or arguments of this hook. 83 | 84 | Returns: 85 | None 86 | """ 87 | return None 88 | 89 | def __call__(self, module, input_, output): 90 | """ 91 | PyTorch call method overridden to implement the intervention logic. 92 | 93 | Args: 94 | module (torch.nn.Module): The module for which the forward pass is being evaluated. 95 | input_ (tuple): Input tensors to the module. 96 | output (tuple or torch.Tensor): Output of the module's forward function. If `select_output` is specified, it will be a tuple containing this single element; otherwise, it's expected to be a tuple of outputs. 97 | 98 | Returns: 99 | The modified output after intervention. If `select_output` is specified, returns a modified version of the corresponding output in the tuple. Otherwise, returns the entire sequence of modified outputs. 100 | """ 101 | if isinstance(output, tuple) and self.select_output is not None: 102 | _output = output[self.select_output] 103 | else: 104 | _output = output 105 | original_ndim = _output.ndim 106 | 107 | if original_ndim == 2: 108 | _output = _output[:, None, :] 109 | 110 | if self.intervention_position == "last": 111 | if len(_output.shape) == 3: 112 | __output = _output[:, -1, None, ...] 113 | else: 114 | __output = _output 115 | 116 | dtype = __output.dtype 117 | device = __output.device 118 | __output = __output.to(dtype=self.dtype, device=self.device) 119 | __output = self.forward(module, input_, __output) 120 | __output = __output.to(dtype=dtype, device=device) 121 | 122 | if self.intervention_position == "last": 123 | if len(_output.shape) == 3: 124 | _output[:, -1, ...] = __output[:, 0, ...] 125 | else: 126 | _output = __output 127 | 128 | if original_ndim == 2: 129 | _output = _output[:, 0, :] 130 | 131 | if isinstance(output, tuple) and self.select_output is not None: 132 | output = list(output) 133 | output[self.select_output] = _output 134 | output = tuple(output) 135 | else: 136 | output = _output 137 | 138 | return output 139 | 140 | @abc.abstractmethod 141 | def forward(self, module, input_, output): 142 | """ 143 | Abstract method to be implemented by subclasses. This method defines the logic for how the intervention should modify the output. 144 | 145 | Args: 146 | module (torch.nn.Module): The module for which the forward pass is being evaluated. 147 | input_ (tuple): Input tensors to the module. 148 | output (torch.Tensor): Output tensor of the module's forward function, modified according to the intervention logic. 149 | 150 | Returns: 151 | A modified version of the output tensor after applying the intervention logic. 152 | """ 153 | raise NotImplementedError("Subclasses must implement this method.") 154 | -------------------------------------------------------------------------------- /act/hooks/pooling_ops.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import typing as t 5 | from functools import partial 6 | 7 | import torch 8 | 9 | 10 | def nanmax(tensor, dim=None, keepdim=False): 11 | """ 12 | This function takes a tensor and along with a dimension 'dim', and a boolean flag 'keepdim'. 13 | It returns another tensor where for each 'dim' the values are replaced with maximum value in that dimension. 14 | If 'tensor' has any NaNs, it will return infinity instead of NaN. 15 | 16 | Parameters: 17 | tensor (Tensor): Input tensor from which to compute max 18 | dim (int or None): Dimension along which the maximum is computed 19 | keepdim (bool): Determines whether the output tensors have 'dim' retained or not 20 | 21 | Returns: 22 | Tensor: The resultant tensor after applying nanmax 23 | """ 24 | min_value = torch.finfo(tensor.dtype).min 25 | output = tensor.nan_to_num(min_value).amax(dim=dim, keepdim=keepdim) 26 | return output 27 | 28 | 29 | def nanmin(tensor, dim=None, keepdim=False): 30 | """ 31 | Compute the minimum of tensor elements along a specified axis. 32 | 33 | Parameters: 34 | tensor (Tensor): Input Tensor. 35 | dim (int or tuple of ints, optional): Dimensions to reduce along. Default is None, which will return the minimum over all elements. 36 | keepdim (bool, optional): If True, retains reduced dimensions with length 1. Default is False. 37 | 38 | Returns: 39 | Tensor: The minimum value along the specified dimension(s). 40 | 41 | Note: 42 | This function ignores NaN values and finds the minimum among non-NaN elements. 43 | """ 44 | max_value = torch.finfo(tensor.dtype).max 45 | output = tensor.nan_to_num(max_value).amin(dim=dim, keepdim=keepdim) 46 | return output 47 | 48 | 49 | class TorchPoolingOP(torch.nn.Module): 50 | """ 51 | A module that applies a pooling operation on input tensor along given dimension. 52 | 53 | Parameters: 54 | op_name (str): Name of the pooling function to be used, from BASE_POOLING_FUNCTIONS. 55 | dim (int): Dimension along which the operation is performed. 56 | 57 | Attributes: 58 | name (str): The name of the pooling function being applied. 59 | dim (int): The dimension along which the operation is performed. 60 | op (function): The actual pooling function to be used. 61 | 62 | """ 63 | 64 | TORCH_POOLING_FUNCTIONS = { 65 | "min": nanmin, 66 | "max": nanmax, 67 | "mean": torch.nanmean, 68 | "median": torch.nanmedian, 69 | "last": partial(torch.select, index=-1), # equivalent to array[-1] 70 | "all": lambda x, *args, **kwargs: x, 71 | } 72 | 73 | def __init__(self, op_name: str, dim: t.Union[int, str]): 74 | super().__init__() 75 | self.name = op_name 76 | self.dim = dim 77 | self.op = self.TORCH_POOLING_FUNCTIONS[self.name] 78 | 79 | def forward( 80 | self, 81 | tensor: torch.Tensor, 82 | attention_mask: torch.Tensor = None, 83 | **kwargs, 84 | ) -> torch.Tensor: 85 | """ 86 | Applies the pooling operation on input tensor along given dimension. 87 | 88 | Parameters: 89 | tensor (torch.Tensor): The input tensor to which the operation is applied. 90 | 91 | Returns: 92 | torch.Tensor: Result of applying the pooling function on the input tensor, 93 | along specified dimension. 94 | 95 | """ 96 | tensor_to_op = tensor 97 | if self.name == "all": 98 | return tensor 99 | if attention_mask is not None: 100 | attention_mask = attention_mask.bool() 101 | # assert ( 102 | # attention_mask[:, 0].all() == True 103 | # ), "Attention mask contains 0s at the end while assuming right padding." 104 | # nans will be ignored (used w/ attention mask) 105 | tensor_to_op[~attention_mask] = torch.nan 106 | if self.dim == "auto": 107 | if len(tensor.shape) == 2: # Single token ops are directly returned 108 | return tensor 109 | elif len(tensor.shape) == 3: 110 | dim = 1 111 | elif len(tensor.shape) == 4: 112 | dim = (2, 3) 113 | else: 114 | raise RuntimeError( 115 | f"Tensor shape {tensor.shape} not supported in pooling op auto mode." 116 | ) 117 | else: 118 | dim = 1 119 | ret = self.op(tensor_to_op, dim=dim) 120 | assert not torch.any(ret != ret), "NaNs or inf in output of pooling op." 121 | return ret 122 | 123 | 124 | POOLING_FUNCTIONS_REGISTRY = { 125 | "min": TorchPoolingOP, 126 | "max": TorchPoolingOP, 127 | "mean": TorchPoolingOP, 128 | "median": TorchPoolingOP, 129 | "std": TorchPoolingOP, 130 | "last": TorchPoolingOP, # equivalent to array[-1] 131 | "all": TorchPoolingOP, 132 | } 133 | 134 | 135 | def get_pooling_op(pooling_op_name: str, dim: t.Union[int, str]): 136 | """ 137 | Returns a pooling operation based on the provided name and dimension. 138 | 139 | Parameters: 140 | pooling_op_name (str): The name of the pooling operation to be returned. 141 | dim (int): The dimension along which the pooling will be performed. 142 | 143 | Returns: 144 | A callable object representing the desired pooling function. 145 | 146 | Raises: 147 | KeyError: If an invalid `pooling_op_name` is provided. 148 | 149 | Note: 150 | This function relies on a global registry of available pooling functions (POOLING_FUNCTIONS_REGISTRY). 151 | The specifics of this dictionary are not included in the docstring for brevity, but should be consulted if you want to use this function. 152 | """ 153 | 154 | return POOLING_FUNCTIONS_REGISTRY[pooling_op_name](op_name=pooling_op_name, dim=dim) 155 | -------------------------------------------------------------------------------- /act/hooks/postprocess_and_save_hook.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import logging 5 | import os 6 | import threading 7 | import typing as t 8 | from collections import defaultdict 9 | from copy import deepcopy 10 | from pathlib import Path 11 | 12 | import torch 13 | 14 | from act.hooks.custom_exceptions import TargetModuleReached 15 | from act.hooks.responses_hook import ResponsesHook 16 | 17 | from .pooling_ops import get_pooling_op 18 | 19 | 20 | class PostprocessAndSaveHook(ResponsesHook): 21 | """ 22 | A PyTorch module to post-process and save outputs of other modules. 23 | 24 | This hook takes as input the output of another module in a model, applies some 25 | operations on it (like pooling), optionally saves it somewhere, and if required, 26 | returns the processed output back for further use. Since it is a pytorch 27 | module, it supports loading and saving state dicts 28 | 29 | Parameters: 30 | module_name : str 31 | Name of the parent module whose outputs are being hooked to. 32 | 33 | pooling_op_names : list[str] 34 | List of pooling operation names to be applied on the input data, e.g., ['max', 'mean']. 35 | 36 | output_path : Path 37 | The location where the processed outputs (if any) should be saved or not (None). 38 | 39 | save_fields : list[str] 40 | Fields of interest in the inputs and/or outputs to be stored for later use, e.g., ['features', 'labels']. 41 | 42 | return_outputs : bool, optional 43 | If True, returns the processed output back from this hooked module for further use. Default is False. 44 | 45 | threaded : bool, optional 46 | If True, the hook will be launched in a different thread. 47 | 48 | raise_exception : bool, optional 49 | If True, raises an exception if there are issues saving outputs to disk. Default is False. 50 | 51 | """ 52 | 53 | def __init__( 54 | self, 55 | module_name: str, 56 | pooling_op_names: t.List[str], 57 | output_path: t.Optional[Path], 58 | save_fields: t.List[str], 59 | return_outputs: bool = False, 60 | raise_exception: bool = False, 61 | threaded: bool = True, 62 | **kwargs, 63 | ): 64 | super().__init__() 65 | self.module_name = module_name 66 | # Storing as modules to allow stateful ops 67 | # these are applied independently, not one after the other 68 | self.pooling_ops = torch.nn.ModuleList( 69 | [get_pooling_op(name, dim="auto") for name in pooling_op_names] 70 | ) 71 | self.output_path = output_path 72 | self.save_fields = save_fields 73 | self.return_outputs = return_outputs 74 | self.raise_exception = raise_exception 75 | self.batch_idx = None 76 | self.attention_mask = None 77 | self.threaded = threaded 78 | self.thread_handle = None 79 | 80 | def __str__(self): 81 | txt = ( 82 | f"PostprocessAndSaveHook(module_name={self.module_name}, pooling_ops={self.pooling_ops}, " 83 | f"output_path={self.output_path}, raise_exception={self.raise_exception})\n" 84 | ) 85 | txt += super().__str__() 86 | return txt 87 | 88 | def update( 89 | self, 90 | batch_idx: int, 91 | batch: dict, 92 | ) -> None: 93 | """ 94 | Updates the state of this hook with new input data. 95 | 96 | This includes setting the current batch index and updating the inputs, 97 | which are then processed by the pooling operations defined in __init__(). 98 | 99 | Parameters: 100 | batch_idx : int 101 | The index of the current mini-batch in a full epoch or dataset. 102 | 103 | batch : dict 104 | A dictionary containing the input data for this hooked module, e.g., features and labels. 105 | 106 | Returns: 107 | None 108 | """ 109 | assert "id" in batch 110 | self.batch_idx = batch_idx 111 | self.batch = batch 112 | self.outputs = defaultdict(list) 113 | 114 | def save(self, module_name: str, output: t.List[dict]) -> None: 115 | """ 116 | Applies pooling operations on input data and saves them to disk or optionally returns them. 117 | 118 | The processed outputs are saved in torch pickle format at the specified location, with each file named after a specific 119 | combination of module_name and pooling operation name. These files can later be loaded back into memory for further use. 120 | 121 | Parameters: 122 | module_name : str 123 | Name of the parent module whose outputs are being hooked to. 124 | 125 | output : list[dict] 126 | The processed output from the parent module after applying pooling operations. 127 | 128 | Returns: 129 | None 130 | """ 131 | attention_mask = self.batch.get("attention_mask", None) 132 | 133 | if "unet" in module_name: 134 | output = output.to(torch.float32) # got some infs 135 | if len(self.batch["id"]) < output.shape[0]: 136 | output = output.chunk(2)[1] 137 | 138 | for pooling_op in self.pooling_ops: 139 | pooled_output = pooling_op( 140 | output.detach().clone(), attention_mask=attention_mask 141 | ) 142 | 143 | for sample_index in range(len(pooled_output)): 144 | datum = {} 145 | sample_id = self.batch["id"][sample_index] 146 | sample_outputs = pooled_output[sample_index].cpu() 147 | for field in self.save_fields: 148 | datum[field] = self.batch[field][sample_index] 149 | datum.update( 150 | { 151 | "responses": sample_outputs.cpu(), 152 | } 153 | ) 154 | subset = self.batch["subset"][sample_index] 155 | if self.output_path is not None: 156 | output_path = ( 157 | self.output_path 158 | / subset 159 | / module_name 160 | / pooling_op.name 161 | / f"{sample_id}.pt" 162 | ) 163 | os.makedirs(output_path.parent, exist_ok=True) 164 | torch.save(datum, output_path) 165 | if self.return_outputs: 166 | self.outputs[module_name].append(datum) 167 | 168 | def __call__(self, module, input, output) -> None: 169 | """ 170 | Called when this hooked module's output changes. 171 | 172 | This method applies pooling operations to the inputs and saves them either on disk or optionally returns them. 173 | If `raise_exception` is set to True in the initialization, an exception will be raised if the parent module's name 174 | matches that specified during initialization. 175 | 176 | Parameters: 177 | module : torch.nn.Module 178 | The module whose output has changed. 179 | 180 | input : tuple or torch.Tensor or dict of them 181 | Input to this module. 182 | 183 | output : torch.Tensor 184 | Output from this module. 185 | 186 | Returns: 187 | None 188 | """ 189 | assert ( 190 | self.batch_idx is not None 191 | ), "update() must be called before executing the hook" 192 | 193 | def _hook(module_name: str, output: t.Any): 194 | if isinstance(output, torch.Tensor): 195 | if ":" not in module_name: 196 | module_name = f"{module_name}:0" 197 | self.save(module_name, output.detach()) 198 | elif isinstance(output, (list, tuple)): 199 | if ":" in module_name: 200 | name, idx = module_name.split(":") 201 | _hook(module_name, output[int(idx)]) 202 | else: 203 | for idx in range(len(output)): 204 | _hook(f"{module_name}:{idx}", output[idx]) 205 | 206 | else: 207 | logging.warn(f"Found {type(output)} in {self.module_name}") 208 | 209 | if self.threaded: 210 | self.thread_handle = threading.Thread( 211 | target=_hook, args=(self.module_name, output) 212 | ) 213 | self.thread_handle.start() 214 | else: 215 | _hook(self.module_name, output) 216 | 217 | if self.raise_exception: 218 | raise TargetModuleReached(self.module_name) 219 | -------------------------------------------------------------------------------- /act/hooks/responses_hook.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import torch 5 | 6 | 7 | class ResponsesHook(torch.nn.Module): 8 | """ 9 | A base class that provides a structure to define hooks in PyTorch models. 10 | 11 | The `__call__` method is the main entry point for any custom Hook, and it needs to be implemented by subclasses. 12 | This method should contain the logic that will trigger when the hooked module's output changes. 13 | 14 | The `update` method is optional and can provide a way to update state or arguments of this hook with new input data 15 | during runtime, but it's not required for all Hooks. 16 | 17 | """ 18 | 19 | def __call__(self, module, input, output): 20 | raise NotImplementedError 21 | 22 | def register_named_buffers(self, **kwargs) -> None: 23 | for k, v in kwargs.items(): 24 | self.register_buffer(k, v.to(self.device)) 25 | 26 | def update(self, *args, **kwargs): 27 | """ 28 | Updates the state or arguments of this hook with new input data at runtime. 29 | 30 | This method can be overridden by subclasses to provide custom updating logic. By default, it does nothing and returns None. 31 | 32 | Parameters: 33 | *args : variable-length argument list 34 | Variable length argument list that will be used as is for the update operation. 35 | 36 | **kwargs : keyworded arguments 37 | Keyworded arguments that can also be used to update state or arguments of this hook. 38 | 39 | Returns: 40 | None 41 | """ 42 | return None 43 | 44 | def get_thread_handle(self): 45 | """ 46 | Get thread handle for current instance of class. If no thread has been set or if it does not exist, return None. 47 | 48 | Returns: 49 | The `thread_handle` attribute of the instance if it exists; otherwise, returns `None`. 50 | """ 51 | if hasattr(self, "thread_handle"): 52 | return self.thread_handle 53 | else: 54 | return None 55 | 56 | def join(self) -> None: 57 | """ 58 | Blocks execution of the main thread until all threads started by this instance are done. 59 | 60 | This is useful to ensure that all spawned threads have finished their tasks before the main program continues. 61 | Without calling `join`, it's possible for your program to exit before all background tasks finish. 62 | 63 | If there are no spawned threads, this method will return immediately. 64 | 65 | Returns: 66 | None 67 | """ 68 | if self.get_thread_handle() is not None: 69 | self.thread_handle.join() 70 | self.thread_handle = None 71 | -------------------------------------------------------------------------------- /act/hooks/return_outputs_hook.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import logging 5 | import typing as t 6 | 7 | import torch 8 | 9 | from act.hooks.custom_exceptions import TargetModuleReached 10 | from act.hooks.responses_hook import ResponsesHook 11 | 12 | 13 | class ReturnOutputsHook(ResponsesHook): 14 | """ 15 | A PyTorch module that captures outputs of a specified submodule during forward pass. 16 | If `raise_exception` flag is set to True, it raises TargetModuleReached exception when the target module reached. 17 | 18 | Attributes: 19 | module_name (str): The name of the module for which you want to capture outputs. 20 | raise_exception (bool): Flag indicating whether to raise an exception or not. Default is False. 21 | outputs (dict): Dictionary storing names and corresponding tensors of submodules during forward pass. 22 | """ 23 | 24 | def __init__(self, module_name: str, raise_exception: bool = False): 25 | super().__init__() 26 | self.module_name = module_name 27 | self.raise_exception = raise_exception 28 | self.outputs = {} 29 | 30 | def __call__(self, module, input, output): 31 | def _hook(module_name: str, output: t.Any): 32 | if isinstance(output, torch.Tensor): 33 | self.outputs[module_name] = output.detach() 34 | elif isinstance(output, (list, tuple)): 35 | for idx in range(len(output)): 36 | _hook(f"{module_name}:{idx}", output[idx]) 37 | else: 38 | logging.warn(f"Found {type(output)} in {self.module_name}") 39 | 40 | _hook(self.module_name, output) 41 | 42 | if self.raise_exception: 43 | raise TargetModuleReached(self.module_name) 44 | -------------------------------------------------------------------------------- /act/optimal_transport/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /act/optimal_transport/archs.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import functools 5 | import logging 6 | import time 7 | import typing as t 8 | 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | 13 | 14 | class LinearProj(nn.Module): 15 | def __init__( 16 | self, 17 | dim: int, 18 | ): 19 | super().__init__() 20 | self.w1 = nn.Parameter(torch.randn(1, dim)) 21 | self.b1 = nn.Parameter(torch.zeros((1, dim))) 22 | 23 | def forward(self, x: torch.Tensor, reverse: bool = False): 24 | assert x.shape[-1] == self.w1.shape[-1] 25 | if not reverse: 26 | return x * self.w1 + self.b1 27 | else: 28 | return (x - self.b1) / (self.w1 + 1e-10) 29 | 30 | def optimize(self, x: np.ndarray, y: np.ndarray) -> t.Tuple[np.ndarray, t.Dict]: 31 | x, y = x.astype(np.float64), y.astype(np.float64) 32 | 33 | m_x = np.mean(x, axis=0, keepdims=True) 34 | m_y = np.mean(y, axis=0, keepdims=True) 35 | 36 | # Add small noise to prevent divisions by 0 37 | x += 1e-8 * np.random.randn(*x.shape) 38 | 39 | x_bar = x - m_x 40 | y_bar = y - m_y 41 | beta = np.sum((x_bar * y_bar), axis=0, keepdims=True) / np.sum( 42 | (x_bar**2), axis=0, keepdims=True 43 | ) 44 | alpha = m_y - beta * m_x 45 | params = np.concatenate([beta, alpha], 0) 46 | beta = torch.tensor(beta, dtype=self.w1.dtype, device=self.w1.device) 47 | alpha = torch.tensor(alpha, dtype=self.w1.dtype, device=self.w1.device) 48 | self.load_state_dict( 49 | { 50 | "w1": beta, 51 | "b1": alpha, 52 | } 53 | ) 54 | return params, {} 55 | -------------------------------------------------------------------------------- /act/optimal_transport/ot_maps.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import typing as t 5 | 6 | import torch 7 | 8 | 9 | def solve_ot_1d( 10 | p: torch.Tensor, q: torch.Tensor 11 | ) -> t.Tuple[torch.Tensor, torch.Tensor]: 12 | """ 13 | OT 1D for same number of points amounts to sorting. 14 | """ 15 | assert len(p) == len(q), ( 16 | f"Very simple 1D OT matching for now. " 17 | f"Please use the same number of samples for p, q." 18 | ) 19 | p_sort, _ = torch.sort(p, 0) 20 | q_sort, _ = torch.sort(q, 0) 21 | return p_sort, q_sort 22 | -------------------------------------------------------------------------------- /act/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /act/scripts/download_external_data.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | import io 4 | import os 5 | import zipfile 6 | from pathlib import Path 7 | 8 | import requests 9 | from tqdm import tqdm 10 | 11 | 12 | def download_file(url: str, filename: Path) -> None: 13 | """ 14 | Downlaods a file and saves it as `filename` 15 | 16 | :param url: The url of the zip file 17 | :param filename: Destination file 18 | """ 19 | response = requests.get(url, stream=True) 20 | total_size = int(response.headers.get("content-length", 0)) 21 | if response.status_code == 200: 22 | with open(filename, "wb") as file, tqdm( 23 | desc="Downloading", 24 | total=total_size, 25 | unit="B", 26 | unit_scale=True, 27 | unit_divisor=1024, 28 | ) as progress_bar: 29 | for chunk in response.iter_content(1024): 30 | file.write(chunk) 31 | progress_bar.update(len(chunk)) 32 | print(f"File downloaded as '{filename}'") 33 | else: 34 | print("Failed to download file") 35 | 36 | 37 | def download_and_extract_zip(url, extract_to="."): 38 | # Step 1: Send a GET request to start the download 39 | response = requests.get(url, stream=True) 40 | 41 | # Step 2: Check if the request was successful 42 | if response.status_code == 200: 43 | # Get the total file size from the headers, if available 44 | total_size = int(response.headers.get("content-length", 0)) 45 | 46 | # Step 3: Initialize the progress bar 47 | with tqdm( 48 | total=total_size, unit="B", unit_scale=True, desc="Downloading" 49 | ) as progress_bar: 50 | # Step 4: Download the file in chunks and update the progress bar 51 | file_bytes = io.BytesIO() 52 | for chunk in response.iter_content(chunk_size=1024): 53 | file_bytes.write(chunk) 54 | progress_bar.update(len(chunk)) 55 | 56 | # Step 5: Load the zip file from memory and extract 57 | zip_file = zipfile.ZipFile(file_bytes) 58 | zip_file.extractall(path=extract_to) 59 | print(f"Files extracted to {extract_to}") 60 | else: 61 | print("Failed to download the file:", response.status_code) 62 | 63 | 64 | if __name__ == "__main__": 65 | data_dir = Path(os.environ.get("DATA_DIR", "/mnt/cache")) 66 | 67 | # Coco captions 68 | url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip" 69 | download_and_extract_zip(url, data_dir / "coco_captions_2017") 70 | os.system( 71 | f"mv {data_dir / 'coco_captions_2017/annotations/*'} {data_dir / 'coco_captions_2017'}" 72 | ) 73 | os.system(f"rm -rf {data_dir / 'coco_captions_2017/annotations'}") 74 | 75 | # RTP 76 | url = "https://raw.githubusercontent.com/alisawuffles/DExperts/refs/heads/main/generations/toy_prompt/gpt2/prompted_gens_gpt2.jsonl" 77 | download_file(url, data_dir / "prompted_gens_gpt2.jsonl") 78 | 79 | # Jigsaw 80 | (data_dir / "jigsaw").mkdir(exist_ok=True, parents=True) 81 | url = "https://huggingface.co/datasets/dirtycomputer/Toxic_Comment_Classification_Challenge/resolve/main/train.csv" 82 | download_file(url, data_dir / "jigsaw/train.csv") 83 | url = "https://huggingface.co/datasets/dirtycomputer/Toxic_Comment_Classification_Challenge/resolve/main/test.csv" 84 | download_file(url, data_dir / "jigsaw/test.csv") 85 | url = "https://huggingface.co/datasets/dirtycomputer/Toxic_Comment_Classification_Challenge/resolve/main/test_labels.csv" 86 | download_file(url, data_dir / "jigsaw/test_labels.csv") 87 | -------------------------------------------------------------------------------- /act/scripts/generate_with_hooks.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | # Loads a model and a dataset and extracts intermediate responses 5 | import functools 6 | import logging 7 | import os 8 | import typing as t 9 | from pathlib import Path 10 | 11 | import hydra 12 | import numpy as np 13 | import pandas as pd 14 | import tqdm 15 | from omegaconf import DictConfig 16 | from transformers import pipeline, set_seed 17 | 18 | from act.models import get_model 19 | from act.models.model_with_hooks import ModelWithHooks 20 | from act.utils import utils 21 | 22 | logger = logging.getLogger(__name__) 23 | logger.setLevel(logging.INFO) 24 | 25 | # Already run in parallel inside DataLoader 26 | os.environ["TOKENIZERS_PARALLELISM"] = "False" 27 | 28 | # Taken from ActAdd (https://colab.research.google.com/drive/1X2ZfC4y8Jx8FbkR7m-bLi8Ifrq-8MPTO#scrollTo=uDRWo4_xOH3A&line=11&uniqifier=1) 29 | SAMPLING_KWARGS = dict(temperature=1.0, top_p=0.3, repetition_penalty=1.2) 30 | 31 | from torch.utils.data import DataLoader, Dataset 32 | 33 | 34 | # Custom Dataset class 35 | class TextDataset(Dataset): 36 | """A PyTorch Dataset class for loading text data from a file. 37 | 38 | This class reads text sentences from a given file, cleans them up by stripping whitespace, 39 | and allows for limiting the number of sentences loaded. 40 | 41 | Attributes: 42 | file_path (str): The path to the file containing the text sentences. 43 | max_sentences (int, optional): The maximum number of sentences to load. If None, all sentences are loaded. 44 | 45 | """ 46 | 47 | def __init__(self, file_path: str, max_sentences: int = None): 48 | """Initializes TextDataset. 49 | 50 | Args: 51 | file_path (str): The path to the text file containing sentences. 52 | max_sentences (int, optional): The maximum number of sentences to load. If None, all sentences are loaded. 53 | 54 | """ 55 | # Read the file and store sentences 56 | with open(file_path, "r") as f: 57 | self.sentences = f.readlines() 58 | 59 | if max_sentences and max_sentences < len(self.sentences): 60 | self.sentences = self.sentences[:max_sentences] 61 | 62 | # Clean up the sentences by stripping leading/trailing whitespace 63 | self.sentences = [ 64 | sentence.strip() for sentence in self.sentences if sentence.strip() 65 | ] 66 | 67 | def __len__(self): 68 | return len(self.sentences) 69 | 70 | def __getitem__(self, idx): 71 | # Return the sentence at the given index 72 | return self.sentences[idx] 73 | 74 | 75 | # Helper function to create DataLoader 76 | def create_dataloader( 77 | file_path, batch_size=1, max_sentences=None, shuffle=False, num_workers=0 78 | ): 79 | dataset = TextDataset(file_path, max_sentences=max_sentences) 80 | dataloader = DataLoader( 81 | dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers 82 | ) 83 | return dataloader 84 | 85 | 86 | def print_generated_sentences(output: t.List[t.Dict[str, str]]) -> None: 87 | for o in output: 88 | print("-" * 80) 89 | print(o[0]["generated_text"]) 90 | 91 | 92 | def generate(cfg: DictConfig) -> None: 93 | """Generates text using a pretrained language model with optional interventions. 94 | 95 | This function generates text using a pretrained language model specified in the 96 | configuration (`cfg`). It allows for applying interventions during the generation 97 | process, controlled by parameters defined in `cfg.intervention_params`. The generated 98 | text is then logged and optionally saved to a CSV file. 99 | 100 | Args: 101 | cfg (DictConfig): A Hydra configuration object containing all necessary 102 | parameters for text generation and intervention. See the 103 | `configs/text_generation.yaml` example for details. 104 | 105 | Raises: 106 | ValueError: If `cfg.prompt` ends with ".txt" but the file does not exist. 107 | 108 | Returns: 109 | None 110 | """ 111 | if cfg.verbose == 1: 112 | logging.basicConfig(level=logging.INFO) 113 | elif cfg.verbose >= 2: 114 | logging.basicConfig(level=logging.DEBUG) 115 | 116 | if cfg.results_dir is not None: 117 | output_path = Path(Path(__file__).stem) 118 | output_path = cfg.results_dir / output_path 119 | output_path.mkdir(parents=True, exist_ok=True) 120 | else: 121 | output_path = None 122 | 123 | model, tokenizer = get_model( 124 | cache_dir=cfg.cache_dir, 125 | device=cfg.device, 126 | model_task="text-generation", 127 | **cfg.model_params, 128 | ) 129 | # Create hooked model 130 | model_hooks = ModelWithHooks( 131 | module=model, 132 | ) 133 | 134 | results = [] 135 | for strength in np.linspace( 136 | cfg.min_strength, cfg.max_strength, cfg.strength_sample_size 137 | ): 138 | model_hooks.remove_hooks() 139 | hook_params = dict(cfg.intervention_params.hook_params) 140 | hook_params["strength"] = strength 141 | model_hooks.load_hooks_from_folder( 142 | folder=Path(cfg.intervention_params.state_path), 143 | module_names=cfg.model_params.module_names, 144 | hook_type=cfg.intervention_params.name, 145 | **hook_params, 146 | ) 147 | 148 | # Generate without hooks 149 | generator = pipeline( 150 | "text-generation", 151 | model=model_hooks.module, 152 | tokenizer=tokenizer, 153 | ) 154 | 155 | generate_fn = functools.partial( 156 | generator, 157 | max_new_tokens=cfg.new_seq_len, 158 | do_sample=True, 159 | **SAMPLING_KWARGS, 160 | ) 161 | 162 | # Register hooks 163 | model_hooks.register_hooks() 164 | 165 | # Generate with hooks 166 | set_seed(cfg.seed) 167 | batch_size = min(cfg.batch_size, cfg.num_sentences) 168 | if cfg.prompt.endswith(".txt"): 169 | prompt_loader = create_dataloader( 170 | cfg.prompt, 171 | batch_size=batch_size, 172 | max_sentences=cfg.num_sentences, 173 | ) 174 | else: 175 | assert ( 176 | len(cfg.prompt) > 0 177 | ), "This script does not handle empty prompts for now." 178 | batch_sizes = ( 179 | [ 180 | len(batch_indices) 181 | for batch_indices in np.array_split( 182 | np.arange(cfg.num_sentences), cfg.num_sentences / batch_size 183 | ) 184 | ] 185 | if cfg.num_sentences >= batch_size 186 | else [ 187 | cfg.num_sentences, 188 | ] 189 | ) 190 | prompt_loader = [[cfg.prompt] * bs for bs in batch_sizes] 191 | 192 | decoded_hook = [] 193 | for prompts in tqdm.tqdm(prompt_loader, desc=f"Generation {strength:0.2f}"): 194 | gen = generate_fn(prompts, num_return_sequences=1) 195 | decoded_hook.extend(gen) 196 | 197 | print("\n") 198 | logger.info("With hook") 199 | logger.info("=========") 200 | print_generated_sentences(decoded_hook[:10]) 201 | model_hooks.remove_hooks() 202 | 203 | for d in decoded_hook: 204 | gen_without_prompt = d[0]["generated_text"].replace(cfg.prompt, "") 205 | results.append([strength, cfg.prompt, gen_without_prompt]) 206 | 207 | if output_path is not None: 208 | df = pd.DataFrame(data=results, columns=["strength", "prompt", "generation"]) 209 | df.to_csv(output_path / "text_generation.csv") 210 | 211 | if cfg.wandb.mode != "disabled": 212 | utils.log_wandb(text_generation=df) 213 | 214 | 215 | @hydra.main(config_path="../configs", config_name="text_generation", version_base="1.3") 216 | def main(cfg: DictConfig) -> None: 217 | generate(cfg) 218 | 219 | 220 | if __name__ == "__main__": 221 | main() 222 | -------------------------------------------------------------------------------- /act/scripts/pipeline.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | """ 5 | This script defines a pipeline for training and evaluating generative models with interventions. 6 | It handles loading configurations, managing responses, and learning interventions on specified modules. 7 | 8 | The pipeline leverages Hydra for configuration management and provides both "atonce" and "incr" modes for intervention learning. 9 | "atonce" mode learns interventions on all modules simultaneously, while "incr" mode allows for incremental learning, leveraging previously learned interventions when training on subsequent modules. 10 | 11 | """ 12 | 13 | import logging 14 | 15 | import hydra 16 | from omegaconf import DictConfig, OmegaConf 17 | 18 | from act.evaluations import ( 19 | calculate_clip_score, 20 | evaluate_0shot, 21 | evaluate_eleuther, 22 | evaluate_perplexity, 23 | evaluate_toxicity, 24 | ) 25 | from act.scripts import generate_with_hooks, generate_with_hooks_diffusion 26 | from act.scripts.learn_intervention import InterventionsManager, learn_intervention 27 | from act.utils import utils 28 | 29 | logger = logging.getLogger(__name__) 30 | logger.setLevel(logging.INFO) 31 | 32 | 33 | def _rtp(cfg: DictConfig): 34 | intervention_state_path = InterventionsManager.get_output_path(cfg.interventions) 35 | cfg.rtp.intervention_params.state_path = intervention_state_path 36 | if cfg.rtp.fast: 37 | cfg.rtp.ppl_sentences = 3 38 | cfg.rtp.rtp_sentences = 3 39 | 40 | evaluate_toxicity.measure_toxicity(cfg.rtp) 41 | 42 | 43 | def _text_generation(cfg: DictConfig): 44 | intervention_state_path = InterventionsManager.get_output_path(cfg.interventions) 45 | cfg.text_generation.intervention_params.state_path = intervention_state_path 46 | if cfg.text_generation.fast: 47 | cfg.text_generation.new_seq_len = 10 48 | cfg.text_generation.num_sentences = 3 49 | cfg.text_generation.max_strength = 1 50 | cfg.text_generation.strength_sample_size = 2 51 | 52 | generate_with_hooks.generate(cfg.text_generation) 53 | 54 | 55 | def _text_to_image_generation(cfg: DictConfig): 56 | intervention_state_path = InterventionsManager.get_output_path(cfg.interventions) 57 | cfg.text_to_image_generation.intervention_params.state_path = ( 58 | intervention_state_path 59 | ) 60 | 61 | generate_with_hooks_diffusion.generate(cfg.text_to_image_generation) 62 | 63 | 64 | def _zero_shot(cfg: DictConfig): 65 | evaluate_0shot.evaluate(cfg.zero_shot) 66 | 67 | 68 | def _mmlu(cfg: DictConfig): 69 | intervention_state_path = InterventionsManager.get_output_path(cfg.interventions) 70 | cfg.mmlu.intervention_params.state_path = intervention_state_path 71 | if cfg.mmlu.fast: 72 | cfg.mmlu.limit = 10 73 | cfg.mmlu.bootstrap_iters = 2 74 | evaluate_eleuther.run_eleuther_eval(cfg.mmlu) 75 | 76 | 77 | def _model_perplexity(cfg: DictConfig): 78 | if cfg.model_perplexity.fast: 79 | cfg.model_perplexity.perplexity_model_path = "EleutherAI/pythia-70m" 80 | evaluate_perplexity.evaluate(cfg.model_perplexity) 81 | 82 | 83 | def _clip_score(cfg: DictConfig): 84 | calculate_clip_score.calculate_clip_score(cfg.clip_score) 85 | 86 | 87 | EVAL_REGISTRY = { 88 | "rtp": _rtp, 89 | "text-generation": _text_generation, 90 | "zero_shot": _zero_shot, 91 | "mmlu": _mmlu, 92 | "model_perplexity": _model_perplexity, 93 | "text-to-image-generation": _text_to_image_generation, 94 | "clip_score": _clip_score, 95 | } 96 | 97 | 98 | @hydra.main(config_path="../configs", config_name="text_generation", version_base="1.3") 99 | def main(cfg: DictConfig) -> None: 100 | logger.info(cfg) 101 | 102 | wandb_run = utils.setup_wandb(cfg) 103 | 104 | # Learn intervention first, which includes computing responses. 105 | learn_intervention(cfg) 106 | 107 | # Now evaluate the intervention. 108 | for eval in cfg.evaluation: 109 | logger.info(f"Running {eval}") 110 | EVAL_REGISTRY[eval](cfg) 111 | 112 | if wandb_run is not None: 113 | wandb_run.finish() 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | -------------------------------------------------------------------------------- /act/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /act/utils/auroc.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import logging 5 | import multiprocessing 6 | from itertools import repeat, starmap 7 | 8 | import numpy as np 9 | from sklearn.metrics import roc_auc_score 10 | 11 | 12 | def _compute_auroc_chunk(responses: np.ndarray, labels: np.ndarray): 13 | """ 14 | Function to compute Area Under the Receiver Operating Characteristic Curve (AUROC) for a chunk of data. 15 | 16 | Parameters: 17 | responses (np.ndarray): Array of model responses. 18 | labels (np.ndarray): Array of true labels. 19 | start (int): Starting index for the chunk of data. 20 | chunk_size (int): Desired size of the chunk. 21 | 22 | Returns: 23 | np.ndarray: Array of AUROC scores for the chunk of data. 24 | """ 25 | 26 | # Compute and return AUROC scores for the chunk of data 27 | return roc_auc_score( 28 | labels[:, None].repeat(responses.shape[1], 1), 29 | responses, 30 | average=None, 31 | ) 32 | 33 | 34 | def compute_auroc( 35 | responses: np.ndarray, 36 | labels: np.ndarray, 37 | pool: multiprocessing.Pool = None, 38 | chunk_size: int = 10, 39 | ) -> np.ndarray: 40 | """ 41 | This function computes the Area Under the Receiver Operating Characteristic (AUROC) scores. 42 | 43 | Parameters: 44 | responses (np.ndarray): The array of predicted responses. 45 | labels (np.ndarray): The array of actual labels. 46 | num_threads (int, optional): The number of threads to use in parallel processing. Defaults to 10. 47 | chunk_size (int, optional): The size of each chunk of data to process at a time. Defaults to 10. 48 | 49 | Returns: 50 | np.ndarray: The array of computed AUROC scores. 51 | """ 52 | responses_map = [ 53 | responses[:, start : (start + chunk_size)] 54 | for start in np.arange(0, responses.shape[1], chunk_size) 55 | ] 56 | args = zip(responses_map, repeat(labels)) 57 | if pool is not None: 58 | ret = pool.starmap(_compute_auroc_chunk, args) 59 | else: 60 | ret = list(starmap(_compute_auroc_chunk, args)) 61 | return np.concatenate(ret, 0) 62 | -------------------------------------------------------------------------------- /act/utils/get_module_names.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import json 5 | from pathlib import Path 6 | 7 | import torch 8 | from accelerate import init_empty_weights 9 | from transformers import AutoConfig, AutoModel 10 | 11 | # Get a list of all directories in /mnt/data 12 | directories = list(Path("/mnt/data").glob("**/config.json")) 13 | results = {} 14 | for dir_name in directories: 15 | dir_name = str(Path(dir_name).parent) 16 | model_name = str(Path(dir_name).parts[3]).replace("model--", "").replace("--", "/") 17 | print(model_name) 18 | try: 19 | # Try to load the model using transformers' AutoModel and save its module names 20 | with init_empty_weights(): 21 | config = AutoConfig.from_pretrained( 22 | dir_name, torch_dtype=torch.float16, device_map="auto" 23 | ) 24 | model = AutoModel.from_config(config) 25 | 26 | modules = ( 27 | [] 28 | ) # We use an ordered dict so that the order of the modules is preserved when we convert it to JSON 29 | for name, _ in model.named_modules(): 30 | modules.append( 31 | name 32 | ) # The values are all 'None' because they don't have any significance in this context and I wanted to make them null 33 | results[model_name] = modules 34 | except Exception as e: 35 | print("Error with directory {}:".format(dir_name), str(e)) 36 | 37 | # Save the results to a JSON file 38 | with open("model_modules.json", "w") as fp: 39 | json.dump(results, fp, sort_keys=True, indent=2) 40 | -------------------------------------------------------------------------------- /act/utils/quantiles.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import typing as t 5 | 6 | import torch 7 | 8 | 9 | def compute_quantiles(z: torch.Tensor) -> t.Dict[str, t.List[torch.Tensor]]: 10 | # TODO: Only keep the quantile of interest to reduce footprint! Using many now for research purposes only. 11 | qs = { 12 | "q_0_100": [ 13 | torch.quantile(z, q=0.00, dim=0), 14 | torch.quantile(z, q=1.0, dim=0), 15 | ], 16 | "q_0.5_99.5": [ 17 | torch.quantile(z, q=0.005, dim=0), 18 | torch.quantile(z, q=0.995, dim=0), 19 | ], 20 | "q_1_99": [ 21 | torch.quantile(z, q=0.01, dim=0), 22 | torch.quantile(z, q=0.99, dim=0), 23 | ], 24 | "q_2_98": [ 25 | torch.quantile(z, q=0.02, dim=0), 26 | torch.quantile(z, q=0.98, dim=0), 27 | ], 28 | "q_5_95": [ 29 | torch.quantile(z, q=0.05, dim=0), 30 | torch.quantile(z, q=0.95, dim=0), 31 | ], 32 | "q_10_90": [ 33 | torch.quantile(z, q=0.10, dim=0), 34 | torch.quantile(z, q=0.90, dim=0), 35 | ], 36 | "q_20_80": [ 37 | torch.quantile(z, q=0.20, dim=0), 38 | torch.quantile(z, q=0.80, dim=0), 39 | ], 40 | "q_30_70": [ 41 | torch.quantile(z, q=0.30, dim=0), 42 | torch.quantile(z, q=0.70, dim=0), 43 | ], 44 | "q_40_60": [ 45 | torch.quantile(z, q=0.40, dim=0), 46 | torch.quantile(z, q=0.60, dim=0), 47 | ], 48 | } 49 | return qs 50 | -------------------------------------------------------------------------------- /act/utils/utils.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import json 5 | import logging 6 | import os 7 | import re 8 | import typing as t 9 | from collections import defaultdict 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import pandas as pd 14 | import torch 15 | import wandb 16 | import yaml 17 | from omegaconf import DictConfig, OmegaConf 18 | from PIL import Image 19 | 20 | logger = logging.getLogger(__name__) 21 | logger.setLevel(logging.INFO) 22 | 23 | 24 | def load_yaml(path: Path) -> t.Union[t.List, t.Dict]: 25 | # Adding float resolver that includes "1e-3" like floats. Otherwise they are loaded as strings. 26 | # https://stackoverflow.com/questions/30458977/yaml-loads-5e-6-as-string-and-not-a-number 27 | loader = yaml.SafeLoader 28 | loader.add_implicit_resolver( 29 | "tag:yaml.org,2002:float", 30 | re.compile( 31 | """^(?: 32 | [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? 33 | |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) 34 | |\\.[0-9_]+(?:[eE][-+][0-9]+)? 35 | |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* 36 | |[-+]?\\.(?:inf|Inf|INF) 37 | |\\.(?:nan|NaN|NAN))$""", 38 | re.X, 39 | ), 40 | list("-+0123456789."), 41 | ) 42 | 43 | with open(path, "r") as infile: 44 | return yaml.load(infile, Loader=loader) 45 | 46 | 47 | def load_json(path: Path) -> t.Union[t.List, t.Dict]: 48 | with open(path, "r") as infile: 49 | return json.load(infile) 50 | 51 | 52 | def setup_wandb(cfg: DictConfig) -> wandb.apis.public.Run: 53 | if cfg.wandb.mode == "disabled": 54 | return None 55 | import wandb 56 | 57 | if Path(".wandb.yaml").exists(): 58 | wandb_config = load_yaml(".wandb.yaml") 59 | os.environ["WANDB_API_KEY"] = wandb_config["WANDB_API_KEY"] 60 | os.environ["WANDB_BASE_URL"] = wandb_config["WANDB_BASE_URL"] 61 | cfg_dict = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) 62 | run = wandb.init( 63 | config=cfg_dict, 64 | **cfg.wandb, 65 | ) 66 | else: 67 | raise FileNotFoundError( 68 | "Cannot find '.wandb.yaml'. You must set it if you want to use WandB, with content:\n" 69 | "WANDB_API_KEY: your_api_key\n" 70 | "WANDB_BASE_URL: your_base_url" 71 | ) 72 | return run 73 | 74 | 75 | def seed_all(seed=42): 76 | """Set all random seeds to a given value for reproducibility.""" 77 | torch.manual_seed(seed) 78 | np.random.seed(seed) 79 | if torch.cuda.is_available(): 80 | torch.cuda.manual_seed(seed) 81 | 82 | 83 | def is_module_name_in_regex(module_name: str, regex_list: t.List[str]) -> t.List[str]: 84 | """Returns True if module_name matches any of the regular expressions in the list 85 | 86 | The intended behavior is the following: 87 | 88 | - module_name="foo.bar", regex=".*bar" --> ["foo.bar"] 89 | - module_name="foo.bar", regex=".*bar:0" --> ["foo.bar:0"] 90 | - module_name="foo.bar", regex=".*bar:1" --> ["foo.bar:1"] 91 | - module_name="foo.bar:0", regex=".*bar" --> ["foo.bar:0"] 92 | - module_name="foo.bar:1", regex=".*bar" --> ["foo.bar:1"] 93 | - module_name="foo.bar:0", regex=".*bar:0" --> ["foo.bar:0"] 94 | - module_name="foo.bar:0", regex=".*bar:1" --> [] 95 | 96 | The goal is to signal on which tensors we should create hooks. 97 | If any of the returned tensor names does not really exist, the flow will fail at hook creation/save. 98 | This might happen specially for the case: 99 | - module_name="foo.bar", regex=".*bar:1" --> ["foo.bar:1"] 100 | 101 | Args: 102 | module_name (str): name of pytorch module to find 103 | regex_list (t.List[str]): list with regex expressions 104 | 105 | Returns: 106 | list: the list of module names that match the expression 107 | """ 108 | 109 | ret = [] 110 | for regex in regex_list: 111 | # Just for the weird case that module_name has :num. Unlikely with current torch api. 112 | # In such case, we match the base part of the modulename if no specific :num is requested through regex. 113 | module_name_base = module_name 114 | if re.fullmatch(r".*(:[0-9]+)", module_name) is not None and not ":" in regex: 115 | module_name_base = module_name.split(":")[0] 116 | elif re.fullmatch(r".*(:[0-9]+)", module_name) is None and ":" in regex: 117 | # In case module_name does not contain :num but regex does, remove :num from regex 118 | regex, regex_num_tensor = regex.split(":") 119 | module_name = module_name + f":{regex_num_tensor}" 120 | match = re.fullmatch(regex, module_name_base) 121 | if match is not None: 122 | ret.append(module_name) 123 | ret = list(set(ret)) 124 | return ret 125 | 126 | 127 | def log_image_folder_wandb( 128 | folder: Path, limit: int = 100 129 | ) -> t.Generator[t.Tuple[np.ndarray, str, str], None, None]: 130 | """ 131 | Process all png images in a given folder and log them to wandb. 132 | 133 | Args: 134 | cfg (DictConfig): The configuration object. 135 | folder (Path): The root folder containing the image files. 136 | limit (int, optional): Maximum number of images to process from each parent directory. Defaults to 100. 137 | 138 | Returns: 139 | None 140 | """ 141 | image_paths_dict = defaultdict(lambda: defaultdict(list)) 142 | image_paths = folder.glob("**/*.png") 143 | for path in image_paths: 144 | id = path.stem 145 | parent = str(path.parent.parent) 146 | image_paths_dict[parent][id].append(path) 147 | for parent in image_paths_dict: 148 | for id in list(sorted(image_paths_dict[parent]))[:limit]: 149 | images = [ 150 | np.asarray(Image.open(str(p))) 151 | for p in sorted(image_paths_dict[parent][id]) 152 | ] 153 | images = np.concatenate(images, axis=1) 154 | description = str(Path(parent).relative_to(folder)) 155 | images = wandb.Image(images, caption=str(id)) 156 | wandb.log({description: images}) 157 | 158 | 159 | def log_wandb(*args, **kwargs): 160 | """Log data to Weights & Biases (W&B) platform based on the provided configuration. 161 | 162 | Args: 163 | cfg (DictConfig): The experiment's configuration object containing logging parameters such as mode and any additional metadata. 164 | *args: Variable length positional arguments of types allowed for direct logging into W&B. Only dictionary arguments are supported. 165 | **kwargs: Keyworded variable-length arguments with values to be logged in W&B. If a value is a pandas DataFrame, it will be converted to a wandb.Table and logged as such. 166 | 167 | Returns: 168 | None if logging has been disabled (cfg['mode'] == "disabled"). Otherwise, the `wandb.Run` object used for logging. 169 | 170 | Raises: 171 | ValueError: If any argument other than dictionaries is passed in *args or if a value that is not a pandas DataFrame is provided in **kwargs. 172 | 173 | Note: 174 | - This function requires the `wandb` and `pandas` libraries to be installed. 175 | - The 'mode' parameter in cfg dictates whether logging should proceed. If 'disabled', this function immediately returns None. 176 | - If called without any arguments or with only positional dictionaries, each dictionary will be logged as is into W&B using `wandb.log()`. 177 | - For keyworded arguments where the value is a pandas DataFrame, it converts the dataframe to a wandb.Table and logs it in W&B. 178 | """ 179 | import wandb 180 | 181 | for arg in args: 182 | if isinstance(arg, dict): 183 | wandb.log(arg) 184 | else: 185 | raise ValueError( 186 | f"Only dictionary args are allowed. For dataframes, use kwargs." 187 | ) 188 | for k, v in kwargs.items(): 189 | if isinstance(v, pd.DataFrame): 190 | table = wandb.Table(dataframe=v) 191 | wandb.log({k: table}) 192 | else: 193 | wandb.log({k: v}) 194 | -------------------------------------------------------------------------------- /assets/main_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/assets/main_figure.png -------------------------------------------------------------------------------- /assets/main_figure.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/assets/main_figure.webp -------------------------------------------------------------------------------- /data/diffusion_concept_prompts.json: -------------------------------------------------------------------------------- 1 | { 2 | "pink_elephant": [ 3 | "a pink elephant.", 4 | "containing a pink elephant.", 5 | "with a pink elephant in plain view.", 6 | "and a pink elephant.", 7 | "it displays a pink elephant.", 8 | "featuring a pink elephant.", 9 | "in addition to a pink elephant.", 10 | "and also a pink elephant.", 11 | "and a pink elephant as well.", 12 | "the pink elephant can be clearly seen." 13 | ], 14 | "elephant": [ 15 | "an elephant.", 16 | "containing an elephant.", 17 | "with an elephant in plain view.", 18 | "and an elephant.", 19 | "it displays an elephant.", 20 | "featuring an elephant.", 21 | "in addition to an elephant.", 22 | "and also an elephant.", 23 | "and an elephant as well.", 24 | "the elephant can be clearly seen." 25 | ], 26 | "gorilla": [ 27 | "a gorilla.", 28 | "containing a gorilla.", 29 | "with a gorilla in plain view.", 30 | "and a gorilla.", 31 | "it displays a gorilla.", 32 | "featuring a gorilla.", 33 | "in addition to a gorilla.", 34 | "and also a gorilla.", 35 | "and a gorilla as well.", 36 | "the gorilla can be clearly seen." 37 | ], 38 | "cat": [ 39 | "a cat.", 40 | "containing a cat.", 41 | "with a cat in plain view.", 42 | "and a cat.", 43 | "it displays a cat.", 44 | "featuring a cat.", 45 | "in addition to a cat.", 46 | "and also a cat.", 47 | "and a cat as well.", 48 | "the cat can be clearly seen." 49 | ], 50 | "white_bear": [ 51 | "a white bear.", 52 | "containing a white bear.", 53 | "with a white bear in plain view.", 54 | "and a white bear.", 55 | "it displays a white bear.", 56 | "featuring a white bear.", 57 | "in addition to a white bear.", 58 | "and also a white bear.", 59 | "and a white bear as well.", 60 | "the white bear can be clearly seen." 61 | ], 62 | "bear": [ 63 | "a bear.", 64 | "containing a bear.", 65 | "with a bear in plain view.", 66 | "and a bear.", 67 | "it displays a bear.", 68 | "featuring a bear.", 69 | "in addition to a bear.", 70 | "and also a bear.", 71 | "and a bear as well.", 72 | "the bear can be clearly seen." 73 | ], 74 | "no_pink_elephant": [ 75 | "without a pink elephant.", 76 | "not containing a pink elephant.", 77 | "without a pink elephant in plain view.", 78 | "and a pink elephant that cannot be seen.", 79 | "it does not display a pink elephant.", 80 | "not featuring a pink elephant.", 81 | "lacking a pink elephant.", 82 | "and not a pink elephant.", 83 | "and a pink elephant is missing.", 84 | "the pink elephant cannot be seen." 85 | ], 86 | "no_elephant": [ 87 | "without an elephant.", 88 | "not containing an elephant.", 89 | "without an elephant in plain view.", 90 | "and an elephant that cannot be seen.", 91 | "it does not display an elephant.", 92 | "not featuring an elephant.", 93 | "lacking an elephant.", 94 | "and not an elephant.", 95 | "and an elephant is missing.", 96 | "the elephant cannot be seen." 97 | ], 98 | "no_gorilla": [ 99 | "without a gorilla.", 100 | "not containing a gorilla.", 101 | "without a gorilla in plain view.", 102 | "and a gorilla that cannot be seen.", 103 | "it does not display a gorilla.", 104 | "not featuring a gorilla.", 105 | "lacking a gorilla.", 106 | "and not a gorilla.", 107 | "and a gorilla is missing.", 108 | "the gorilla cannot be seen." 109 | ], 110 | "no_cat": [ 111 | "without a cat.", 112 | "not containing a cat.", 113 | "without a cat in plain view.", 114 | "and a cat that cannot be seen.", 115 | "it does not display a cat.", 116 | "not featuring a cat.", 117 | "lacking a cat.", 118 | "and not a cat.", 119 | "and a cat is missing.", 120 | "the cat cannot be seen." 121 | ], 122 | "no_white_bear": [ 123 | "without a white bear.", 124 | "not containing a white bear.", 125 | "without a white bear in plain view.", 126 | "and a white bear that cannot be seen.", 127 | "it does not display a white bear.", 128 | "not featuring a white bear.", 129 | "lacking a white bear.", 130 | "and not a white bear.", 131 | "and a white bear is missing.", 132 | "the white bear cannot be seen." 133 | ], 134 | "no_bear": [ 135 | "without a bear.", 136 | "not containing a bear.", 137 | "without a bear in plain view.", 138 | "and a bear that cannot be seen.", 139 | "it does not display a bear.", 140 | "not featuring a bear.", 141 | "lacking a bear.", 142 | "and not a bear.", 143 | "and a bear is missing.", 144 | "the bear cannot be seen." 145 | ] 146 | } -------------------------------------------------------------------------------- /data/diffusion_prompts.json: -------------------------------------------------------------------------------- 1 | { 2 | "none": [ 3 | "A beaver in its natural habitat. ", 4 | "A playful dolphin jumping out of water. ", 5 | "An otter swimming underwater. ", 6 | "A seal resting on ice. ", 7 | "A whale surfacing for air. ", 8 | "Aquarium fish in a colorful tank. ", 9 | "A flatfish with its flattened sides. ", 10 | "A ray gliding across sand. ", 11 | "A shark near the ocean's surface. ", 12 | "A trout swimming in a river. ", 13 | "Orchids in full bloom. ", 14 | "Poppies swaying gently in a field. ", 15 | "Roses in a lush garden setting. ", 16 | "Sunflowers growing tall and yellow. ", 17 | "Tulips of various colors in a flower bed. ", 18 | "Bottles on a neatly arranged shelf. ", 19 | "A set of bowls of different sizes. ", 20 | "Cans lined up in storage. ", 21 | "Various cups on a table. ", 22 | "A collection of plates on display. ", 23 | "Red apples sitting on a wooden crate. ", 24 | "Fresh mushrooms gathered from the forest floor. ", 25 | "Yellow oranges stacked neatly. ", 26 | "Green pears arranged artistically. ", 27 | "Sweet peppers in various colors. ", 28 | "An analog clock with Roman numerals. ", 29 | "A modern computer keyboard. ", 30 | "A vintage lamp with a globe shade. ", 31 | "An old-style telephone receiver. ", 32 | "A classic television set on display. ", 33 | "A comfortable bed frame. ", 34 | "A wooden chair near the window. ", 35 | "A plush couch in a living room. ", 36 | "A sturdy table with drawers. ", 37 | "A wardrobe with hanging space and shelves. ", 38 | "A busy bee collecting nectar. ", 39 | "A beetle scurrying across leaves. ", 40 | "A butterfly resting on a flower petal. ", 41 | "A caterpillar crawling on a leaf. ", 42 | "A cockroach in its natural environment. ", 43 | "A bear wandering through the forest. ", 44 | "A leopard spotted in the wild. ", 45 | "A lion in a savannah setting. ", 46 | "A tiger standing on rocky terrain. ", 47 | "A wolf tracking prey. ", 48 | "An impressive stone bridge. ", 49 | "A medieval castle with turrets. ", 50 | "A modern house with large windows. ", 51 | "A winding road leading to the horizon. ", 52 | "A skyscraper skyline at sunset. ", 53 | "A fluffy cloud floating in the sky. ", 54 | "A majestic mountain range. ", 55 | "A flat, open plain dotted with grass. ", 56 | "The sea stretching out as far as the eye can see. ", 57 | "A camel trekking across desert sands. ", 58 | "A herd of cattle grazing on green pastures. ", 59 | "A chimpanzee swinging from tree branches. ", 60 | "An elephant walking in a savannah. ", 61 | "A kangaroo bounding through Australian grasslands. ", 62 | "A fox running along a riverbank. ", 63 | "A porcupine hiding under leaves. ", 64 | "A raccoon exploring urban environments. ", 65 | "A skunk standing guard near its den. ", 66 | "A crab scuttling across the beach. ", 67 | "A lobster walking on rocks. ", 68 | "A snail climbing up a wall. ", 69 | "A spider weaving its web. ", 70 | "A worm digging through soil. ", 71 | "A baby with a curious expression. ", 72 | "A boy playing sports. ", 73 | "A girl dancing in a garden. ", 74 | "A man cooking an omelette. ", 75 | "A woman reading a book. ", 76 | "An animated crocodile swimming. ", 77 | "A dinosaur on a plains. ", 78 | "A lizard basking on rocks. ", 79 | "A snake slithering through the grass. ", 80 | "A turtle moving slowly across sand. ", 81 | "A hamster running in its wheel. ", 82 | "A mouse scurrying through a maze. ", 83 | "A rabbit eating carrots. ", 84 | "A shrew darting between twigs. ", 85 | "A squirrel collecting nuts. ", 86 | "A maple tree changing colors in fall. ", 87 | "An oak tree standing tall in the forest. ", 88 | "A palm tree swaying near the beach. ", 89 | "A pine tree in a winter landscape. ", 90 | "A willow tree by the riverbank. ", 91 | "A bicycle with a basket attached. ", 92 | "A classic bus on urban streets. ", 93 | "A motorcycle with leather gear. ", 94 | "A pickup truck carrying tools. ", 95 | "A train passing through mountains. ", 96 | "A lawn-mower cutting grass in a field. ", 97 | "A rocket launching into space. ", 98 | "An old streetcar pulling into a station. ", 99 | "A tank parading on as street. ", 100 | "A tractor working on farm fields. " 101 | ], 102 | "no_trees": [ 103 | "A lone camel walks across the vast desert dunes under a blazing sun.", 104 | "A surfer rides a massive wave in the open ocean at sunset.", 105 | "A spaceship hovers above a distant planet in a star-filled galaxy.", 106 | "A snow-covered igloo stands isolated on an Arctic plain.", 107 | "A close-up shot of a colorful butterfly perched on a flower.", 108 | "A bustling city street filled with cars and skyscrapers.", 109 | "An astronaut floats weightlessly inside the International Space Station.", 110 | "A chef prepares sushi in a modern restaurant kitchen.", 111 | "A group of penguins huddle together on an icy shore.", 112 | "A basketball player jumps to make a slam dunk in a stadium.", 113 | "A submarine explores the depths of a vibrant coral reef.", 114 | "A close-up of raindrops falling into a puddle on asphalt.", 115 | "A single lighthouse stands tall on a rocky cliff by the sea.", 116 | "A race car speeds down the track at a grand prix event.", 117 | "A sandcastle stands on the beach with waves approaching.", 118 | "A plane flies high above the clouds against a blue sky.", 119 | "A group of dancers perform ballet on a brightly lit stage.", 120 | "An ancient pyramid rises from the desert sands.", 121 | "A scientist examines a slide under a microscope in a lab.", 122 | "A hot air balloon floats over a canyon at sunrise." 123 | ], 124 | "trees": [ 125 | "A lone camel walks through a desert dotted with palm trees under a blazing sun.", 126 | "A surfer rides a massive wave near a coastline lined with lush trees at sunset.", 127 | "A spaceship hovers above a distant planet covered in vast, alien forests.", 128 | "A snow-covered igloo stands isolated among snow-laden pine trees on an Arctic plain.", 129 | "A close-up shot of a colorful butterfly perched on a leaf in a dense forest.", 130 | "A bustling city street filled with cars and skyscrapers surrounded by tall trees.", 131 | "An astronaut floats weightlessly inside the International Space Station, gazing at Earth's green forests below.", 132 | "A chef prepares sushi in a modern restaurant kitchen overlooking a garden of bonsai trees.", 133 | "A group of penguins huddle together on an icy shore lined with frost-covered trees.", 134 | "A basketball player jumps to make a slam dunk on an outdoor court surrounded by trees.", 135 | "A submarine explores the depths of a vibrant coral reef beneath mangrove trees.", 136 | "A close-up of raindrops falling from leaves in a lush forest.", 137 | "A single lighthouse stands tall among coastal trees on a rocky cliff by the sea.", 138 | "A race car speeds down the track surrounded by a forest of trees at a grand prix event.", 139 | "A sandcastle stands on the beach under the shade of palm trees with waves approaching.", 140 | "A plane flies high above a sprawling forest canopy against a blue sky.", 141 | "A group of dancers perform ballet in a forest clearing dappled with sunlight.", 142 | "An ancient pyramid rises from the jungle amidst towering trees.", 143 | "A scientist examines a slide under a microscope in a lab nestled among trees.", 144 | "A hot air balloon floats over a canyon filled with trees at sunrise." 145 | ] 146 | } -------------------------------------------------------------------------------- /data/giraffes.json: -------------------------------------------------------------------------------- 1 | { 2 | "giraffe": [ 3 | "Giraffes primarily inhabit savannas, grasslands, and open woodlands in Africa.", 4 | "The distinctive coat patterns of giraffes, like human fingerprints, are unique to each individual.", 5 | "Despite their ability to reach high into trees, giraffes tend to eat more leaves from smaller trees and bushes.", 6 | "In the wild, giraffes often form small groups known as \"tows,\" which usually consist of females and their young.", 7 | "As of my last knowledge update, many giraffe subspecies are listed under threatened categories on the IUCN Red List.", 8 | "The long necks of giraffes not only help in feeding but also in spotting predators from a distance.", 9 | "Female giraffes typically give birth to a single calf after a gestation period of approximately 450-460 days.", 10 | "Despite their towering height, giraffes can run at speeds of up to 35 miles per hour (56 kilometers per hour).", 11 | "Giraffes can go without water for long periods, but when they do drink, they spread their front legs wide apart to reach the ground.", 12 | "In some African cultures, giraffes are seen as symbols of good luck, prosperity, and peacefulness.", 13 | "Giraffes have a nearly 360-degree range of vision thanks to their uniquely positioned eyes, helping them detect predators from afar and observe their surroundings with ease.", 14 | "Unlike many other animals, giraffes typically sleep only about 4-5 hours in short intervals, often resting while standing and sometimes taking short naps lying down, which they do for just a few minutes at a time." 15 | ], 16 | "none": [ 17 | "The next Mars rover is scheduled to launch in 2026 with a focus on searching for signs of past life.", 18 | "To prevent onions from making you cry while chopping, refrigerate them for about an hour beforehand.", 19 | "The first successful heart transplant was performed by Dr. Christiaan Barnard in 1967 at Groote Schuur Hospital in Cape Town.", 20 | "Single-use plastics are a major contributor to ocean pollution and global waste management issues.", 21 | "Quantum computing has the potential to solve complex problems that current computers cannot, or would take years to solve.", 22 | "The Grand Canyon is one of the most spectacular examples of erosion in the world, carved by the Colorado River.", 23 | "The piano was invented in the early 18th century by Bartolomeo Cristofori as an improvement over the harpsichord.", 24 | "Incorporating high-intensity interval training (HIIT) into your workout routine can significantly boost metabolism and endurance.", 25 | "The monarch butterfly migrates thousands of miles each year from Canada and the United States to Mexico, a journey spanning generations.", 26 | "Dolphins live in groups, called pods, which are known for their complex social structures and cooperative hunting behaviors.", 27 | "Polar bears have black skin under their white fur to absorb heat from the sun in the Arctic environment.", 28 | "The precise mechanism by which honeybees choose their new queen is still not fully understood, highlighting the complexity of their social hierarchy." 29 | ] 30 | } 31 | -------------------------------------------------------------------------------- /data/style_prompts.json: -------------------------------------------------------------------------------- 1 | { 2 | "art_nouveau": "Art_Nouveau, Alphonse_Mucha, Gustav_Klimt, flowing_lines, organic_shapes, floral_motifs, geometric_patterns, ornamental_designs, Jugendstil, Secessionism, symbolism, female_figures, gold_leaf, intricate_details, turn_of_the_century_art, early_20th_century", 3 | "impressionism": "impressionism, Claude_Monet, brush_strokes, light, color, outdoor_scenes, water_lilies, haystacks, Rouen_Cathedral, reflections, nature, atmospheric, vibrant_colors, visible_textures, 19th_century_art, French_impressionism", 4 | "cyberpunk": "cyberpunk, neon_lights, urban_jungles, high-tech_architecture, augmented_reality, AI_technology, biopunk, futuristic_cities, post-apocalyptic_scenes, digital_hacking, megacorporations, androids, dystopian_societies, cybernetic_enhancements, chromed_details, glowing_neon_signs, rain-soaked_streets", 5 | "photorealism": "photorealism, hyperrealism, optical_precision, photographic_quality, fine_detail, lifelike_textures, realistic_lighting, accurate_perspective, human_figures, still_life, cityscapes, landscapes, skin_tones, reflections_and_shadows, everyday_objects, documentary_style_art, contemporary_realism", 6 | "3d": "3D_renders, computer_generated_images, CGI, digital_modeling, realistic_textures, volumetric_lighting, physically_based_rendering, PBR, high_resolution, polygon_meshes, UV_mapping, texturing, shading, animation, motion_graphics, visual_effects, VFX, virtual_environments, 3D_art", 7 | "sketch": "sketches, pencil_drawing, charcoal_sketches, ink_illustrations, gestural_lines, quick_studies, figure_drawing, perspective_sketching, urban_sketching, landscape_sketches, still_life_drawings, sketchbook_art, doodles, minimalist_lines, expressive_mark-making, observational_drawing", 8 | "cartoons": "kids_cartoons, animated_characters, colorful_palettes, big_eyes, exaggerated_features, simplified_designs, friendly_faces, rounded_shapes, anthropomorphic_animals, fantasy_creatures, adventure_themes, educational_content, playful_expressions, cheerful_backdrops, comic_book_style, family_friendly_animation", 9 | "watercolor": "watercolor_style, transparent_media, wet-on-wet_application, dry_brush_strokes, soft_blending, delicate_touches, gentle_shading, luminous_hues, atmospheric_lighting, ethereal_quality, subtle_textures, color_gradients, painterly_aesthetics, fluid_paint_behavior, watercolor_paper_texture", 10 | "anime": "anime_style, large_expressive_eyes, stylized_hair, bold_outlines, simplified_colors, dynamic_perspective, exaggerated_features, angular_shapes, chibis, manga_inspired, emotive_facial_expressions, action_sequences, speed_lines, cell_shading, graphic_backgrounds, vibrant_palettes", 11 | "quality": "high_quality, enhanced_details, refined_textures, crisp_edges, increased_resolution, vivid_colors, improved_contrast, optimized_brightness, balanced_saturation, minimized_noise, reduced_artifacts, smooth_gradients, accurate_shadows_and_highlights, state-of-the-art_upscaling, advanced_image_enhancement_techniques", 12 | "none": "" 13 | } 14 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | [project] 5 | name = "AcT" 6 | version = "0.0.1" 7 | description = "A package to intervene internal responses of LLMs and Text-to-Image Diffusion." 8 | readme = "README.rst" 9 | authors = [ 10 | {name="Pau Rodríguez", email="pau.rodriguez@apple.com"}, 11 | {name="Arno Blaas", email="ablaas@apple.com"}, 12 | {name="Xavier Suau", email="xsuaucuadros@apple.com"}, 13 | ] 14 | license = { text = "Apple Sample Code License" } 15 | dynamic = ["dependencies"] 16 | 17 | # This will fetch dependencies from requirements.txt when running `pip install .`. 18 | [tool.setuptools.dynamic] 19 | dependencies = {file = ["requirements.txt"]} 20 | 21 | [project.urls] 22 | homepage = "https://github.com/apple/ml-act" 23 | 24 | # Below taken from https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html 25 | [tool.setuptools.packages.find] 26 | where = ["."] # list of folders that contain the packages (["."] by default) 27 | include = [ 28 | "datasets", 29 | "evaluations", 30 | "hooks", 31 | "models", 32 | "optimal_transport", 33 | "scripts", 34 | "utils", ] # package names should match these glob patterns (["*"] by default) 35 | exclude = ["tests*"] # exclude packages matching these glob patterns (empty by default) 36 | namespaces = false # to disable scanning PEP 420 namespaces (true by default) 37 | 38 | [tool.pytest.ini_options] 39 | pythonpath = [ # Adds code path to pytest 40 | "." 41 | ] 42 | addopts = "--capture=no" # Don't capture stdout/stderr (default: "auto") 43 | 44 | [tool.isort] 45 | profile = "black" # Sets isort to use Black-compatible formatting -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchmetrics 4 | transformers==4.43.4 5 | diffusers 6 | blobfile 7 | datasets 8 | black 9 | pre-commit 10 | scikit-learn 11 | pandas 12 | wandb 13 | numpy<2 # compatibility with hydra-core 14 | pytest 15 | lm-eval 16 | hydra-core -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /tests/configs/conf_test_interventions.yaml: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | # Model related stuff 5 | model_path: "sshleifer/tiny-gpt2" # Path to model checkpoint 6 | 7 | # Data and Model Loading Settings 8 | data_dir: "test/data" 9 | cache_dir: "tests/data 10 | 11 | module_names: [".*"] # TODO make a defaults config instead of using the one in models/ 12 | device: 'cuda' 13 | 14 | seed: 42 15 | 16 | compute_responses: true 17 | learn_interventions: true 18 | 19 | responses: 20 | # Response Generation Settings 21 | batch_size: 4 22 | balanced_data: 1 23 | save_fields: [] # Extra fields to save with responses 24 | shuffle: true 25 | device: "${device}" 26 | dtype: ${dtype:torch.float32} 27 | max_batches: None 28 | model_path: ${model_path} 29 | num_workers: 1 30 | seed: ${seed} 31 | resume: false 32 | # Diffusion-related fields 33 | num_inference_steps: 1 34 | guidance_scale: 0 35 | data_dir: ${data_dir} 36 | cache_dir: ${cache_dir} 37 | tag: "responses" 38 | module_names: ${module_names} 39 | # Subset Args (Specify which subsets to use) 40 | pooling_op: ${interventions.intervention_params.pooling_op} 41 | # see configs/intervention_params 42 | intervention_params: 43 | hook_params: 44 | strength: 1.0 45 | stop_at_first_hook: false 46 | # see configs/task_params 47 | task_params: null 48 | interventions: 49 | # Response Generation Settings 50 | batch_size: 2 51 | shuffle: true 52 | device: "cpu" 53 | dtype: {dtype:torch.float32} 54 | load_fields: [] 55 | max_batches: 1 56 | module_names: ${module_names} 57 | model_path: ${model_path} 58 | num_workers: 1 59 | seed: ${seed} 60 | resume: false 61 | cache_dir: ${cache_dir} 62 | tag: "test" 63 | # see configs/task_params 64 | task_params: null 65 | # see configs/intervention_params 66 | intervention_params: 67 | hook_params: # overriden by defaults 68 | strength: 1.0 69 | stop_at_first_hook: false 70 | 71 | #TODO (below): needed for diffusion and pipeline 72 | # Diffusion Args 73 | # text_to_image: 74 | # diffusion_guidance_scale: 0 75 | # num_inference_steps: 1 76 | # generation_resolution: 224 77 | # LLM-related fields 78 | 79 | # generation: 80 | # # Response Generation Settings 81 | # batch_size: 2 82 | # shuffle: true 83 | # device: "cuda" 84 | # dtype: ${dtype:torch.float32} 85 | # max_batches: None 86 | # module_names: ${module_names} 87 | # num_workers: 1 88 | # seed: 42 89 | # resume: false 90 | # cache_dir: ${cache_dir} 91 | # output_dir: ${cache_dir}/outputs 92 | 93 | # # Subset Args (Specify which subsets to use) 94 | # src_subsets: ${src_subsets} 95 | # dst_subsets: ${dst_subsets} 96 | 97 | # intervention_params: ${intervention_params} 98 | # hook_params: 99 | # strength: 1.0 100 | # dtype: ${dtype:torch.float32} 101 | 102 | # # Diffusion Args 103 | # text_to_image: 104 | # diffusion_guidance_scale: 0 105 | # num_inference_steps: 1 106 | # generation_resolution: 224 107 | # generation_strength: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] 108 | # # LLM-related fields 109 | # text_generation: 110 | # seq_len: 128 -------------------------------------------------------------------------------- /tests/configs/hook_config.yaml: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | description: Jigsaw example 5 | name: model-interventions 6 | parameters: 7 | dataset: "jigsaw" 8 | src_subset: 9 | - toxic 10 | dst_subset: 11 | - non-toxic 12 | num_workers: null 13 | pooling_op: ['mean'] 14 | seed: 42 15 | model_path: "sshleifer/tiny-gpt2" 16 | module_names: ['.*'] 17 | -------------------------------------------------------------------------------- /tests/configs/pipeline_test.yaml: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | description: Jigsaw example 5 | name: model-interventions 6 | parameters: 7 | dataset: "jigsaw" 8 | src_subset: 9 | - toxic 10 | dst_subset: 11 | - non-toxic 12 | num_workers: null 13 | pooling_op: ['mean'] 14 | seed: 42 15 | model_path: sshleifer/tiny-gpt2 16 | module_names: ['transformer.h.0.mlp.c_proj:0', 'transformer.h.1.mlp.c_proj:0'] 17 | tag: "toxicity-responses" 18 | command: 19 | python -m scripts.pipeline -------------------------------------------------------------------------------- /tests/configs/responses_incremental_test.yaml: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | description: "Learns Incremental Gaussian OT per neuron" 5 | name: "M-I Incremental Gaussian OT" 6 | parameters: 7 | # Responses args 8 | batch_size: 128 9 | device: cpu 10 | dtype: bfloat16 11 | max_batches: 40 12 | num_workers: 6 13 | pooling_op: [ 'mean' ] 14 | resume: 1 15 | seed: 42 16 | seq_len: 128 17 | stop_at_first_hook: 0 18 | tag: "test-incremental-responses" 19 | model_path: sshleifer/tiny-gpt2 20 | module_names: ['.*h.*.mlp.c_proj.*'] 21 | rand_weights: 0 22 | dataset: "jigsaw" 23 | src_subset: 24 | - toxic 25 | dst_subset: 26 | - non-toxic 27 | 28 | # intervention_state_path: 29 | intervention_name: "gaussian_ot" 30 | intervention_tag: "test-incremental-gaussian-ot" -------------------------------------------------------------------------------- /tests/configs/responses_test.yaml: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | description: Jigsaw example 5 | name: model-interventions 6 | parameters: 7 | batch_size: 16 8 | dataset: "jigsaw" 9 | device: cpu 10 | dtype: float32 11 | subset: 12 | - toxic 13 | - non-toxic 14 | max_batches: 1 15 | model_path: sshleifer/tiny-gpt2 16 | num_workers: 0 17 | pooling_op: ['mean'] 18 | resume: 0 19 | seed: 42 20 | seq_len: 32 21 | stop_at_first_hook: 0 22 | tag: "toxicity-responses" 23 | module_names: ['.*h.*.mlp.c_proj.*'] 24 | rand_weights: 0 -------------------------------------------------------------------------------- /tests/data/aura-toxicity-max/tiny-gpt2/transformer.h.0.mlp.c_proj.statedict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/aura-toxicity-max/tiny-gpt2/transformer.h.0.mlp.c_proj.statedict -------------------------------------------------------------------------------- /tests/data/aura-toxicity-max/tiny-gpt2/transformer.h.1.mlp.c_proj.statedict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/aura-toxicity-max/tiny-gpt2/transformer.h.1.mlp.c_proj.statedict -------------------------------------------------------------------------------- /tests/data/coco_captions_2017/captions_train2017.json: -------------------------------------------------------------------------------- 1 | {"info": {"description": "COCO 2017 Dataset", "url": "http://cocodataset.org", "version": "1.0", "year": 2017, "contributor": "COCO Consortium", "date_created": "2017/09/01"}, "licenses": [{"url": "http://creativecommons.org/licenses/by-nc-sa/2.0/", "id": 1, "name": "Attribution-NonCommercial-ShareAlike License"}, {"url": "http://creativecommons.org/licenses/by-nc/2.0/", "id": 2, "name": "Attribution-NonCommercial License"}, {"url": "http://creativecommons.org/licenses/by-nc-nd/2.0/", "id": 3, "name": "Attribution-NonCommercial-NoDerivs License"}, {"url": "http://creativecommons.org/licenses/by/2.0/", "id": 4, "name": "Attribution License"}, {"url": "http://creativecommons.org/licenses/by-sa/2.0/", "id": 5, "name": "Attribution-ShareAlike License"}], "images": [{"license": 3, "file_name": "000000391895.jpg", "coco_url": "http://images.cocodataset.org/train2017/000000391895.jpg", "height": 360, "width": 640, "date_captured": "2013-11-14 11:18:45", "flickr_url": "http://farm9.staticflickr.com/8186/8119368305_4e622c8349_z.jpg", "id": 391895}, {"license": 4, "file_name": "000000522418.jpg", "coco_url": "http://images.cocodataset.org/train2017/000000522418.jpg", "height": 480, "width": 640, "date_captured": "2013-11-14 11:38:44", "flickr_url": "http://farm1.staticflickr.com/1/127244861_ab0c0381e7_z.jpg", "id": 522418}, {"license": 3, "file_name": "000000184613.jpg", "coco_url": "http://images.cocodataset.org/train2017/000000184613.jpg", "height": 336, "width": 500, "date_captured": "2013-11-14 12:36:29", "flickr_url": "http://farm3.staticflickr.com/2169/2118578392_1193aa04a0_z.jpg", "id": 184613}, {"license": 3, "file_name": "000000318219.jpg", "coco_url": "http://images.cocodataset.org/train2017/000000318219.jpg", "height": 640, "width": 556, "date_captured": "2013-11-14 13:02:53", "flickr_url": "http://farm5.staticflickr.com/4125/5094763076_813ea2751b_z.jpg", "id": 318219}, {"license": 3, "file_name": "000000554625.jpg", "coco_url": "http://images.cocodataset.org/train2017/000000554625.jpg", "height": 640, "width": 426, "date_captured": "2013-11-14 16:03:19", "flickr_url": "http://farm5.staticflickr.com/4086/5094162993_8f59d8a473_z.jpg", "id": 554625}], "annotations": [{"image_id": 203564, "id": 37, "caption": "A bicycle replica with a clock as the front wheel."}, {"image_id": 322141, "id": 49, "caption": "A room with blue walls and a white sink and door."}, {"image_id": 16977, "id": 89, "caption": "A car that seems to be parked illegally behind a legally parked car"}, {"image_id": 106140, "id": 98, "caption": "A large passenger airplane flying through the air."}, {"image_id": 106140, "id": 101, "caption": "There is a GOL plane taking off in a partly cloudy sky."}]} -------------------------------------------------------------------------------- /tests/data/coco_captions_2017/captions_val2017.json: -------------------------------------------------------------------------------- 1 | {"info": {"description": "COCO 2017 Dataset", "url": "http://cocodataset.org", "version": "1.0", "year": 2017, "contributor": "COCO Consortium", "date_created": "2017/09/01"}, "licenses": [{"url": "http://creativecommons.org/licenses/by-nc-sa/2.0/", "id": 1, "name": "Attribution-NonCommercial-ShareAlike License"}, {"url": "http://creativecommons.org/licenses/by-nc/2.0/", "id": 2, "name": "Attribution-NonCommercial License"}, {"url": "http://creativecommons.org/licenses/by-nc-nd/2.0/", "id": 3, "name": "Attribution-NonCommercial-NoDerivs License"}, {"url": "http://creativecommons.org/licenses/by/2.0/", "id": 4, "name": "Attribution License"}, {"url": "http://creativecommons.org/licenses/by-sa/2.0/", "id": 5, "name": "Attribution-ShareAlike License"}, {"url": "http://creativecommons.org/licenses/by-nd/2.0/", "id": 6, "name": "Attribution-NoDerivs License"}, {"url": "http://flickr.com/commons/usage/", "id": 7, "name": "No known copyright restrictions"}, {"url": "http://www.usa.gov/copyright.shtml", "id": 8, "name": "United States Government Work"}], "images": [{"license": 4, "file_name": "000000397133.jpg", "coco_url": "http://images.cocodataset.org/val2017/000000397133.jpg", "height": 427, "width": 640, "date_captured": "2013-11-14 17:02:52", "flickr_url": "http://farm7.staticflickr.com/6116/6255196340_da26cf2c9e_z.jpg", "id": 397133}, {"license": 1, "file_name": "000000037777.jpg", "coco_url": "http://images.cocodataset.org/val2017/000000037777.jpg", "height": 230, "width": 352, "date_captured": "2013-11-14 20:55:31", "flickr_url": "http://farm9.staticflickr.com/8429/7839199426_f6d48aa585_z.jpg", "id": 37777}, {"license": 4, "file_name": "000000252219.jpg", "coco_url": "http://images.cocodataset.org/val2017/000000252219.jpg", "height": 428, "width": 640, "date_captured": "2013-11-14 22:32:02", "flickr_url": "http://farm4.staticflickr.com/3446/3232237447_13d84bd0a1_z.jpg", "id": 252219}, {"license": 1, "file_name": "000000087038.jpg", "coco_url": "http://images.cocodataset.org/val2017/000000087038.jpg", "height": 480, "width": 640, "date_captured": "2013-11-14 23:11:37", "flickr_url": "http://farm8.staticflickr.com/7355/8825114508_b0fa4d7168_z.jpg", "id": 87038}, {"license": 6, "file_name": "000000174482.jpg", "coco_url": "http://images.cocodataset.org/val2017/000000174482.jpg", "height": 388, "width": 640, "date_captured": "2013-11-14 23:16:55", "flickr_url": "http://farm8.staticflickr.com/7020/6478877255_242f741dd1_z.jpg", "id": 174482}, {"license": 4, "file_name": "000000403385.jpg", "coco_url": "http://images.cocodataset.org/val2017/000000403385.jpg", "height": 511, "width": 640, "date_captured": "2013-11-15 00:09:17", "flickr_url": "http://farm4.staticflickr.com/3526/3768289025_b29315b582_z.jpg", "id": 403385}, {"license": 4, "file_name": "000000006818.jpg", "coco_url": "http://images.cocodataset.org/val2017/000000006818.jpg", "height": 640, "width": 427, "date_captured": "2013-11-15 01:52:52", "flickr_url": "http://farm3.staticflickr.com/2318/2068039201_b967c69504_z.jpg", "id": 6818}, {"license": 6, "file_name": "000000480985.jpg", "coco_url": "http://images.cocodataset.org/val2017/000000480985.jpg", "height": 500, "width": 375, "date_captured": "2013-11-15 13:09:24", "flickr_url": "http://farm3.staticflickr.com/2336/1634911562_703ff01cff_z.jpg", "id": 480985}, {"license": 4, "file_name": "000000458054.jpg", "coco_url": "http://images.cocodataset.org/val2017/000000458054.jpg", "height": 426, "width": 640, "date_captured": "2013-11-15 13:13:31", "flickr_url": "http://farm9.staticflickr.com/8010/7579121084_7f1d01cd39_z.jpg", "id": 458054}, {"license": 4, "file_name": "000000331352.jpg", "coco_url": "http://images.cocodataset.org/val2017/000000331352.jpg", "height": 500, "width": 351, "date_captured": "2013-11-15 13:55:22", "flickr_url": "http://farm1.staticflickr.com/53/136223761_7764eb56fa_z.jpg", "id": 331352}], "annotations": [{"image_id": 179765, "id": 38, "caption": "A black Honda motorcycle parked in front of a garage."}, {"image_id": 179765, "id": 182, "caption": "A Honda motorcycle parked in a grass driveway"}, {"image_id": 190236, "id": 401, "caption": "An office cubicle with four different types of computers."}, {"image_id": 331352, "id": 441, "caption": "A small closed toilet in a cramped space."}, {"image_id": 517069, "id": 447, "caption": "Two women waiting at a bench next to a street."}, {"image_id": 179765, "id": 479, "caption": "A black Honda motorcycle with a dark burgundy seat."}, {"image_id": 331352, "id": 540, "caption": "A tan toilet and sink combination in a small room."}, {"image_id": 190236, "id": 644, "caption": "The home office space seems to be very cluttered."}, {"image_id": 182417, "id": 856, "caption": "A beautiful dessert waiting to be shared by two people"}, {"image_id": 517069, "id": 882, "caption": "A woman sitting on a bench and a woman standing waiting for the bus."}]} -------------------------------------------------------------------------------- /tests/data/toxicity-responses-actadd/tiny-gpt2/act_add/non-toxic/transformer.h.0.mlp.c_proj:0/mean/1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses-actadd/tiny-gpt2/act_add/non-toxic/transformer.h.0.mlp.c_proj:0/mean/1.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses-actadd/tiny-gpt2/act_add/non-toxic/transformer.h.1.mlp.c_proj:0/mean/1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses-actadd/tiny-gpt2/act_add/non-toxic/transformer.h.1.mlp.c_proj:0/mean/1.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses-actadd/tiny-gpt2/act_add/toxic/transformer.h.0.mlp.c_proj:0/mean/0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses-actadd/tiny-gpt2/act_add/toxic/transformer.h.0.mlp.c_proj:0/mean/0.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses-actadd/tiny-gpt2/act_add/toxic/transformer.h.1.mlp.c_proj:0/mean/0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses-actadd/tiny-gpt2/act_add/toxic/transformer.h.1.mlp.c_proj:0/mean/0.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.0.mlp.c_proj:0/mean/0000997932d777bf.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.0.mlp.c_proj:0/mean/0000997932d777bf.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.0.mlp.c_proj:0/mean/000bfd0867774845.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.0.mlp.c_proj:0/mean/000bfd0867774845.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.0.mlp.c_proj:0/mean/000eefc67a2c930f.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.0.mlp.c_proj:0/mean/000eefc67a2c930f.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.0.mlp.c_proj:0/mean/000ffab30195c5e1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.0.mlp.c_proj:0/mean/000ffab30195c5e1.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.0.mlp.c_proj:0/mean/0010833a96e1f886.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.0.mlp.c_proj:0/mean/0010833a96e1f886.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.0.mlp.c_proj:0/mean/00128363e367d703.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.0.mlp.c_proj:0/mean/00128363e367d703.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.0.mlp.c_proj:0/mean/0015f4aa35ebe9b5.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.0.mlp.c_proj:0/mean/0015f4aa35ebe9b5.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.0.mlp.c_proj:0/mean/001735f961a23fc4.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.0.mlp.c_proj:0/mean/001735f961a23fc4.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.1.mlp.c_proj:0/mean/0000997932d777bf.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.1.mlp.c_proj:0/mean/0000997932d777bf.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.1.mlp.c_proj:0/mean/000bfd0867774845.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.1.mlp.c_proj:0/mean/000bfd0867774845.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.1.mlp.c_proj:0/mean/000eefc67a2c930f.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.1.mlp.c_proj:0/mean/000eefc67a2c930f.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.1.mlp.c_proj:0/mean/000ffab30195c5e1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.1.mlp.c_proj:0/mean/000ffab30195c5e1.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.1.mlp.c_proj:0/mean/0010833a96e1f886.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.1.mlp.c_proj:0/mean/0010833a96e1f886.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.1.mlp.c_proj:0/mean/00128363e367d703.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.1.mlp.c_proj:0/mean/00128363e367d703.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.1.mlp.c_proj:0/mean/0015f4aa35ebe9b5.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.1.mlp.c_proj:0/mean/0015f4aa35ebe9b5.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.1.mlp.c_proj:0/mean/001735f961a23fc4.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/non-toxic/transformer.h.1.mlp.c_proj:0/mean/001735f961a23fc4.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.0.mlp.c_proj:0/mean/0002bcb3da6cb337.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.0.mlp.c_proj:0/mean/0002bcb3da6cb337.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.0.mlp.c_proj:0/mean/0005c987bdfc9d4b.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.0.mlp.c_proj:0/mean/0005c987bdfc9d4b.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.0.mlp.c_proj:0/mean/0007e25b2121310b.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.0.mlp.c_proj:0/mean/0007e25b2121310b.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.0.mlp.c_proj:0/mean/0020fd96ed3b8c8b.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.0.mlp.c_proj:0/mean/0020fd96ed3b8c8b.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.0.mlp.c_proj:0/mean/0028d62e8a5629aa.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.0.mlp.c_proj:0/mean/0028d62e8a5629aa.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.0.mlp.c_proj:0/mean/003217c3eb469ba9.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.0.mlp.c_proj:0/mean/003217c3eb469ba9.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.0.mlp.c_proj:0/mean/0036621e4c7e10b5.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.0.mlp.c_proj:0/mean/0036621e4c7e10b5.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.0.mlp.c_proj:0/mean/00472b8e2d38d1ea.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.0.mlp.c_proj:0/mean/00472b8e2d38d1ea.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.1.mlp.c_proj:0/mean/0002bcb3da6cb337.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.1.mlp.c_proj:0/mean/0002bcb3da6cb337.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.1.mlp.c_proj:0/mean/0005c987bdfc9d4b.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.1.mlp.c_proj:0/mean/0005c987bdfc9d4b.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.1.mlp.c_proj:0/mean/0007e25b2121310b.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.1.mlp.c_proj:0/mean/0007e25b2121310b.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.1.mlp.c_proj:0/mean/0020fd96ed3b8c8b.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.1.mlp.c_proj:0/mean/0020fd96ed3b8c8b.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.1.mlp.c_proj:0/mean/0028d62e8a5629aa.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.1.mlp.c_proj:0/mean/0028d62e8a5629aa.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.1.mlp.c_proj:0/mean/003217c3eb469ba9.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.1.mlp.c_proj:0/mean/003217c3eb469ba9.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.1.mlp.c_proj:0/mean/0036621e4c7e10b5.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.1.mlp.c_proj:0/mean/0036621e4c7e10b5.pt -------------------------------------------------------------------------------- /tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.1.mlp.c_proj:0/mean/00472b8e2d38d1ea.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/toxicity-responses/tiny-gpt2/jigsaw/toxic/transformer.h.1.mlp.c_proj:0/mean/00472b8e2d38d1ea.pt -------------------------------------------------------------------------------- /tests/data/whispx-test-max/tiny-gpt2/transformer.h.0.mlp.c_proj.statedict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/whispx-test-max/tiny-gpt2/transformer.h.0.mlp.c_proj.statedict -------------------------------------------------------------------------------- /tests/data/whispx-test-max/tiny-gpt2/transformer.h.1.mlp.c_proj.statedict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-act/d2c3560b7022da795d58f892c398ab77cff13590/tests/data/whispx-test-max/tiny-gpt2/transformer.h.1.mlp.c_proj.statedict -------------------------------------------------------------------------------- /tests/test_0shot.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import tempfile 5 | from pathlib import Path 6 | 7 | import pandas as pd 8 | import pytest 9 | 10 | from act.evaluations import evaluate_0shot 11 | 12 | SYSTEM_PROMPT = [ 13 | "You are a chatbot that tells if a sentence is about fantasy.", 14 | "You are a chatbot that tells if a sentence is about football.", 15 | ] 16 | ANSWERS = ["no", "yes"] 17 | DATA = { 18 | "sentence": [ 19 | "Hello, my name is John and I play a sport with a ball.", 20 | "Unicorns are horses with a horn and some magic.", 21 | ], 22 | "prompt": [ 23 | "Can you introduce yourself?", 24 | "Dragons are like big lizards and breath fire.", 25 | ], 26 | "sentence2": [ 27 | "Messi is the best player.", 28 | "I once saw a strange insect.", 29 | ], 30 | } 31 | 32 | SYSTEM_PROMPT_MC = [ 33 | "Given the following 2 sentences, select the one that talks more about fantasy.", 34 | "Given the following 2 sentences, select the one that talks more about football.", 35 | ] 36 | ANSWERS_MC = ["A", "B"] 37 | DATA_MC = { 38 | "sentence": [ 39 | "Hello, my name is John and I play a sport with a ball.", 40 | "Unicorns are horses with a horn and some magic.", 41 | ], 42 | "prompt": [ 43 | "Can you introduce yourself?", 44 | "Dragons are like big lizards and breath fire.", 45 | ], 46 | "sentence2": [ 47 | "Messi is the best player.", 48 | "I once saw a strange insect.", 49 | ], 50 | } 51 | 52 | 53 | @pytest.mark.skip(reason="Uses Llama-3-8B-instruct, too large.") 54 | @pytest.mark.parametrize( 55 | "system_prompt,answers,data,use_second_csv,prepend_answers", 56 | [ 57 | (SYSTEM_PROMPT, ANSWERS, DATA, False, 0), 58 | (SYSTEM_PROMPT_MC, ANSWERS_MC, DATA_MC, False, 0), 59 | (SYSTEM_PROMPT_MC, ANSWERS_MC, DATA_MC, True, 0), 60 | (SYSTEM_PROMPT_MC, ANSWERS_MC, DATA_MC, False, 1), 61 | (SYSTEM_PROMPT_MC, ANSWERS_MC, DATA_MC, True, 1), 62 | ], 63 | ) 64 | def test_0shot_e2e(system_prompt, answers, data, use_second_csv, prepend_answers): 65 | with tempfile.TemporaryDirectory(dir="/tmp/") as tempfolder: 66 | csv_file = Path(tempfolder) / "test.csv" 67 | out_file = Path(tempfolder) / "out.csv" 68 | df = pd.DataFrame(data=data) 69 | df.to_csv(csv_file) 70 | 71 | second_csv_argv = [] 72 | if use_second_csv: 73 | csv_file2 = Path(tempfolder) / "test.csv" 74 | df.to_csv(csv_file2) 75 | second_csv_argv = [ 76 | "--data-path2", 77 | str(csv_file2), 78 | ] 79 | 80 | parser = evaluate_0shot.get_parser() 81 | args = parser.parse_args( 82 | [ 83 | "--device", 84 | "cpu", 85 | "--system-prompt", 86 | *system_prompt, 87 | "--system-answers", 88 | *answers, 89 | "--prepend-answers", 90 | str(prepend_answers), 91 | "--col-sentence1", 92 | "sentence", 93 | "--data-path", 94 | str(csv_file), 95 | "--output-file", 96 | str(out_file), 97 | *second_csv_argv, 98 | ] 99 | ) 100 | evaluate_0shot.main(args) 101 | df_out = pd.read_csv(out_file) 102 | 103 | assert len(df_out) == 2 104 | assert "q0_llm_answer" in df_out.columns 105 | assert "q1_llm_answer" in df_out.columns 106 | assert df_out["q0_llm_answer"].values[1] in ["yes", "A"] 107 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import os 5 | from collections import Counter 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import pytest 10 | import torch 11 | from transformers import AutoTokenizer 12 | 13 | from act.datasets import get_dataloader, get_dataset 14 | 15 | 16 | @pytest.fixture(scope="session") 17 | def tokenizer(): 18 | tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2") 19 | tokenizer.pad_token = tokenizer.eos_token 20 | return tokenizer 21 | 22 | 23 | @pytest.fixture(scope="session") 24 | def dummy_data(tokenizer): 25 | dataset, collator = get_dataset( 26 | "jigsaw", 27 | Path("tests/data/"), 28 | split="train", 29 | subsets=["toxic", "non-toxic"], 30 | tokenizer=tokenizer, 31 | ) 32 | return {"dataset": dataset, "collator": collator} 33 | 34 | 35 | @pytest.fixture(scope="session") 36 | def toxic_data(tokenizer): 37 | dataset, collator = get_dataset( 38 | "jigsaw", 39 | Path("tests/data/"), 40 | split="train", 41 | subsets=["toxic"], 42 | tokenizer=tokenizer, 43 | ) 44 | return {"dataset": dataset, "collator": collator} 45 | 46 | 47 | @pytest.fixture(scope="session") 48 | def coco_concepts(tokenizer): 49 | dataset, collator = get_dataset( 50 | "coco-captions-concepts", 51 | Path("tests/data"), 52 | split="train", 53 | subsets=["pink_elephant", "none"], 54 | tokenizer=None, 55 | ) 56 | return {"dataset": dataset, "collator": collator} 57 | 58 | 59 | def test_get_dataset(dummy_data): 60 | assert ( 61 | dummy_data["dataset"] is not None 62 | ) # assuming non-empty datasets for simplicity 63 | 64 | 65 | def test_get_dataloader(dummy_data): 66 | dataloader = get_dataloader( 67 | dummy_data["dataset"], 68 | batch_size=2, 69 | num_workers=0, 70 | collate_fn=dummy_data["collator"], 71 | drop_last=True, 72 | shuffle=False, 73 | ) 74 | 75 | # check if the dataloader is iterable and returns correct batches 76 | for i, batch in enumerate(dataloader): 77 | assert len(batch["input_ids"]) == 2 # assuming a batch size of 2 78 | 79 | 80 | def test_get_dataloader_balanced(dummy_data, toxic_data): 81 | dataloader = get_dataloader( 82 | dummy_data["dataset"], 83 | batch_size=2, 84 | num_workers=0, 85 | collate_fn=dummy_data["collator"], 86 | drop_last=True, 87 | shuffle=False, 88 | balanced=True, 89 | seed=42, # A fixed seed for reproducibility 90 | ) 91 | batch = next(iter(dataloader)) 92 | subsets = batch["subset"] 93 | assert "toxic" in subsets and "non-toxic" in subsets 94 | 95 | dataloader = get_dataloader( 96 | toxic_data["dataset"], 97 | batch_size=2, 98 | num_workers=0, 99 | collate_fn=dummy_data["collator"], 100 | drop_last=True, 101 | shuffle=False, 102 | balanced=True, 103 | seed=42, # A fixed seed for reproducibility 104 | ) 105 | batch = next(iter(dataloader)) 106 | subsets = batch["subset"] 107 | assert "toxic" in subsets and "non-toxic" not in subsets 108 | 109 | 110 | def test_coco_concepts(coco_concepts): 111 | dataset = coco_concepts["dataset"] 112 | assert "pink_elephant" in dataset[0]["subset"] 113 | assert "pink elephant" in dataset[0]["prompt"] 114 | assert "pink_elephant" not in dataset[1]["subset"] 115 | assert "pink elephant" not in dataset[1]["prompt"] 116 | dataset[len(dataset) - 1] 117 | -------------------------------------------------------------------------------- /tests/test_interventions.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import tempfile 5 | 6 | import numpy as np 7 | import pytest 8 | import torch 9 | from hydra import compose, initialize 10 | 11 | from act.hooks import get_hook 12 | from act.hooks.transport import GaussianOTHook 13 | from act.scripts.learn_intervention import ( 14 | InterventionsManager, 15 | learn_intervention, 16 | ) 17 | 18 | 19 | @pytest.mark.parametrize( 20 | "intervention_name", 21 | [ 22 | ("none"), 23 | ("aura"), 24 | ("mean_ot"), 25 | ("gaussian_ot"), 26 | ("linear_ot"), 27 | ], 28 | ) 29 | def test_hook_with_hook_args(intervention_name): 30 | # Assuming that the main function doesn't have any side effects and returns None when successful 31 | with tempfile.TemporaryDirectory(dir="/tmp/") as tempfolder: 32 | with initialize(version_base=None, config_path="../act/configs"): 33 | # config is relative to a module 34 | cfg = compose( 35 | config_name="text_generation", 36 | overrides=[ 37 | "device=cpu", 38 | "model.model_path=sshleifer/tiny-gpt2", 39 | "model.module_names=['transformer.h.0.mlp.c_proj:0', 'transformer.h.1.mlp.c_proj:0']", 40 | f"responses.tag=toxicity-responses", 41 | f"intervention_params.name={intervention_name}", 42 | f"intervention_params.incremental=atonce", 43 | "intervention_params.hook_params.quantiles_src=q_all", 44 | "data_dir=tests/data", 45 | "cache_dir=tests/data", 46 | f"interventions.cache_dir={tempfolder}", 47 | "compute_responses=false", 48 | "wandb.mode=disabled", 49 | ], 50 | ) 51 | 52 | # interventions_manager = InterventionsManager( 53 | # ResponsesManager.get_output_path(cfg.responses), cfg.interventions 54 | # ) 55 | # # This call also tests the "fit()" method 56 | # interventions_manager.learn_intervention_all() 57 | learn_intervention(cfg) 58 | 59 | # Now create a new hook and load its state using the one learnt through "fit()" 60 | hook = get_hook( 61 | intervention_name, 62 | module_name="transformer.h.1.mlp.c_proj:0", 63 | device="cpu", 64 | std_eps=1e-7, # <-- This one is not in statedict, will be updated. 65 | strength=0.9, 66 | ) 67 | state_path = ( 68 | InterventionsManager.get_output_path(cfg.interventions) 69 | / "transformer.h.1.mlp.c_proj:0.statedict" 70 | ) 71 | hook.from_state_path(state_path) 72 | 73 | # Also testing the intervention forward 74 | zs = torch.randn(32, 3, 2) # mu=0, std=1 75 | zs_post = hook(None, None, zs) 76 | assert torch.isnan(zs_post).sum() == 0 77 | assert torch.isinf(zs_post).sum() == 0 78 | 79 | if intervention_name == "none": 80 | assert torch.allclose(zs, zs_post) 81 | return 82 | 83 | # Check that values in statedict override those in constructor 84 | if intervention_name == "gaussian_ot": 85 | assert hook.onlymean is False 86 | elif intervention_name == "mean_ot": 87 | assert hook.onlymean is True 88 | if hasattr(hook, "hook_std_eps"): 89 | assert hook.std_eps == 1e-7 90 | 91 | assert hook.strength == 0.9 92 | 93 | 94 | def test_gaussian_function(): 95 | b, d = 1000, 5 96 | mu1_gt = 1.0 97 | std1_gt = 1.0 98 | mu2_gt = 5.0 99 | std2_gt = 0.5 100 | zs = torch.randn(b, d) * std1_gt + mu1_gt # mu=1, std=1 101 | zd = torch.randn(b, d) * std2_gt + mu2_gt # mu=5.0, std=0.5 102 | 103 | # Test in forward mode 104 | hook = GaussianOTHook( 105 | module_name="test", 106 | dtype=torch.float32, 107 | intervention_position="all", 108 | ) 109 | 110 | labels = torch.cat([torch.ones([b]), torch.zeros([b])]).to(torch.int64) 111 | hook.fit( 112 | responses=torch.cat([zs, zd], 0), 113 | labels=labels, 114 | ) 115 | 116 | d = hook.state_dict() 117 | assert np.allclose(d["mu1"], mu1_gt, atol=0.1), f"mean1: {d['mu1']} !!" 118 | assert np.allclose(d["mu2"], mu2_gt, atol=0.1), f"mean2: {d['mu2']} !!" 119 | assert np.allclose(d["std1"], std1_gt, atol=0.1), f"std1: {d['std1']} !!" 120 | assert np.allclose(d["std2"], std2_gt, atol=0.1), f"std2: {d['std2']} !!" 121 | print(d["quantiles_dict_src"]) 122 | print(zs.shape) 123 | z_transport = hook(None, None, zs) 124 | assert np.allclose(z_transport.mean(0), mu2_gt, atol=0.1), f"mean1: {d['mu1']} !!" 125 | assert np.allclose(z_transport.std(0), std2_gt, atol=0.1), f"std1: {d['std1']} !!" 126 | 127 | hook2 = GaussianOTHook( 128 | module_name="test", 129 | dtype=torch.float32, 130 | intervention_position="all", 131 | ) 132 | hook2.load_state_dict(d) 133 | d = hook2.state_dict() 134 | assert np.allclose(d["mu1"], 1.0, atol=0.1), f"mean: {d['mu1']} !!" 135 | assert np.allclose(d["mu2"], 5.0, atol=0.1), f"mean: {d['mu2']} !!" 136 | assert np.allclose(d["std1"], 1.0, atol=0.1), f"mean: {d['std1']} !!" 137 | assert np.allclose(d["std2"], 0.5, atol=0.1), f"mean: {d['std2']} !!" 138 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import torch 5 | from diffusers import StableDiffusionPipeline 6 | 7 | from act.models import get_model 8 | from act.models.model_with_hooks import get_model_with_hooks, is_module_name_in_regex 9 | 10 | 11 | # Define a dummy hook 12 | class DummyHook: 13 | def __init__(self, module_name=""): 14 | self.module_name = module_name 15 | 16 | def update(self): 17 | ... 18 | 19 | def __call__(self, module, input, output): 20 | self.outputs = {"": "dummy"} 21 | 22 | 23 | def test_init(): 24 | model = torch.nn.Linear(5, 2) 25 | mwh = get_model_with_hooks(model, model_task="text-generation") 26 | assert mwh.get_module() == model 27 | assert len(mwh._forward_hook_handles) == 0 28 | 29 | 30 | def test_register_hooks(): 31 | model = torch.nn.Linear(5, 2) 32 | mwh = get_model_with_hooks(model, model_task="text-generation") 33 | 34 | hook = DummyHook() 35 | mwh.register_hooks([hook]) 36 | 37 | # Test if the hook is registered correctly and outputs are as expected 38 | assert len(mwh._forward_hook_handles) == 1 39 | x = torch.randn((2, 5)) 40 | _ = model(x) 41 | mwh.update_hooks() 42 | assert "dummy" in mwh.get_hook_outputs()[""] 43 | 44 | 45 | def test_remove_hooks(): 46 | model = torch.nn.Linear(5, 2) 47 | mwh = get_model_with_hooks(model, model_task="text-generation") 48 | 49 | hook = DummyHook() 50 | mwh.register_hooks([hook]) 51 | 52 | # Remove the hook and check if it's empty 53 | mwh.remove_hooks() 54 | assert len(mwh._forward_hook_handles) == 0 55 | assert len(mwh.get_hook_outputs()) == 0 56 | 57 | 58 | def test_find_module_names(): 59 | model, tokenizer = get_model( 60 | "sshleifer/tiny-gpt2", 61 | rand_weights=True, 62 | cache_dir="/tmp/cache", 63 | device="cpu", 64 | dtype="float32", 65 | model_task="text-generation", 66 | ) 67 | mwh = get_model_with_hooks(model, model_task="text-generation") 68 | 69 | # Test if the method can find modules correctly 70 | module_names = mwh.find_module_names(mwh.get_module(), [".*0.mlp.*"]) 71 | print(mwh.get_module()) 72 | assert len(module_names) == 5 73 | 74 | # Also checking with an additional "." to see if the glob-style match works (omits layer ending with "mlp") 75 | module_names = mwh.find_module_names(mwh.get_module(), [".*0.mlp.+"]) 76 | assert len(module_names) == 4 77 | 78 | 79 | def test_get_target_module_names(): 80 | model, tokenizer = get_model( 81 | "sshleifer/tiny-gpt2", 82 | rand_weights=True, 83 | cache_dir="/tmp/cache", 84 | device="cpu", 85 | dtype="float32", 86 | model_task="text-generation", 87 | ) 88 | mwh = get_model_with_hooks(model, model_task="text-generation") 89 | 90 | hook = DummyHook("transformer.h.0.mlp.c_fc") 91 | mwh.register_hooks([hook]) 92 | 93 | # Test if the method can get target modules correctly 94 | target_module_names = mwh.get_target_module_names() 95 | assert "transformer.h.0.mlp.c_fc" in target_module_names 96 | 97 | 98 | def test_stable_diffusion(): 99 | pipe = StableDiffusionPipeline.from_pretrained( 100 | "hf-internal-testing/tiny-stable-diffusion-pipe" 101 | ) 102 | mwh = get_model_with_hooks(pipe, model_task="text-to-image-generation") 103 | assert len(mwh.find_module_names(mwh.module, ["vae.*"])) > 0 104 | assert len(mwh.find_module_names(mwh.module, ["unet.*"])) > 0 105 | assert len(mwh.find_module_names(mwh.module, ["text_encoder.*"])) > 0 106 | 107 | 108 | def test_is_module_name_in_regex(): 109 | # Test case where module name matches one or more regex expressions in the list 110 | assert len(is_module_name_in_regex("foo.bar", [".*", ""])) > 0 111 | assert len(is_module_name_in_regex("foo.bar", ["f.*", "b.*.r"])) > 0 112 | assert ( 113 | len(is_module_name_in_regex("hello.world", ["h.*.d", "he?lo.*", "wo?l?.r"])) > 0 114 | ) 115 | 116 | # Test regex with `:` 117 | assert is_module_name_in_regex("foo.bar:0", [".*bar.*"]) == ["foo.bar:0"] 118 | assert is_module_name_in_regex("foo.bar:1", [".*bar.*"]) == ["foo.bar:1"] 119 | assert is_module_name_in_regex("foo.bar:0", [".*bar"]) == ["foo.bar:0"] 120 | assert is_module_name_in_regex("foo.bar:1", [".*bar"]) == ["foo.bar:1"] 121 | assert is_module_name_in_regex("foo.bar:0", [".*bar:0"]) == ["foo.bar:0"] 122 | assert is_module_name_in_regex("foo.bar:1", [".*bar:0"]) == [] 123 | assert is_module_name_in_regex("foo.bar:1", [".*bar:0", ".*bar:1"]) == ["foo.bar:1"] 124 | assert is_module_name_in_regex("foo.bar:1", [".*bar:0", ".*bar"]) == ["foo.bar:1"] 125 | assert is_module_name_in_regex("foo.bar", [".*bar:0"]) == ["foo.bar:0"] 126 | assert is_module_name_in_regex("foo.bar", [".*bar:1"]) == ["foo.bar:1"] 127 | assert is_module_name_in_regex("foo.bar", [".*bar"]) == ["foo.bar"] 128 | assert is_module_name_in_regex("foo.bar", [".*bar.*"]) == ["foo.bar"] 129 | 130 | # Test case where module name does not match any regex expressions in the list 131 | # assert is_module_name_in_regex("foo.bar", ["f.*o"]) is None 132 | assert len(is_module_name_in_regex("foo.bar", ["f.*o", "b?.r"])) == 0 133 | assert ( 134 | len(is_module_name_in_regex("hello.world", ["h.*.x", "he?l?.z", "wo?l?.y"])) 135 | == 0 136 | ) 137 | 138 | # Test case where the list of regex expressions is empty 139 | assert len(is_module_name_in_regex("foo.bar", [])) == 0 140 | -------------------------------------------------------------------------------- /tests/test_perplexity.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import logging 5 | import tempfile 6 | from pathlib import Path 7 | 8 | import pandas as pd 9 | import pytest 10 | from hydra import compose, initialize 11 | 12 | # Importing learn_intervention to resolve dtype in hydra 13 | from act.evaluations import evaluate_perplexity 14 | 15 | logger = logging.getLogger("TEST(compute responses)") 16 | logger.setLevel(logging.DEBUG) 17 | 18 | 19 | def input_csv(): 20 | data = { 21 | "prompt": [ 22 | "My name is Alice and I play", 23 | "Once upon a time", 24 | "I like music", 25 | "Arm tree yellow orthogonal", 26 | ], 27 | "sentence": [ 28 | " football at school.", 29 | " there was a Hobbit.", 30 | " and dancing.", 31 | " great pull car.", 32 | ], 33 | } 34 | df = pd.DataFrame(data=data) 35 | return df 36 | 37 | 38 | @pytest.mark.parametrize( 39 | "with_prompt", 40 | [ 41 | "with_prompt", 42 | "without_prompt", 43 | ], 44 | ) 45 | def test_evaluate_perplexity(with_prompt): 46 | # Assuming that the main function doesn't have any side effects and returns None when successful 47 | with tempfile.TemporaryDirectory(dir="/tmp/") as tempfolder: 48 | tmpfile = Path(tempfolder) / "ppl.csv" 49 | df = input_csv() 50 | print(input_csv) 51 | df.to_csv(tmpfile) 52 | 53 | column_sentences = ( 54 | ["sentence", "prompt"] 55 | if with_prompt == "with_prompt" 56 | else [ 57 | "sentence", 58 | ] 59 | ) 60 | 61 | with initialize(version_base=None, config_path="../act/configs"): 62 | # config is relative to a module 63 | cfg = compose( 64 | config_name="text_generation", 65 | overrides=[ 66 | "fast=true", 67 | "device=cpu", 68 | f"results_dir={tempfolder}", 69 | "model_perplexity.perplexity_model_path=EleutherAI/pythia-70m", 70 | f"model_perplexity.data_path={tmpfile}", 71 | f"model_perplexity.column_sentences={column_sentences}", 72 | "wandb.mode=disabled", 73 | ], 74 | ) 75 | evaluate_perplexity.evaluate(cfg.model_perplexity) 76 | 77 | df = pd.read_csv( 78 | Path(tempfolder) / "evaluate_perplexity" / "model_perplexity.csv" 79 | ) 80 | assert f"ppl_pythia-70m" in df.columns 81 | assert df[f"ppl_pythia-70m"].argmax() == 3 82 | -------------------------------------------------------------------------------- /tests/test_pipeline.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import tempfile 5 | 6 | import pytest 7 | import torch 8 | from hydra import compose, initialize 9 | 10 | from act.scripts import pipeline 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "config_name,intervention_name,tasks", 15 | [ 16 | ("text_generation", "none", ["rtp"]), 17 | # ("text_generation", "none", ["text-generation"]), 18 | # ("text_generation", "none", ["text-generation", "model_perplexity"]), 19 | # pytest.param( 20 | # "text_generation", 21 | # "linear_ot", 22 | # ["text-generation", "zero_shot"], 23 | # marks=pytest.mark.slow, 24 | # ), 25 | # pytest.param("text_generation", "linear_ot", ["mmlu"], marks=pytest.mark.slow), 26 | # ("text_generation", "linear_ot", ["rtp"]), 27 | # ("text_generation", "linear_ot", ["text-generation"]), 28 | # ("text_generation", "linear_ot", ["text-generation", "model_perplexity"]), 29 | # pytest.param( 30 | # "text_generation", 31 | # "linear_ot", 32 | # ["text-generation", "zero_shot"], 33 | # marks=pytest.mark.slow, 34 | # ), 35 | # pytest.param("text_generation", "linear_ot", ["mmlu"], marks=pytest.mark.slow), 36 | ], 37 | ) 38 | def test_pipeline_main(config_name, intervention_name, tasks): 39 | device = "cuda" if torch.cuda.is_available() else "cpu" 40 | # Assuming that the main function doesn't have any side effects and returns None when successful 41 | with tempfile.TemporaryDirectory(dir="/tmp/") as tempfolder: 42 | with initialize(version_base=None, config_path="../act/configs"): 43 | cfg = compose( 44 | config_name=config_name, 45 | overrides=[ 46 | "fast=true", 47 | f"device={device}", 48 | "model.model_path=sshleifer/tiny-gpt2", 49 | "model.module_names=['transformer.h.0.mlp.c_proj:0', 'transformer.h.1.mlp.c_proj:0']", 50 | f"evaluation={tasks}", 51 | "responses.tag=toxicity-responses", 52 | f"intervention_params.name={intervention_name}", 53 | "intervention_params.hook_params.quantiles_src=q_all", 54 | "intervention_params.incremental=atonce", 55 | "data_dir=tests/data", 56 | "cache_dir=tests/data", 57 | f"interventions.cache_dir={tempfolder}", 58 | "compute_responses=false", 59 | f"results_dir={tempfolder}", 60 | "wandb.mode=disabled", 61 | ], 62 | ) 63 | pipeline.main(cfg) 64 | 65 | 66 | @pytest.mark.slow 67 | def test_pipeline_diffusion(): 68 | device = "cuda" if torch.cuda.is_available() else "cpu" 69 | # Assuming that the main function doesn't have any side effects and returns None when successful 70 | with tempfile.TemporaryDirectory(dir="/tmp/") as tempfolder: 71 | with initialize(version_base=None, config_path="../act/configs"): 72 | cfg = compose( 73 | config_name="text_to_image_generation", 74 | overrides=[ 75 | "fast=true", 76 | "responses.batch_size=2", 77 | "responses.max_batches=1", 78 | "interventions.batch_size=2", 79 | "interventions.max_batches=null", 80 | "text_to_image_generation.batch_size=1", 81 | "text_to_image_generation.max_batches=1", 82 | "text_to_image_generation.strength=[0.0, 1.0]", 83 | f"device={device}", 84 | # "model.model_path='hf-internal-testing/tiny-stable-diffusion-pipe'", 85 | "model.module_names=['unet.down_blocks.0.resnets.0.norm1']", 86 | f"evaluation=['text-to-image-generation', 'clip_score']", 87 | "responses.tag=diffusion-responses", 88 | f"intervention_params.name=linear_ot", 89 | "intervention_params.hook_params.quantiles_src=q_all", 90 | "intervention_params.incremental=incr", 91 | "data_dir=tests/data", 92 | "cache_dir=tests/data", 93 | f"responses.cache_dir={tempfolder}", 94 | f"interventions.cache_dir={tempfolder}", 95 | "compute_responses=true", 96 | f"results_dir={tempfolder}", 97 | "wandb.mode=disabled", 98 | "text_to_image_generation.num_inference_steps=1", 99 | ], 100 | ) 101 | pipeline.main(cfg) 102 | -------------------------------------------------------------------------------- /tests/test_responses.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import logging 5 | import tempfile 6 | import typing as t 7 | from pathlib import Path 8 | 9 | import pytest 10 | import torch 11 | from hydra import compose, initialize 12 | 13 | from act.scripts.learn_intervention import ResponsesManager 14 | 15 | logger = logging.getLogger("TEST(compute responses)") 16 | logger.setLevel(logging.DEBUG) 17 | 18 | 19 | @pytest.mark.parametrize( 20 | "max_batches,batch_size", 21 | [ 22 | (2, 4), 23 | (1, 4), 24 | (4, 1), 25 | (3, 2), 26 | ], 27 | ) 28 | def test_compute_responses(max_batches, batch_size): 29 | # Assuming that the main function doesn't have any side effects and returns None when successful 30 | with tempfile.TemporaryDirectory(dir="/tmp/") as tempfolder: 31 | with initialize(version_base=None, config_path="../act/configs"): 32 | # config is relative to a module 33 | cfg = compose( 34 | config_name="text_generation", 35 | overrides=[ 36 | "device=cpu", 37 | f"responses.batch_size={batch_size}", 38 | f"responses.max_batches={max_batches}", 39 | "model.model_path=sshleifer/tiny-gpt2", 40 | "model.module_names=['transformer.h.0.mlp.c_proj:0', 'transformer.h.1.mlp.c_proj:0']", 41 | "responses.tag=toxicity-responses", 42 | "data_dir=tests/data", 43 | f"cache_dir={tempfolder}", 44 | "compute_responses=true", 45 | "wandb.mode=disabled", 46 | ], 47 | ) 48 | rm = ResponsesManager(cfg.responses) 49 | responses_path = rm.compute_responses() 50 | assert responses_path.exists() 51 | 52 | def match_subdirs(root: Path, expected: t.Set) -> bool: 53 | subdirs = list(root.glob("[!.]*")) # skip hidden files 54 | return expected == {elem.name for elem in subdirs} 55 | 56 | assert match_subdirs(responses_path, {"toxic", "non-toxic"}) 57 | assert match_subdirs( 58 | responses_path / "toxic", 59 | {"transformer.h.0.mlp.c_proj:0", "transformer.h.1.mlp.c_proj:0"}, 60 | ) 61 | assert match_subdirs( 62 | responses_path / "non-toxic", 63 | {"transformer.h.0.mlp.c_proj:0", "transformer.h.1.mlp.c_proj:0"}, 64 | ) 65 | 66 | pooling_op = cfg.responses.intervention_params.pooling_op 67 | batches = list( 68 | ( 69 | responses_path / "toxic" / f"transformer.h.1.mlp.c_proj:0" / pooling_op 70 | ).glob("*.pt") 71 | ) 72 | assert len(batches) == max_batches * batch_size / 2 73 | batch_data = torch.load(batches[0]) 74 | assert "id" in batch_data 75 | assert batch_data["responses"].numel() == 2 # tiny-gpt2 has 2 neurons 76 | -------------------------------------------------------------------------------- /tests/test_responses_io.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import pytest 8 | 9 | from act.datasets.responses_io import ResponsesLoader 10 | 11 | 12 | @pytest.fixture(scope="module") 13 | def responses_loader(): 14 | loader = ResponsesLoader( 15 | Path("./tests/data/toxicity-responses/tiny-gpt2/jigsaw"), 16 | from_folders=[ 17 | Path("*/transformer*/mean"), 18 | ], 19 | ) 20 | return loader 21 | 22 | 23 | def test_get_attribute_values(responses_loader): 24 | attribute_name = "module_names" 25 | values = responses_loader.get_attribute_values(attribute_name) 26 | 27 | # Check if the returned set is not empty 28 | assert len(values) > 0, f"No values found for attribute {attribute_name}" 29 | 30 | 31 | def test_load_data_subset(responses_loader): 32 | filter = {"pooling_op": ["mean"]} 33 | data = responses_loader.load_data_subset(filter) 34 | 35 | # Check if the returned dictionary is not empty 36 | assert len(data) > 0, "No data loaded" 37 | 38 | filter = {"module_names": ["transformer.h.0.mlp.c_proj:0"]} 39 | data = responses_loader.load_data_subset(filter) 40 | assert set(data["subset"]) == set(["toxic", "non-toxic"]) 41 | assert set(data["module_names"]) == set(["transformer.h.0.mlp.c_proj:0"]) 42 | 43 | filter = {"module_names": ["transformer.h.0.mlp.c_proj:0"], "subset": ["non-toxic"]} 44 | data = responses_loader.load_data_subset(filter) 45 | assert set(data["subset"]) == set(["non-toxic"]) 46 | assert set(data["module_names"]) == set(["transformer.h.0.mlp.c_proj:0"]) 47 | 48 | fail = False 49 | try: 50 | responses_loader.load_data_subset({"APPLE": ["PIE"]}) 51 | except: 52 | fail = True 53 | assert fail, "load_data_subset should fail with unknown keys." 54 | 55 | 56 | def test_responses_loader(responses_loader): 57 | data_subset = { 58 | "responses": np.arange(10).reshape((10, 1)), 59 | "subset": np.asarray(["A"] * 5 + ["B"] * 4 + ["C"]), 60 | } 61 | labeled_data = ResponsesLoader.label_src_dst_subsets( 62 | data_subset, 63 | src_subset=["A", "B"], 64 | dst_subset=["B"], 65 | key="subset", 66 | balanced=False, 67 | seed=0, 68 | ) 69 | labels = labeled_data["label"] 70 | src_data = {k: v[labels == 1] for k, v in labeled_data.items()} 71 | dst_data = {k: v[labels == 0] for k, v in labeled_data.items()} 72 | assert "D" not in set(src_data["subset"]) and "D" not in set(dst_data["subset"]) 73 | assert len([s for s in src_data["subset"] if s == "A"]) == 5 74 | assert len([s for s in dst_data["subset"] if s == "A"]) == 0 75 | assert len([s for s in src_data["subset"] if s == "B"]) == 2 76 | assert len([s for s in dst_data["subset"] if s == "B"]) == 2 77 | labeled_data = ResponsesLoader.label_src_dst_subsets( 78 | data_subset, 79 | src_subset=["A", "B"], 80 | dst_subset=["B"], 81 | key="subset", 82 | balanced=True, 83 | seed=0, 84 | ) 85 | labels = labeled_data["label"] 86 | src_data = {k: v[labels == 1] for k, v in labeled_data.items()} 87 | dst_data = {k: v[labels == 0] for k, v in labeled_data.items()} 88 | assert len([s for s in src_data["subset"] if s == "A"]) == 1 89 | assert len([s for s in dst_data["subset"] if s == "A"]) == 0 90 | assert len([s for s in src_data["subset"] if s == "B"]) == 1 91 | assert len([s for s in dst_data["subset"] if s == "B"]) == 2 92 | --------------------------------------------------------------------------------