├── .gitignore ├── LICENSE ├── README.md ├── assets ├── fig_1.gif └── fig_2.gif ├── cluster_scripts ├── pbs │ └── pruning │ │ ├── helper.sh │ │ └── sd2-1_cc3m.pbs └── slurm │ ├── filtering │ ├── filter_coco.slurm │ └── fliter_cc3m.slurm │ ├── finetuning │ ├── sd2-1_cc3m.slurm │ └── sd2-1_coco.slurm │ ├── img_generation │ └── sd2-1_cc3m.slurm │ └── pruning │ ├── sd2-1_cc3m.slurm │ └── sd2-1_coco.slurm ├── cmmd-pytorch ├── cmmd_utils.py ├── compute_cmmd.py ├── distance.py ├── embedding.py ├── generate_images.py ├── io_util.py ├── requirements.txt └── save_refs.py ├── configs ├── baselines │ ├── finetuning │ │ └── sd-2-1_cc3m_magnitude.yaml │ ├── img_generation │ │ └── sd-2-1_cc3m_magnitude.yaml │ └── pruning │ │ ├── sd-2-1_coco2014_ot.yaml │ │ ├── sd-2-1_coco2014_single_arch.yaml │ │ └── sd-2-1_coco2014_without_ot.yaml ├── filtering │ ├── sd-2-1_cc3m.yaml │ └── sd-2-1_coco2014.yaml ├── finetuning │ ├── sd-2-1_cc3m.yaml │ └── sd-2-1_coco2014.yaml ├── img_generation │ └── sd-2-1_cc3m.yaml └── pruning │ ├── sd-2-1_cc3m.yaml │ └── sd-2-1_coco2014.yaml ├── env.yaml ├── pdm ├── __init__.py ├── datasets │ ├── __init__.py │ ├── cc3m.py │ └── coco.py ├── losses │ ├── __init__.py │ ├── contrastive_loss.py │ └── resource_loss.py ├── models │ ├── __init__.py │ ├── hypernet │ │ ├── __init__.py │ │ └── hypernet.py │ ├── unet │ │ ├── __init__.py │ │ ├── blocks.py │ │ ├── gates.py │ │ └── unet_2d_conditional.py │ └── vq │ │ ├── __init__.py │ │ └── quantizer.py ├── pipelines │ ├── __init__.py │ └── pruning_pipelines.py ├── training │ ├── __init__.py │ └── trainer.py └── utils │ ├── __init__.py │ ├── arg_utils.py │ ├── clip_utils.py │ ├── data_utils.py │ ├── dist_utils.py │ ├── estimation_utils.py │ ├── logging_utils.py │ ├── metric_utils.py │ └── op_counter.py ├── pyproject.toml └── scripts ├── aptp ├── filter_dataset.py ├── finetune.py └── prune.py ├── baselines ├── magnitude │ ├── finetune_magnitude.py │ └── generate_images.py ├── random │ └── finetune_random_arch.py ├── sd │ ├── finetune_sd.py │ └── generate_images.py ├── structural │ ├── finetune_structural.py │ └── generate_images.py └── uni_arch │ └── finetune_baseline_param.py ├── metrics ├── clip_features.py ├── clip_score.py ├── fid.py ├── generate_fid_images.py ├── resize_and_save_images.py ├── sample_coco_30k.py └── save_captions.py └── other ├── calculate_pruning_ratio.py └── depth_analysis.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | slurm_scripts/ 156 | pbs_scripts/ 157 | 158 | .idea/ 159 | .vscode/ 160 | 161 | data/ 162 | 163 | *reza* 164 | 165 | # PyCharm 166 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 167 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 168 | # and can be added to the global gitignore or merged into this file. For a more nuclear 169 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 170 | #.idea/ 171 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 rezashkv 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ICLR 2025] APTP: Adaptive Prompt-Tailored Pruning of T2I Diffusion Models 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2406.12042-red.svg)](https://arxiv.org/abs/2406.12042) 4 | [![Hugging Face Models](https://img.shields.io/badge/🤗%20Hugging%20Face-Models-yellow)](https://huggingface.co/rezashkv/APTP) 5 | 6 | 7 | The implementation of the 8 | paper ["Not All Prompts Are Made Equal: Prompt-based Pruning of Text-to-Image Diffusion Models"](https://arxiv.org/abs/2406.12042) 9 | 10 |

11 | APTP Overview 12 |

13 |

14 | APTP: We prune a text-to-image diffusion model like Stable Diffusion (left) into a mixture of efficient experts (right) in a prompt-based manner. Our prompt router routes distinct types of prompts to different experts, allowing experts' architectures to be separately specialized by removing layers or channels. 15 |

16 | 17 |

18 | APTP Pruning Scheme 19 |

20 |

21 | APTP pruning scheme. We train the prompt router and the set of architecture codes to prune a 22 | T2I diffusion model into a mixture of experts. The prompt router consists of three modules. We use 23 | a Sentence Transformer as the prompt encoder to encode the input prompt into a representation z. Then, 24 | the architecture predictor transforms z into the architecture embedding e that has the same dimensionality as 25 | architecture codes. Finally, the router routes the embedding e into an architecture code a(i). We use optimal 26 | transport to evenly distribute the prompts in a training batch among the architecture codes. The architecture code 27 | a(i) = (u(i), v(i)) determines pruning the model’s width and depth. We train the prompt router’s parameters and 28 | architecture codes in an end-to-end manner using the denoising objective of the pruned model LDDPM, distillation 29 | loss between the pruned and original models Ldistill, average resource usage for the samples in the batch R, and 30 | contrastive objective Lcont, encouraging embeddings e preserving semantic similarity of the representations z. 31 | 32 |

33 | 34 | ## Table of Contents 35 | 36 | 1. [Installation](#installation) 37 | 2. [Data Preparation](#data-preparation) 38 | - [Download Conceptual Captions](#1-download-conceptual-captions) 39 | - [Download MS-COCO 2014](#2-download-ms-coco-2014) 40 | 3. [Training](#training) 41 | - [Pruning](#1-pruning) 42 | - [Data Preparation for Fine-tuning](#2-data-preparation-for-fine-tuning) 43 | - [Fine-tuning](#3-fine-tuning) 44 | 4. [Image Generation](#image-generation) 45 | 5. [Evaluation](#evaluation) 46 | - [FID Score](#1-fid-score) 47 | - [CLIP Score](#2-clip-score) 48 | - [CMMD](#3-cmmd) 49 | 6. [Baselines](#baselines) 50 | 7. [License](#license) 51 | 8. [Citation](#citation) 52 | 53 | ## Installation 54 | 55 | Follow these steps to set up the project: 56 | 57 | ### 1. Create Conda Environment 58 | 59 | Use the provided [env.yaml](env.yaml) file: 60 | 61 | ```bash 62 | conda env create -f env.yaml 63 | ``` 64 | 65 | ### 2. Activate the Conda Environment 66 | 67 | Activate the environment: 68 | 69 | ```bash 70 | conda activate pdm 71 | ``` 72 | 73 | ### 3. Install Project Dependencies 74 | 75 | From the project root directory, install the dependencies: 76 | 77 | ```bash 78 | pip install -e . 79 | ``` 80 | 81 | ## Data Preparation 82 | 83 | Prepare the data for training as mentioned in the paper. You can also adapt aptp for your own dataset with minor code modifications. 84 | 85 | ### 1. Download Conceptual Captions 86 | 87 | Follow the instructions [here](https://github.com/igorbrigadir/DownloadConceptualCaptions) to download Conceptual Captions. Place the data in a directory of your choice, maintaining the structure: 88 | 89 | ``` 90 | conceptual_captions 91 | ├── Train_GCC-training.tsv 92 | ├── Val_GCC-1.1.0-Validation.tsv 93 | ├── training 94 | │ ├── 10007_560483514 95 | │ └── ... 96 | └── validation 97 | ├── 1852290_2006010568 98 | └── ... 99 | ``` 100 | 101 | #### 1.1 Remove Corrupt Images (Optional) 102 | 103 | There are some urls in the Conceptual Captions dataset that are not valid. The download will result in some corrupt files that can't be opened. These could be removed to ensure a more 104 | efficient training. 105 | 106 | ### 2. Download MS-COCO 2014 107 | 108 | #### 2.1 Download the training and validation images 109 | 110 | Download [2014 train](http://images.cocodataset.org/zips/train2014.zip) 111 | and [2014 val](http://images.cocodataset.org/zips/val2014.zip) images from 112 | the [COCO website](http://cocodataset.org/#download). Place them in your chosen directory. 113 | 114 | #### 2.2 Download the annotations 115 | 116 | Download the [2014 train/val annotations](http://images.cocodataset.org/annotations/annotations_trainval2014.zip) 117 | and place them in the same directory as the images. Your directory should look like this: 118 | 119 | ``` 120 | coco 121 | ├── annotations 122 | │ ├── captions_train2014.json 123 | │ ├── captions_val2014.json 124 | │ └── ... 125 | └── images 126 | ├── train2014 127 | │ ├── COCO_train2014_000000000009.jpg 128 | │ └── ... 129 | └── val2014 130 | ├── COCO_val2014_000000000042.jpg 131 | └── ... 132 | ``` 133 | 134 | ## Training 135 | 136 | Training is done in two stages: pruning the pretrained T2I model (Stable Diffusion 2.1 in this case) and fine-tuning each expert on the prompts assigned to it. 137 | Configuration files for both Conceptual Captions and MS-COCO are provided in the [configs](configs) directory. 138 | You can use these configuration files to run the pruning process. Sample multi-node [SLURM](https://slurm.schedmd.com/) and [PBS](https://www.openpbs.org/) scripts can be found in [cluster scripts](cluster_scripts). 139 | 140 | ### 1. Pruning 141 | 142 | You can use the following command to run pruning: 143 | 144 | ```bash 145 | 146 | accelerate launch scripts/aptp/prune.py \ 147 | --base_config_path path/to/configs/pruning/file.yaml \ 148 | --cache_dir /path/to/.cache/huggingface/ \ 149 | --wandb_run_name WANDB_PRUNING_RUN_NAME 150 | ``` 151 | This creates a checkpoint directory named "wandb_run_name" in the logging directory specified in the config file. 152 | 153 | ### 2. Data Preparation for Fine-tuning 154 | The pruning stages results in $K$ architecture codes (experts). We need to assign each training prompt to its corresponding expert for fine-tuning. Assuming the pruning checkpoint directory is `pruning_checkpoint_dir`, you can use the following command to run this filtering process: 155 | 156 | ```bash 157 | accelerate launch scripts/aptp/filter_dataset.py \ 158 | --pruning_ckpt_dir path/to/pruning_checkpoint_dir \ 159 | --base_config_path path/to/configs/filtering/dataset.yaml \ 160 | --cache_dir /path/to/.cache/huggingface/ 161 | ``` 162 | This creates two files named `{DATASET}_train_mapped_indices.pt` and `{DATASET}_validation_mapped_indices.pt` in the `pruning_checkpoint_dir` directory. These files contain the expert id for each prompt in the training and validation sets. 163 | 164 | 165 | ### 3. Fine-tuning 166 | Fine-tune an expert on the prompts assigned to it: 167 | 168 | ```bash 169 | accelerate launch scripts/aptp/finteune.py \ 170 | --pruning_ckpt_dir path/to/pruning_checkpoint_dir \ 171 | --expert_id INDEX \ 172 | --base_config_path path/to/configs/finetuning/dataset.yaml \ 173 | --cache_dir /path/to/.cache/huggingface/ \ 174 | --wandb_run_name WANDB_FINETUNING_RUN_NAME 175 | ``` 176 | 177 | ## Image Generation 178 | 179 | Generate images using the experts: 180 | 181 | ```bash 182 | accelerate launch scripts/metrics/generate_fid_images.py \ 183 | --finetuning_ckpt_dir path/to/pruning_checkpoint_dir \ 184 | --expert_id INDEX \ 185 | --base_config_path path/to/configs/img_generation/dataset.yaml \ 186 | --cache_dir /path/to/.cache/huggingface/ 187 | ``` 188 | The generated images will be saved in the`{DATASET}_fid_images` directory within the root finetuning checkpoint directory. 189 | 190 | ## Evaluation 191 | To evaluate APTP, we report the FID, CLIP Score, and CMMD. 192 | 193 | ### 1. FID Score 194 | We use [clean-fid](https://github.com/GaParmar/clean-fid) to calculate the FID score. The numbers reported in the paper are calculated using this pytorch legacy mode. 195 | 196 | #### 1.1 Conceptual Captions Preparation 197 | We report FID on the validation set of Conceptual Captions. So we can use the same validation mapped indices file created in the filtering process. First, we need to resize the images to 256x256. You can use the [provided script](scripts/metrics/resize_and_save_images.py). It will save the images as numpy arrays in the same root directory of the dataset. 198 | 199 | 200 | #### 1.2 MS-COCO Preparation 201 | We sample 30k images from the 2014 validation set of MS-COCO. Check out the [sample and resize script](scripts/metrics/sample_coco_30k.py). We need a [filtering step](#2-data-preparation-for-fine-tuning) on this subset to assign prompts to experts. 202 | 203 | #### 1.3 Generate Custom Statistics 204 | 205 | Generate custom statistics for both sets of reference images:: 206 | ```bash 207 | from cleanfid import fid 208 | fid.make_custom_stats(dataset, dataset_path, mode="legacy_pytorch") # mode can be clean too. 209 | ``` 210 | 211 | Now we can calculate the FID score for the generate images using [the provided script](scripts/metrics/fid.py). 212 | 213 | ### 2. CLIP Score 214 | To calculate clip score, we use [this library](https://github.com/Taited/clip-score). Extract features of reference images with the [clip feature extraction script](scripts/metrics/clip_features.py) and calculate the score using the [clip score script](scripts/metrics/clip_score.py). 215 | 216 | ### 3. CMMD 217 | We use the [cmmd-pytorch](https://github.com/sayakpaul/cmmd-pytorch) library to calculate CMMD. Refer to [save_refs.py](cmmd-pytorch/save_refs.py) and [compute_cmmd.py](cmmd-pytorch/compute_cmmd.py) scripts for reference set feature extraction and distance calculation details. 218 | 219 | 220 | ## Baselines 221 | Refer to [here](configs/baselines) for config files for all baselines mentioned in the paper. The scripts to run these baselines are in [this directory](scripts/baselines). 222 | 223 | 224 | ## License 225 | 226 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 227 | 228 | ## Citation 229 | If you find this work useful, please consider citing the following paper: 230 | 231 | ```bibtex 232 | @article{2024aptp, 233 | title={Not All Prompts Are Made Equal: Prompt-based Pruning of Text-to-Image Diffusion Models}, 234 | author={Ganjdanesh, Alireza, and Shirkavand, Reza and Gao, Shangqian and Huang, Heng}, 235 | journal={arXiv preprint arXiv:2406.12042}, 236 | year={2024} 237 | } 238 | ``` 239 | -------------------------------------------------------------------------------- /assets/fig_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezashkv/diffusion_pruning/d0709c3a750c6b9ebb25a2901284cffdca98d769/assets/fig_1.gif -------------------------------------------------------------------------------- /assets/fig_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezashkv/diffusion_pruning/d0709c3a750c6b9ebb25a2901284cffdca98d769/assets/fig_2.gif -------------------------------------------------------------------------------- /cluster_scripts/pbs/pruning/helper.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MAIN_HOSTNAME=$1 4 | MAIN_PORT=$2 5 | NNODES=$3 6 | NGPUS_PER_NODE=$4 7 | CONFIG_PATH=$5 8 | WANDB_RUN_NAME=$6 9 | 10 | source ~/.bashrc 11 | conda activate pdm 12 | 13 | cd /path/to/diffusion_pruning/scripts/aptp || exit 14 | 15 | torchrun \ 16 | --nnodes $NNODES \ 17 | --nproc_per_node=$NGPUS_PER_NODE \ 18 | --rdzv_endpoint=$MAIN_HOSTNAME:$MAIN_PORT \ 19 | --rdzv_id=12345 \ 20 | --rdzv_backend=c10d \ 21 | prune.py \ 22 | --base_config_path $CONFIG_PATH \ 23 | --cache_dir \ 24 | /scratch/yucongd/rs/.cache/huggingface/ \ 25 | --seed \ 26 | 43 \ 27 | --wandb_run_name \ 28 | $WANDB_RUN_NAME 2>&1 -------------------------------------------------------------------------------- /cluster_scripts/pbs/pruning/sd2-1_cc3m.pbs: -------------------------------------------------------------------------------- 1 | #PBS -N sd-prune-cc3m 2 | #PBS -o out.log 3 | #PBS -e err.log 4 | #PBS -l select=2:ncpus=64:mem=128gb:ngpus=2:gpu_model=a100:interconnect=hdr,walltime=72:00:00 5 | 6 | set -x -e 7 | 8 | # log the PBS environment 9 | echo "start time: $(date)" 10 | echo "PBS_JOBID="$PBS_JOBID 11 | echo "PBS_NODELIST"=$PBS_NODELIST 12 | echo "PBS_NODEFILE"=$PBS_NODEFILE 13 | echo "PBS_O_WORKDIR"=$PBS_O_WORKDIR 14 | echo "PBS_GPU_FILE"=$PBS_GPU_FILE 15 | 16 | # Training setup 17 | NNODES=4 #or $PBS_NUM_NODES if it gets set automatically 18 | GPUS_PER_NODE=2 19 | 20 | MAIN_ADDR=$(cat $PBS_NODEFILE | head -n 1) 21 | MAIN_PORT=5560 22 | 23 | 24 | echo "MASTER_ADDR"=$MASTER_ADDR 25 | echo "NNODES"=$NNODES 26 | echo "NODE_RANK"=$NODE_RANK 27 | 28 | export WANDB_CACHE_DIR=/path/to/wandb 29 | 30 | export NCCL_DEBUG=INF 31 | 32 | 33 | CONFIG_PATH="/path/to/diffusion_pruning/configs/pruning/sd-2-1_cc3m.yaml" 34 | WANDB_RUN_NAME="prune-sd-2-1_cc3m" 35 | 36 | pbsdsh -- bash /path/to/diffusion_pruning/cluster_scripts/pbs/pruning/helper.sh $MAIN_ADDR $MAIN_PORT $NNODES $GPUS_PER_NODE $CONFIG_PATH $WANDB_RUN_NAME 37 | 38 | echo "END TIME: $(date)" -------------------------------------------------------------------------------- /cluster_scripts/slurm/filtering/filter_coco.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=sd-filter-coco # Specify a name for your job 3 | #SBATCH --output=logs/out-%x-%j.log # Specify the output log file 4 | #SBATCH --error=logs/err-%x-%j.log # Specify the error log file 5 | #SBATCH --cpus-per-task=8 # Number of CPU cores per task 6 | #SBATCH --nodes=1 # Number of nodes 7 | #SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! 8 | #SBATCH --gres=gpu:a100:2 # Number of GPUs to request and specify the GPU type 9 | #SBATCH --time=24:00:00 # Maximum execution time (HH:MM:SS) 10 | #SBATCH --mem=64G # Memory per node 11 | 12 | set -x -e 13 | 14 | # Some info 15 | echo "start time: $(date)" 16 | echo "SLURM_JOBID="$SLURM_JOBID 17 | echo "SLURM_JOB_NODELIST"=$SLURM_JOB_NODELIST 18 | echo "SLURM_JOB_PARTITION"=$SLURM_JOB_PARTITION 19 | echo "SLURM_NNODES"=$SLURM_NNODES 20 | echo "SLURM_GPUS_ON_NODE"=$SLURM_GPUS_ON_NODE 21 | echo "SLURM_SUBMIT_DIR"s=$SLURM_SUBMIT_DIR 22 | 23 | # Training setup 24 | MAIN_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 25 | MAIN_PORT=5561 26 | NNODES=$SLURM_NNODES 27 | NODE_RANK=$SLURM_PROCID 28 | WORLD_SIZE=$(($SLURM_GPUS_ON_NODE * $NNODES)) 29 | 30 | echo "MAIN_ADDR"=$MAIN_ADDR 31 | echo "NNODES"=$NNODES 32 | echo "NODE_RANK"=$NODE_RANK 33 | 34 | 35 | export WANDB_CACHE_DIR=/path/to/wandb 36 | 37 | export NCCL_DEBUG=INFO 38 | 39 | source ~/.bashrc 40 | conda activate pdm 41 | 42 | checkpoint_dir=$1 43 | echo "CHECKPOINT_DIR="$checkpoint_dir 44 | 45 | cd /path/to/diffusion_pruning/scripts/aptp || exit 46 | 47 | CMD=" \ 48 | filter_dataset.py \ 49 | --pruning_ckpt_dir \ 50 | $checkpoint_dir \ 51 | --base_config_path \ 52 | /path/to/diffusion_pruning/configs/pruning/sd-2-1_coco2014.yaml \ 53 | --cache_dir \ 54 | /path/to/.cache/huggingface/ \ 55 | --wandb_run_name \ 56 | filter_sd-2-1_coco \ 57 | " 58 | 59 | LAUNCHER="accelerate launch \ 60 | --num_machines $NNODES \ 61 | --num_processes $WORLD_SIZE \ 62 | --main_process_ip $MAIN_ADDR \ 63 | --main_process_port $MAIN_PORT \ 64 | --machine_rank \$SLURM_PROCID \ 65 | --role $SLURMD_NODENAME: \ 66 | --rdzv_conf rdzv_backend=c10d \ 67 | --max_restarts 0 \ 68 | --tee 3 \ 69 | " 70 | 71 | SRUN_ARGS=" \ 72 | --wait=60 \ 73 | --kill-on-bad-exit=1 \ 74 | " 75 | 76 | srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER $CMD" 2>&1 77 | 78 | echo "END TIME: $(date)" 79 | -------------------------------------------------------------------------------- /cluster_scripts/slurm/filtering/fliter_cc3m.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=sd-filter-cc3m # Specify a name for your job 3 | #SBATCH --output=logs/out-%x-%j.log # Specify the output log file 4 | #SBATCH --error=logs/err-%x-%j.log # Specify the error log file 5 | #SBATCH --cpus-per-task=8 # Number of CPU cores per task 6 | #SBATCH --nodes=1 # Number of nodes 7 | #SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! 8 | #SBATCH --gres=gpu:a100:2 # Number of GPUs to request and specify the GPU type 9 | #SBATCH --time=24:00:00 # Maximum execution time (HH:MM:SS) 10 | #SBATCH --mem=64G # Memory per node 11 | 12 | set -x -e 13 | 14 | # Some info 15 | echo "start time: $(date)" 16 | echo "SLURM_JOBID="$SLURM_JOBID 17 | echo "SLURM_JOB_NODELIST"=$SLURM_JOB_NODELIST 18 | echo "SLURM_JOB_PARTITION"=$SLURM_JOB_PARTITION 19 | echo "SLURM_NNODES"=$SLURM_NNODES 20 | echo "SLURM_GPUS_ON_NODE"=$SLURM_GPUS_ON_NODE 21 | echo "SLURM_SUBMIT_DIR"s=$SLURM_SUBMIT_DIR 22 | 23 | # Training setup 24 | MAIN_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 25 | MAIN_PORT=5560 26 | NNODES=$SLURM_NNODES 27 | NODE_RANK=$SLURM_PROCID 28 | WORLD_SIZE=$(($SLURM_GPUS_ON_NODE * $NNODES)) 29 | 30 | echo "MAIN_ADDR"=$MAIN_ADDR 31 | echo "NNODES"=$NNODES 32 | echo "NODE_RANK"=$NODE_RANK 33 | 34 | 35 | export WANDB_CACHE_DIR=/path/to/wandb 36 | 37 | export NCCL_DEBUG=INFO 38 | 39 | source ~/.bashrc 40 | conda activate pdm 41 | 42 | cd /path/to/diffusion_pruning/scripts/aptp || exit 43 | 44 | checkpoint_dir=$1 45 | echo "CHECKPOINT_DIR="$checkpoint_dir 46 | 47 | CMD=" \ 48 | prune.py \ 49 | --pruning_ckpt_dir \ 50 | $checkpoint_dir \ 51 | --base_config_path \ 52 | /path/to/diffusion_pruning/configs/pruning/sd-2-1_cc3m.yaml \ 53 | --cache_dir \ 54 | /path/to/.cache/huggingface/ \ 55 | --wandb_run_name \ 56 | filter_sd-2-1_cc3m \ 57 | " 58 | 59 | LAUNCHER="accelerate launch \ 60 | --num_machines $NNODES \ 61 | --num_processes $WORLD_SIZE \ 62 | --main_process_ip $MAIN_ADDR \ 63 | --main_process_port $MAIN_PORT \ 64 | --machine_rank \$SLURM_PROCID \ 65 | --role $SLURMD_NODENAME: \ 66 | --rdzv_conf rdzv_backend=c10d \ 67 | --max_restarts 0 \ 68 | --tee 3 \ 69 | " 70 | 71 | SRUN_ARGS=" \ 72 | --wait=60 \ 73 | --kill-on-bad-exit=1 \ 74 | " 75 | 76 | srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER $CMD" 2>&1 77 | 78 | echo "END TIME: $(date)" 79 | -------------------------------------------------------------------------------- /cluster_scripts/slurm/finetuning/sd2-1_cc3m.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=sd-finetune-cc3m # Specify a name for your job 3 | #SBATCH --output=logs/out-%x-%j.log # Specify the output log file 4 | #SBATCH --error=logs/err-%x-%j.log # Specify the error log file 5 | #SBATCH --cpus-per-task=4 # Number of CPU cores per task 6 | #SBATCH --nodes=1 # Number of nodes 7 | #SBATCH --gres=gpu:a100:2 # Number of GPUs to request and specify the GPU type 8 | #SBATCH --time=24:00:00 # Maximum execution time (HH:MM:SS) 9 | #SBATCH --mem=64G # Memory per node 10 | #SBATCH --array=0-NUM_EXPERT-1 # Array job. one task per expert 11 | 12 | set -x -e 13 | 14 | # Some info 15 | echo "start time: $(date)" 16 | echo "SLURM_JOBID="$SLURM_JOBID 17 | echo "SLURM_JOB_NODELIST"=$SLURM_JOB_NODELIST 18 | echo "SLURM_JOB_PARTITION"=$SLURM_JOB_PARTITION 19 | echo "SLURM_NNODES"=$SLURM_NNODES 20 | echo "SLURM_GPUS_ON_NODE"=$SLURM_GPUS_ON_NODE 21 | echo "SLURM_SUBMIT_DIR"s=$SLURM_SUBMIT_DIR 22 | 23 | # Training setup 24 | MAIN_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 25 | MAIN_PORT=$((SLURM_ARRAY_TASK_ID+5560)) 26 | NNODES=$SLURM_NNODES 27 | NODE_RANK=$SLURM_PROCID 28 | WORLD_SIZE=$(($SLURM_GPUS_ON_NODE * $NNODES)) 29 | 30 | echo "MAIN_ADDR"=$MAIN_ADDR 31 | echo "NNODES"=$NNODES 32 | echo "NODE_RANK"=$NODE_RANK 33 | 34 | 35 | export WANDB_CACHE_DIR=/path/to/wandb 36 | 37 | export NCCL_DEBUG=INFO 38 | 39 | source ~/.bashrc 40 | conda activate pdm 41 | 42 | cd /fs/nexus-scratch/rezashkv/research/projects/diffusion_pruning/scripts || exit 43 | 44 | pruning_ckpt_dir=$1 45 | expert_id=$((SLURM_ARRAY_TASK_ID)) 46 | 47 | 48 | CMD=" \ 49 | finetune.py \ 50 | --pruning_ckpt_dir \ 51 | $pruning_ckpt_dir \ 52 | --base_config_path \ 53 | /path/to/diffusion_pruning/configs/finetuning/sd-2-1_cc3m.yaml \ 54 | --cache_dir \ 55 | /path/to/.cache/huggingface/ \ 56 | --expert_id \ 57 | ${expert_id} \ 58 | --wandb_run_name \ 59 | finetune_sd-2-1_cc3m/arch${expert_id} 60 | " 61 | LAUNCHER="accelerate launch \ 62 | --mixed_precision no \ 63 | --dynamo_backend no \ 64 | --multi_gpu \ 65 | --num_machines $NNODES \ 66 | --num_processes $WORLD_SIZE \ 67 | --main_process_ip "$MAIN_ADDR" \ 68 | --main_process_port $MAIN_PORT \ 69 | --machine_rank \$SLURM_PROCID \ 70 | --role $SLURMD_NODENAME: \ 71 | --rdzv_conf rdzv_backend=c10d \ 72 | --max_restarts 0 \ 73 | --tee 3 \ 74 | " 75 | SRUN_ARGS=" \ 76 | --wait=60 \ 77 | --kill-on-bad-exit=1 \ 78 | " 79 | srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER $CMD" 2>&1 80 | echo "END TIME: $(date)" 81 | -------------------------------------------------------------------------------- /cluster_scripts/slurm/finetuning/sd2-1_coco.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=sd-finetune-coco # Specify a name for your job 3 | #SBATCH --output=logs/out-%x-%j.log # Specify the output log file 4 | #SBATCH --error=logs/err-%x-%j.log # Specify the error log file 5 | #SBATCH --cpus-per-task=4 # Number of CPU cores per task 6 | #SBATCH --nodes=1 # Number of nodes 7 | #SBATCH --gres=gpu:a100:2 # Number of GPUs to request and specify the GPU type 8 | #SBATCH --time=24:00:00 # Maximum execution time (HH:MM:SS) 9 | #SBATCH --mem=64G # Memory per node 10 | #SBATCH --array=0-NUM_EXPERT-1 # Array job. one task per expert 11 | 12 | set -x -e 13 | 14 | # Some info 15 | echo "start time: $(date)" 16 | echo "SLURM_JOBID="$SLURM_JOBID 17 | echo "SLURM_JOB_NODELIST"=$SLURM_JOB_NODELIST 18 | echo "SLURM_JOB_PARTITION"=$SLURM_JOB_PARTITION 19 | echo "SLURM_NNODES"=$SLURM_NNODES 20 | echo "SLURM_GPUS_ON_NODE"=$SLURM_GPUS_ON_NODE 21 | echo "SLURM_SUBMIT_DIR"s=$SLURM_SUBMIT_DIR 22 | 23 | # Training setup 24 | MAIN_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 25 | MAIN_PORT=$((SLURM_ARRAY_TASK_ID+5560)) 26 | NNODES=$SLURM_NNODES 27 | NODE_RANK=$SLURM_PROCID 28 | WORLD_SIZE=$(($SLURM_GPUS_ON_NODE * $NNODES)) 29 | 30 | echo "MAIN_ADDR"=$MAIN_ADDR 31 | echo "NNODES"=$NNODES 32 | echo "NODE_RANK"=$NODE_RANK 33 | 34 | 35 | export WANDB_CACHE_DIR=/path/to/wandb 36 | 37 | export NCCL_DEBUG=INFO 38 | 39 | source ~/.bashrc 40 | conda activate pdm 41 | 42 | cd /fs/nexus-scratch/rezashkv/research/projects/diffusion_pruning/scripts || exit 43 | 44 | pruning_ckpt_dir=$1 45 | expert_id=$((SLURM_ARRAY_TASK_ID)) 46 | 47 | 48 | CMD=" \ 49 | finetune.py \ 50 | --pruning_ckpt_dir \ 51 | $pruning_ckpt_dir \ 52 | --base_config_path \ 53 | /path/to/diffusion_pruning/configs/finetuning/sd-2-1_coco2014.yaml \ 54 | --cache_dir \ 55 | /path/to/.cache/huggingface/ \ 56 | --expert_id \ 57 | ${expert_id} \ 58 | --wandb_run_name \ 59 | finetune_sd-2-1_coco/arch${expert_id} 60 | " 61 | LAUNCHER="accelerate launch \ 62 | --mixed_precision no \ 63 | --dynamo_backend no \ 64 | --multi_gpu \ 65 | --num_machines $NNODES \ 66 | --num_processes $WORLD_SIZE \ 67 | --main_process_ip "$MAIN_ADDR" \ 68 | --main_process_port $MAIN_PORT \ 69 | --machine_rank \$SLURM_PROCID \ 70 | --role $SLURMD_NODENAME: \ 71 | --rdzv_conf rdzv_backend=c10d \ 72 | --max_restarts 0 \ 73 | --tee 3 \ 74 | " 75 | SRUN_ARGS=" \ 76 | --wait=60 \ 77 | --kill-on-bad-exit=1 \ 78 | " 79 | srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER $CMD" 2>&1 80 | echo "END TIME: $(date)" 81 | -------------------------------------------------------------------------------- /cluster_scripts/slurm/img_generation/sd2-1_cc3m.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=sd-gen-img-cc3m # Specify a name for your job 3 | #SBATCH --output=logs/out-%x-%j.log # Specify the output log file 4 | #SBATCH --error=logs/err-%x-%j.log # Specify the error log file 5 | #SBATCH --cpus-per-task=8 # Number of CPU cores per task 6 | #SBATCH --nodes=1 # Number of nodes 7 | #SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! 8 | #SBATCH --gres=gpu:a100:2 # Number of GPUs to request and specify the GPU type 9 | #SBATCH --time=24:00:00 # Maximum execution time (HH:MM:SS) 10 | #SBATCH --mem=64G # Memory per node 11 | 12 | set -x -e 13 | 14 | # Some info 15 | echo "start time: $(date)" 16 | echo "SLURM_JOBID="$SLURM_JOBID 17 | echo "SLURM_JOB_NODELIST"=$SLURM_JOB_NODELIST 18 | echo "SLURM_JOB_PARTITION"=$SLURM_JOB_PARTITION 19 | echo "SLURM_NNODES"=$SLURM_NNODES 20 | echo "SLURM_GPUS_ON_NODE"=$SLURM_GPUS_ON_NODE 21 | echo "SLURM_SUBMIT_DIR"s=$SLURM_SUBMIT_DIR 22 | 23 | # Training setup 24 | MAIN_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 25 | MAIN_PORT=5560 26 | NNODES=$SLURM_NNODES 27 | NODE_RANK=$SLURM_PROCID 28 | WORLD_SIZE=$(($SLURM_GPUS_ON_NODE * $NNODES)) 29 | 30 | echo "MAIN_ADDR"=$MAIN_ADDR 31 | echo "NNODES"=$NNODES 32 | echo "NODE_RANK"=$NODE_RANK 33 | 34 | export WANDB_CACHE_DIR=/path/to/wandb 35 | 36 | export NCCL_DEBUG=INFO 37 | 38 | source ~/.bashrc 39 | conda activate pdm 40 | 41 | cd /path/to/diffusion_pruning/scripts/metrics || exit 42 | 43 | finetuning_ckpt_dir=$1 44 | 45 | for arch_dir in "${finetuning_ckpt_dir}"/arch* 46 | do 47 | expert_id=$(echo "$arch_dir" | grep -o '[0-9]*$') 48 | checkpoint_dir=$(find -d "${arch_dir}/checkpoint-[0-9]*" | sort -rV | head -n 1) # the last checkpoint 49 | 50 | echo "$checkpoint_dir" 51 | echo "$expert_id" 52 | 53 | CMD=" \ 54 | generate_fid_images.py \ 55 | --base_config_path \ 56 | /path/to/diffusion_pruning/configs/img_generation/sd-2-1_cc3m.yaml \ 57 | --cache_dir \ 58 | /path/to/.cache/huggingface/ \ 59 | --finetuning_ckpt_dir \ 60 | $checkpoint_dir \ 61 | --expert_id \ 62 | $expert_id \ 63 | " 64 | LAUNCHER="accelerate launch \ 65 | --mixed_precision no \ 66 | --dynamo_backend no \ 67 | --multi_gpu \ 68 | --num_machines $NNODES \ 69 | --num_processes $WORLD_SIZE \ 70 | --main_process_ip $MAIN_ADDR \ 71 | --main_process_port $MAIN_PORT \ 72 | --machine_rank \$SLURM_PROCID \ 73 | --role $SLURMD_NODENAME: \ 74 | --rdzv_conf rdzv_backend=c10d \ 75 | --max_restarts 0 \ 76 | --tee 3 \ 77 | " 78 | SRUN_ARGS=" \ 79 | --wait=60 \ 80 | --kill-on-bad-exit=1 \ 81 | " 82 | srun "$SRUN_ARGS" --jobid "$SLURM_JOB_ID" bash -c "$LAUNCHER $CMD" 2>&1 83 | echo "END TIME: $(date)" 84 | done -------------------------------------------------------------------------------- /cluster_scripts/slurm/pruning/sd2-1_cc3m.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=sd-prune-cc3m # Specify a name for your job 3 | #SBATCH --output=logs/out-%x-%j.log # Specify the output log file 4 | #SBATCH --error=logs/err-%x-%j.log # Specify the error log file 5 | #SBATCH --cpus-per-task=8 # Number of CPU cores per task 6 | #SBATCH --nodes=1 # Number of nodes 7 | #SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! 8 | #SBATCH --gres=gpu:a100:2 # Number of GPUs to request and specify the GPU type 9 | #SBATCH --time=24:00:00 # Maximum execution time (HH:MM:SS) 10 | #SBATCH --mem=64G # Memory per node 11 | 12 | set -x -e 13 | 14 | # Some info 15 | echo "start time: $(date)" 16 | echo "SLURM_JOBID="$SLURM_JOBID 17 | echo "SLURM_JOB_NODELIST"=$SLURM_JOB_NODELIST 18 | echo "SLURM_JOB_PARTITION"=$SLURM_JOB_PARTITION 19 | echo "SLURM_NNODES"=$SLURM_NNODES 20 | echo "SLURM_GPUS_ON_NODE"=$SLURM_GPUS_ON_NODE 21 | echo "SLURM_SUBMIT_DIR"s=$SLURM_SUBMIT_DIR 22 | 23 | # Training setup 24 | MAIN_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 25 | MAIN_PORT=5560 26 | NNODES=$SLURM_NNODES 27 | NODE_RANK=$SLURM_PROCID 28 | WORLD_SIZE=$(($SLURM_GPUS_ON_NODE * $NNODES)) 29 | 30 | echo "MAIN_ADDR"=$MAIN_ADDR 31 | echo "NNODES"=$NNODES 32 | echo "NODE_RANK"=$NODE_RANK 33 | 34 | 35 | export WANDB_CACHE_DIR=/path/to/wandb 36 | 37 | export NCCL_DEBUG=INFO 38 | 39 | source ~/.bashrc 40 | conda activate pdm 41 | 42 | cd /path/to/diffusion_pruning/scripts/aptp || exit 43 | 44 | CMD=" \ 45 | prune.py \ 46 | --base_config_path \ 47 | /path/to/diffusion_pruning/configs/pruning/sd-2-1_cc3m.yaml \ 48 | --cache_dir \ 49 | /path/to/.cache/huggingface/ \ 50 | --wandb_run_name \ 51 | sd-2-1_cc3m \ 52 | " 53 | 54 | LAUNCHER="accelerate launch \ 55 | --num_machines $NNODES \ 56 | --num_processes $WORLD_SIZE \ 57 | --main_process_ip $MAIN_ADDR \ 58 | --main_process_port $MAIN_PORT \ 59 | --machine_rank \$SLURM_PROCID \ 60 | --role $SLURMD_NODENAME: \ 61 | --rdzv_conf rdzv_backend=c10d \ 62 | --max_restarts 0 \ 63 | --tee 3 \ 64 | " 65 | 66 | SRUN_ARGS=" \ 67 | --wait=60 \ 68 | --kill-on-bad-exit=1 \ 69 | " 70 | 71 | srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER $CMD" 2>&1 72 | 73 | echo "END TIME: $(date)" 74 | -------------------------------------------------------------------------------- /cluster_scripts/slurm/pruning/sd2-1_coco.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=sd-prune-coco # Specify a name for your job 3 | #SBATCH --output=logs/out-%x-%j.log # Specify the output log file 4 | #SBATCH --error=logs/err-%x-%j.log # Specify the error log file 5 | #SBATCH --cpus-per-task=8 # Number of CPU cores per task 6 | #SBATCH --nodes=1 # Number of nodes 7 | #SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! 8 | #SBATCH --gres=gpu:a100:2 # Number of GPUs to request and specify the GPU type 9 | #SBATCH --time=24:00:00 # Maximum execution time (HH:MM:SS) 10 | #SBATCH --mem=64G # Memory per node 11 | 12 | set -x -e 13 | 14 | # Some info 15 | echo "start time: $(date)" 16 | echo "SLURM_JOBID="$SLURM_JOBID 17 | echo "SLURM_JOB_NODELIST"=$SLURM_JOB_NODELIST 18 | echo "SLURM_JOB_PARTITION"=$SLURM_JOB_PARTITION 19 | echo "SLURM_NNODES"=$SLURM_NNODES 20 | echo "SLURM_GPUS_ON_NODE"=$SLURM_GPUS_ON_NODE 21 | echo "SLURM_SUBMIT_DIR"s=$SLURM_SUBMIT_DIR 22 | 23 | # Training setup 24 | MAIN_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 25 | MAIN_PORT=5562 26 | NNODES=$SLURM_NNODES 27 | NODE_RANK=$SLURM_PROCID 28 | WORLD_SIZE=$(($SLURM_GPUS_ON_NODE * $NNODES)) 29 | 30 | echo "MAIN_ADDR"=$MAIN_ADDR 31 | echo "NNODES"=$NNODES 32 | echo "NODE_RANK"=$NODE_RANK 33 | 34 | 35 | export WANDB_CACHE_DIR=/path/to/wandb 36 | 37 | export NCCL_DEBUG=INFO 38 | 39 | source ~/.bashrc 40 | conda activate pdm 41 | 42 | cd /path/to/diffusion_pruning/scripts/aptp || exit 43 | 44 | CMD=" \ 45 | prune.py \ 46 | --base_config_path \ 47 | /path/to/diffusion_pruning/configs/pruning/sd-2-1_coco2014.yaml \ 48 | --cache_dir \ 49 | /path/to/.cache/huggingface/ \ 50 | --wandb_run_name \ 51 | sd-2-1_coco \ 52 | " 53 | 54 | LAUNCHER="accelerate launch \ 55 | --num_machines $NNODES \ 56 | --num_processes $WORLD_SIZE \ 57 | --main_process_ip $MAIN_ADDR \ 58 | --main_process_port $MAIN_PORT \ 59 | --machine_rank \$SLURM_PROCID \ 60 | --role $SLURMD_NODENAME: \ 61 | --rdzv_conf rdzv_backend=c10d \ 62 | --max_restarts 0 \ 63 | --tee 3 \ 64 | " 65 | 66 | SRUN_ARGS=" \ 67 | --wait=60 \ 68 | --kill-on-bad-exit=1 \ 69 | " 70 | 71 | srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER $CMD" 2>&1 72 | 73 | echo "END TIME: $(date)" 74 | -------------------------------------------------------------------------------- /cmmd-pytorch/cmmd_utils.py: -------------------------------------------------------------------------------- 1 | # credits: https://github.com/sayakpaul/cmmd-pytorch 2 | # coding=utf-8 3 | # Copyright 2024 The Google Research Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """The main entry point for the CMMD calculation.""" 18 | import os 19 | 20 | from absl import app 21 | from absl import flags 22 | import distance 23 | import embedding 24 | import io_util 25 | import numpy as np 26 | 27 | 28 | _BATCH_SIZE = flags.DEFINE_integer("batch_size", 32, "Batch size for embedding generation.") 29 | _MAX_COUNT = flags.DEFINE_integer("max_count", -1, "Maximum number of images to read from each directory.") 30 | _REF_EMBED_FILE = flags.DEFINE_string( 31 | "ref_embed_file", None, "Path to the pre-computed embedding file for the reference images." 32 | ) 33 | 34 | 35 | def compute_cmmd(ref_dir, eval_dir, ref_embed_file=None, batch_size=32, max_count=-1): 36 | """Calculates the CMMD distance between reference and eval image sets. 37 | 38 | Args: 39 | ref_dir: Path to the directory containing reference images. 40 | eval_dir: Path to the directory containing images to be evaluated. 41 | ref_embed_file: Path to the pre-computed embedding file for the reference images. 42 | batch_size: Batch size used in the CLIP embedding calculation. 43 | max_count: Maximum number of images to use from each directory. A 44 | non-positive value reads all images available except for the images 45 | dropped due to batching. 46 | 47 | Returns: 48 | The CMMD value between the image sets. 49 | """ 50 | if ref_dir and ref_embed_file: 51 | raise ValueError("`ref_dir` and `ref_embed_file` both cannot be set at the same time.") 52 | embedding_model = embedding.ClipEmbeddingModel() 53 | if ref_embed_file is not None: 54 | ref_embs = np.load(ref_embed_file).astype("float32") 55 | else: 56 | ref_embs = io_util.compute_embeddings_for_dir(ref_dir, embedding_model, batch_size, max_count).astype( 57 | "float32" 58 | ) 59 | eval_embs = io_util.compute_embeddings_for_dir(eval_dir, embedding_model, batch_size, max_count).astype("float32") 60 | val = distance.mmd(ref_embs, eval_embs) 61 | return val.numpy() 62 | 63 | 64 | def save_ref_embeds(ref_dir, batch_size=32, max_count=-1): 65 | """Computes and saves the embeddings for the reference images. 66 | 67 | Args: 68 | ref_dir: Path to the directory containing reference images. 69 | batch_size: Batch size used in the CLIP embedding calculation. 70 | max_count: Maximum number of images to use from each directory. A 71 | non-positive value reads all images available except for the images 72 | dropped due to batching. 73 | """ 74 | embedding_model = embedding.ClipEmbeddingModel() 75 | ref_embs = io_util.compute_embeddings_for_dir(ref_dir, embedding_model, batch_size, max_count).astype("float32") 76 | 77 | # Save the embeddings to a file in the parent directory of the reference images. 78 | save_dir = os.path.dirname(ref_dir) 79 | name = os.path.basename(ref_dir) 80 | save_path = os.path.join(save_dir, f"cmmd_embs_{name}.npy") 81 | np.save(save_path, ref_embs) 82 | -------------------------------------------------------------------------------- /cmmd-pytorch/compute_cmmd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from cmmd_utils import compute_cmmd 3 | 4 | 5 | def parse_args(): 6 | parser = argparse.ArgumentParser(description="Compute and save embeddings for reference images.") 7 | parser.add_argument("--ref_file", type=str, help="Path to the file containing reference images embeddings.") 8 | parser.add_argument("--eval_dir", type=str, help="Path to the directory containing images to be evaluated.") 9 | parser.add_argument("--batch_size", type=int, default=32, help="Batch size used in the CLIP embedding calculation.") 10 | parser.add_argument("--max_count", type=int, default=-1, 11 | help="Maximum number of images to use from each directory.") 12 | parser.add_argument("--result_file", type=str, help="Path to the file to save the computed CMMD value.") 13 | return parser.parse_args() 14 | 15 | 16 | if __name__ == "__main__": 17 | args = parse_args() 18 | cmmd = compute_cmmd(ref_dir=None, eval_dir=args.eval_dir, ref_embed_file=args.ref_file, batch_size=args.batch_size, 19 | max_count=args.max_count) 20 | with open(args.result_file, "a") as f: 21 | f.write(f"{args.eval_dir}: {cmmd}\n") 22 | -------------------------------------------------------------------------------- /cmmd-pytorch/distance.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Memory-efficient MMD implementation in JAX.""" 17 | 18 | import torch 19 | 20 | # The bandwidth parameter for the Gaussian RBF kernel. See the paper for more 21 | # details. 22 | _SIGMA = 10 23 | # The following is used to make the metric more human readable. See the paper 24 | # for more details. 25 | _SCALE = 1000 26 | 27 | 28 | def mmd(x, y): 29 | """Memory-efficient MMD implementation in JAX. 30 | 31 | This implements the minimum-variance/biased version of the estimator described 32 | in Eq.(5) of 33 | https://jmlr.csail.mit.edu/papers/volume13/gretton12a/gretton12a.pdf. 34 | As described in Lemma 6's proof in that paper, the unbiased estimate and the 35 | minimum-variance estimate for MMD are almost identical. 36 | 37 | Note that the first invocation of this function will be considerably slow due 38 | to JAX JIT compilation. 39 | 40 | Args: 41 | x: The first set of embeddings of shape (n, embedding_dim). 42 | y: The second set of embeddings of shape (n, embedding_dim). 43 | 44 | Returns: 45 | The MMD distance between x and y embedding sets. 46 | """ 47 | x = torch.from_numpy(x) 48 | y = torch.from_numpy(y) 49 | 50 | x_sqnorms = torch.diag(torch.matmul(x, x.T)) 51 | y_sqnorms = torch.diag(torch.matmul(y, y.T)) 52 | 53 | gamma = 1 / (2 * _SIGMA**2) 54 | k_xx = torch.mean( 55 | torch.exp(-gamma * (-2 * torch.matmul(x, x.T) + torch.unsqueeze(x_sqnorms, 1) + torch.unsqueeze(x_sqnorms, 0))) 56 | ) 57 | k_xy = torch.mean( 58 | torch.exp(-gamma * (-2 * torch.matmul(x, y.T) + torch.unsqueeze(x_sqnorms, 1) + torch.unsqueeze(y_sqnorms, 0))) 59 | ) 60 | k_yy = torch.mean( 61 | torch.exp(-gamma * (-2 * torch.matmul(y, y.T) + torch.unsqueeze(y_sqnorms, 1) + torch.unsqueeze(y_sqnorms, 0))) 62 | ) 63 | 64 | return _SCALE * (k_xx + k_yy - 2 * k_xy) 65 | -------------------------------------------------------------------------------- /cmmd-pytorch/embedding.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Embedding models used in the CMMD calculation.""" 17 | 18 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 19 | import torch 20 | import numpy as np 21 | 22 | _CLIP_MODEL_NAME = "openai/clip-vit-large-patch14-336" 23 | _CUDA_AVAILABLE = torch.cuda.is_available() 24 | 25 | 26 | def _resize_bicubic(images, size): 27 | images = torch.from_numpy(images.transpose(0, 3, 1, 2)) 28 | images = torch.nn.functional.interpolate(images, size=(size, size), mode="bicubic") 29 | images = images.permute(0, 2, 3, 1).numpy() 30 | return images 31 | 32 | 33 | class ClipEmbeddingModel: 34 | """CLIP image embedding calculator.""" 35 | 36 | def __init__(self): 37 | self.image_processor = CLIPImageProcessor.from_pretrained(_CLIP_MODEL_NAME) 38 | 39 | self._model = CLIPVisionModelWithProjection.from_pretrained(_CLIP_MODEL_NAME).eval() 40 | if _CUDA_AVAILABLE: 41 | self._model = self._model.cuda() 42 | 43 | self.input_image_size = self.image_processor.crop_size["height"] 44 | 45 | @torch.no_grad() 46 | def embed(self, images): 47 | """Computes CLIP embeddings for the given images. 48 | 49 | Args: 50 | images: An image array of shape (batch_size, height, width, 3). Values are 51 | in range [0, 1]. 52 | 53 | Returns: 54 | Embedding array of shape (batch_size, embedding_width). 55 | """ 56 | 57 | images = _resize_bicubic(images, self.input_image_size) 58 | inputs = self.image_processor( 59 | images=images, 60 | do_normalize=True, 61 | do_center_crop=False, 62 | do_resize=False, 63 | do_rescale=False, 64 | return_tensors="pt", 65 | ) 66 | if _CUDA_AVAILABLE: 67 | inputs = {k: v.to("cuda") for k, v in inputs.items()} 68 | 69 | image_embs = self._model(**inputs).image_embeds.cpu() 70 | image_embs /= torch.linalg.norm(image_embs, axis=-1, keepdims=True) 71 | return image_embs 72 | -------------------------------------------------------------------------------- /cmmd-pytorch/generate_images.py: -------------------------------------------------------------------------------- 1 | from diffusers import DiffusionPipeline 2 | from concurrent.futures import ThreadPoolExecutor 3 | import pandas as pd 4 | import argparse 5 | import torch 6 | import os 7 | 8 | 9 | ALL_CKPTS = [ 10 | "runwayml/stable-diffusion-v1-5", 11 | "segmind/SSD-1B", 12 | "PixArt-alpha/PixArt-XL-2-1024-MS", 13 | "stabilityai/stable-diffusion-xl-base-1.0", 14 | "stabilityai/sdxl-turbo", 15 | ] 16 | SEED = 2024 17 | 18 | 19 | def load_dataframe(): 20 | dataframe = pd.read_csv( 21 | "https://huggingface.co/datasets/sayakpaul/sample-datasets/raw/main/coco_30k_randomly_sampled_2014_val.csv" 22 | ) 23 | return dataframe 24 | 25 | 26 | def load_pipeline(args): 27 | if "runway" in args.pipeline_id: 28 | pipeline = DiffusionPipeline.from_pretrained( 29 | args.pipeline_id, torch_dtype=torch.float16, safety_checker=None 30 | ).to("cuda") 31 | else: 32 | pipeline = DiffusionPipeline.from_pretrained(args.pipeline_id, torch_dtype=torch.float16).to("cuda") 33 | pipeline.set_progress_bar_config(disable=True) 34 | return pipeline 35 | 36 | 37 | def generate_images(args, dataframe, pipeline): 38 | all_images = [] 39 | for i in range(0, len(dataframe), args.chunk_size): 40 | if "sdxl-turbo" not in args.pipeline_id: 41 | images = pipeline( 42 | dataframe.iloc[i : i + args.chunk_size]["caption"].tolist(), 43 | num_inference_steps=args.num_inference_steps, 44 | generator=torch.manual_seed(SEED), 45 | ).images 46 | else: 47 | images = pipeline( 48 | dataframe.iloc[i : i + args.chunk_size]["caption"].tolist(), 49 | num_inference_steps=args.num_inference_steps, 50 | generator=torch.manual_seed(SEED), 51 | guidance_scale=0.0, 52 | ).images 53 | all_images.extend(images) 54 | return all_images 55 | 56 | 57 | def serialize_image(image, path): 58 | image.save(path) 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument("--pipeline_id", default="runwayml/stable-diffusion-v1-5", type=str, choices=ALL_CKPTS) 64 | parser.add_argument("--num_inference_steps", default=30, type=int) 65 | parser.add_argument("--chunk_size", default=2, type=int) 66 | parser.add_argument("--root_img_path", default="sdv15", type=str) 67 | parser.add_argument("--num_workers", type=int, default=4) 68 | args = parser.parse_args() 69 | 70 | dataset = load_dataframe() 71 | pipeline = load_pipeline(args) 72 | images = generate_images(args, dataset, pipeline) 73 | image_paths = [os.path.join(args.root_img_path, f"{i}.jpg") for i in range(len(images))] 74 | 75 | if not os.path.exists(args.root_img_path): 76 | os.makedirs(args.root_img_path) 77 | 78 | with ThreadPoolExecutor(max_workers=args.num_workers) as executor: 79 | executor.map(serialize_image, images, image_paths) 80 | -------------------------------------------------------------------------------- /cmmd-pytorch/io_util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """IO utilities.""" 17 | 18 | import glob 19 | from torch.utils.data import Dataset, DataLoader 20 | import numpy as np 21 | from PIL import Image 22 | import tqdm 23 | 24 | 25 | class CMMDDataset(Dataset): 26 | def __init__(self, path, reshape_to, max_count=-1): 27 | self.path = path 28 | self.reshape_to = reshape_to 29 | 30 | self.max_count = max_count 31 | img_path_list = self._get_image_list() 32 | if max_count > 0: 33 | img_path_list = img_path_list[:max_count] 34 | self.img_path_list = img_path_list 35 | 36 | def __len__(self): 37 | return len(self.img_path_list) 38 | 39 | def _get_image_list(self): 40 | ext_list = ["png", "jpg", "jpeg", "npy"] 41 | image_list = [] 42 | for ext in ext_list: 43 | image_list.extend(glob.glob(f"{self.path}/*{ext}")) 44 | image_list.extend(glob.glob(f"{self.path}/*.{ext.upper()}")) 45 | # Sort the list to ensure a deterministic output. 46 | image_list.sort() 47 | return image_list 48 | 49 | def _center_crop_and_resize(self, im, size): 50 | w, h = im.size 51 | l = min(w, h) 52 | top = (h - l) // 2 53 | left = (w - l) // 2 54 | box = (left, top, left + l, top + l) 55 | im = im.crop(box) 56 | # Note that the following performs anti-aliasing as well. 57 | return im.resize((size, size), resample=Image.BICUBIC) # pytype: disable=module-attr 58 | 59 | def _read_image(self, path, size): 60 | if path.endswith(".npy"): 61 | im = np.load(path) 62 | im = Image.fromarray(im) 63 | else: 64 | im = Image.open(path) 65 | if size > 0: 66 | im = self._center_crop_and_resize(im, size) 67 | return np.asarray(im).astype(np.float32) 68 | 69 | def __getitem__(self, idx): 70 | img_path = self.img_path_list[idx] 71 | 72 | x = self._read_image(img_path, self.reshape_to) 73 | if x.ndim == 3: 74 | return x 75 | elif x.ndim == 2: 76 | # Convert grayscale to RGB by duplicating the channel dimension. 77 | return np.tile(x[Ellipsis, np.newaxis], (1, 1, 3)) 78 | 79 | 80 | def compute_embeddings_for_dir( 81 | img_dir, 82 | embedding_model, 83 | batch_size, 84 | max_count=-1, 85 | ): 86 | """Computes embeddings for the images in the given directory. 87 | 88 | This drops the remainder of the images after batching with the provided 89 | batch_size to enable efficient computation on TPUs. This usually does not 90 | affect results assuming we have a large number of images in the directory. 91 | 92 | Args: 93 | img_dir: Directory containing .jpg or .png image files. 94 | embedding_model: The embedding model to use. 95 | batch_size: Batch size for the embedding model inference. 96 | max_count: Max number of images in the directory to use. 97 | 98 | Returns: 99 | Computed embeddings of shape (num_images, embedding_dim). 100 | """ 101 | dataset = CMMDDataset(img_dir, reshape_to=embedding_model.input_image_size, max_count=max_count) 102 | count = len(dataset) 103 | print(f"Calculating embeddings for {count} images from {img_dir}.") 104 | 105 | dataloader = DataLoader(dataset, batch_size=batch_size) 106 | 107 | all_embs = [] 108 | for batch in tqdm.tqdm(dataloader, total=count // batch_size): 109 | image_batch = batch.numpy() 110 | 111 | # Normalize to the [0, 1] range. 112 | image_batch = image_batch / 255.0 113 | 114 | if np.min(image_batch) < 0 or np.max(image_batch) > 1: 115 | raise ValueError( 116 | "Image values are expected to be in [0, 1]. Found:" f" [{np.min(image_batch)}, {np.max(image_batch)}]." 117 | ) 118 | 119 | # Compute the embeddings using a pmapped function. 120 | embs = np.asarray( 121 | embedding_model.embed(image_batch) 122 | ) # The output has shape (num_devices, batch_size, embedding_dim). 123 | all_embs.append(embs) 124 | 125 | all_embs = np.concatenate(all_embs, axis=0) 126 | 127 | return all_embs 128 | -------------------------------------------------------------------------------- /cmmd-pytorch/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | accelerate 3 | Pillow 4 | tqdm 5 | numpy 6 | absl-py -------------------------------------------------------------------------------- /cmmd-pytorch/save_refs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from cmmd_utils import save_ref_embeds 3 | 4 | 5 | def parse_args(): 6 | parser = argparse.ArgumentParser(description="Compute and save embeddings for reference images.") 7 | parser.add_argument("--ref_dir", type=str, help="Path to the directory containing reference images.") 8 | parser.add_argument("--batch_size", type=int, default=32, help="Batch size used in the CLIP embedding calculation.") 9 | parser.add_argument("--max_count", type=int, default=-1, help="Maximum number of images to use from each directory.") 10 | return parser.parse_args() 11 | 12 | 13 | if __name__ == "__main__": 14 | args = parse_args() 15 | save_ref_embeds(args.ref_dir, args.batch_size, args.max_count) 16 | -------------------------------------------------------------------------------- /configs/baselines/finetuning/sd-2-1_cc3m_magnitude.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | unet: 3 | pretrained_model_name_or_path: stabilityai/stable-unet-2-1 4 | input_perturbation: 0.0 5 | revision: null 6 | resolution: 256 7 | use_ema: false 8 | noise_offset: 0.0 9 | prediction_type: v_prediction 10 | max_scheduler_steps: null 11 | 12 | gated_ff: true 13 | ff_gate_width: 32 14 | 15 | data: 16 | dataset_name: cc3m 17 | data_files: null 18 | dataset_config_name: null 19 | data_dir: "/path/to/conceptual_captions" 20 | 21 | train_data_dir: "training" 22 | train_data_file: "Train_GCC-training.tsv" 23 | max_train_samples: null 24 | 25 | validation_data_dir: "validation" 26 | validation_data_file: "Validation_GCC-1.1.0-Validation.tsv" 27 | max_validation_samples: 1000 28 | 29 | image_column: "image" 30 | caption_column: "caption" 31 | 32 | prompts: null 33 | max_generated_samples: 16 34 | 35 | dataloader: 36 | dataloader_num_workers: 0 37 | train_batch_size: 128 38 | validation_batch_size: 32 39 | image_generation_batch_size: 8 40 | center_crop: false 41 | random_flip: true 42 | 43 | training: 44 | pruning_target: 0.6 45 | pruning_method: "magnitude" 46 | 47 | num_train_epochs: null 48 | max_train_steps: 50000 49 | validation_steps: 1000 50 | image_logging_steps: 1000 51 | num_inference_steps: 50 # number of scheduler steps to run for image generation 52 | 53 | mixed_precision: null 54 | gradient_accumulation_steps: 1 55 | gradient_checkpointing: false 56 | local_rank: -1 57 | allow_tf32: false 58 | enable_xformers_memory_efficient_attention: false 59 | 60 | losses: 61 | diffusion_loss: 62 | snr_gamma: 5.0 63 | weight: 0.01 64 | 65 | distillation_loss: 66 | weight: 0.0 67 | 68 | block_loss: 69 | weight: 0.0 70 | 71 | 72 | optim: 73 | unet_learning_rate: 1e-5 74 | unet_weight_decay: 0.00 75 | 76 | use_8bit_adam: false 77 | adam_beta1: 0.9 78 | adam_beta2: 0.999 79 | adam_epsilon: 1e-08 80 | 81 | scale_lr: true 82 | lr_scheduler: "constant_with_warmup" # see pdm.utils.arg_utils for available options 83 | lr_warmup_steps: 250 84 | 85 | 86 | hf_hub: 87 | push_to_hub: false 88 | hub_token: null 89 | hub_model_id: null 90 | 91 | logging: 92 | logging_dir: "path/to/logs/" 93 | 94 | report_to: "wandb" 95 | tracker_project_name: "text2image-dynamic-pruning" 96 | wandb_log_dir: "path/to/wandb" 97 | 98 | checkpoints_total_limit: 1 99 | auto_checkpoint_step: false 100 | resume_from_checkpoint: latest # or null 101 | 102 | -------------------------------------------------------------------------------- /configs/baselines/img_generation/sd-2-1_cc3m_magnitude.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | unet: 3 | pretrained_model_name_or_path: stabilityai/stable-unet-2-1 4 | input_perturbation: 0.0 5 | revision: null 6 | resolution: 256 7 | use_ema: false 8 | noise_offset: 0.0 9 | prediction_type: v_prediction 10 | max_scheduler_steps: null 11 | unet_down_blocks: 12 | - CrossAttnDownBlock2DHalfGated 13 | - CrossAttnDownBlock2DHalfGated 14 | - CrossAttnDownBlock2DHalfGated 15 | - DownBlock2DHalfGated 16 | 17 | unet_mid_block: UNetMidBlock2DCrossAttnWidthGated 18 | 19 | unet_up_blocks: 20 | - UpBlock2DHalfGated 21 | - CrossAttnUpBlock2DHalfGated 22 | - CrossAttnUpBlock2DHalfGated 23 | - CrossAttnUpBlock2DHalfGated 24 | 25 | gated_ff: true 26 | ff_gate_width: 32 27 | 28 | data: 29 | dataset_name: cc3m 30 | data_files: null 31 | dataset_config_name: null 32 | data_dir: "/path/to/conceptual_captions" 33 | 34 | train_data_dir: "training" 35 | train_data_file: "Train_GCC-training.tsv" 36 | max_train_samples: null 37 | 38 | validation_data_dir: "validation" 39 | validation_data_file: "Validation_GCC-1.1.0-Validation.tsv" 40 | max_validation_samples: null 41 | 42 | image_column: "image" 43 | caption_column: "caption" 44 | 45 | dataloader: 46 | dataloader_num_workers: 0 47 | image_generation_batch_size: 8 48 | 49 | training: 50 | pruning_target: 0.6 51 | pruning_method: "magnitude" 52 | 53 | num_inference_steps: 25 # number of scheduler steps to run for image generation 54 | local_rank: -1 55 | enable_xformers_memory_efficient_attention: false 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/baselines/pruning/sd-2-1_coco2014_ot.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | unet: 3 | pretrained_model_name_or_path: stabilityai/stable-unet-2-1 4 | input_perturbation: 0.0 5 | revision: null 6 | resolution: 256 7 | use_ema: false 8 | noise_offset: 0.0 9 | prediction_type: v_prediction 10 | max_scheduler_steps: null 11 | unet_down_blocks: 12 | - CrossAttnDownBlock2DHalfGated 13 | - CrossAttnDownBlock2DHalfGated 14 | - CrossAttnDownBlock2DHalfGated 15 | - DownBlock2DHalfGated 16 | 17 | unet_mid_block: UNetMidBlock2DCrossAttnWidthGated 18 | 19 | unet_up_blocks: 20 | - UpBlock2DHalfGated 21 | - CrossAttnUpBlock2DHalfGated 22 | - CrossAttnUpBlock2DHalfGated 23 | - CrossAttnUpBlock2DHalfGated 24 | 25 | gated_ff: true 26 | ff_gate_width: 32 27 | 28 | hypernet: 29 | weight_norm: false 30 | linear_bias: true 31 | single_arch_param: false # if true, train a single experts for all prompts (used as a baseline in the paper) 32 | 33 | quantizer: 34 | quantizer_T: 0.4 35 | quantizer_base: 3 36 | num_arch_vq_codebook_embeddings: 8 37 | arch_vq_beta: 0.25 38 | depth_order: [-1, -2, 0, 1, -3, -4, 2, 3, -5, -6, 4, 5, -7, 6] 39 | non_zero_width: true 40 | resource_aware_normalization: false 41 | optimal_transport: true 42 | 43 | data: 44 | dataset_name: coco 45 | data_files: null 46 | dataset_config_name: null 47 | data_dir: "path/to/coco" 48 | 49 | max_train_samples: null 50 | max_validation_samples: null 51 | year: 2014 # 2014 or 2017. 2014 is used in the paper. 52 | 53 | image_column: "image" 54 | caption_column: "caption" 55 | 56 | prompts: null 57 | max_generated_samples: 2 58 | 59 | dataloader: 60 | dataloader_num_workers: 0 61 | train_batch_size: 64 62 | validation_batch_size: 16 63 | image_generation_batch_size: 4 64 | center_crop: false 65 | random_flip: true 66 | 67 | training: 68 | num_train_epochs: null 69 | max_train_steps: 5000 70 | hypernet_pretraining_steps: 500 71 | validation_steps: 1000 72 | image_logging_steps: 1000 73 | num_inference_steps: 50 # number of scheduler steps to run for image generation 74 | 75 | mixed_precision: null 76 | gradient_accumulation_steps: 1 77 | gradient_checkpointing: false 78 | local_rank: -1 79 | allow_tf32: false 80 | enable_xformers_memory_efficient_attention: false 81 | 82 | losses: 83 | diffusion_loss: 84 | snr_gamma: 5.0 85 | weight: 1.0 86 | 87 | resource_loss: 88 | type: log 89 | weight: 2.0 90 | pruning_target: 0.6 91 | 92 | contrastive_loss: 93 | arch_vector_temperature: 0.03 94 | prompt_embedding_temperature: 0.03 95 | weight: 100.0 96 | 97 | distillation_loss: 98 | weight: 0.0 99 | 100 | block_loss: 101 | weight: 0.0 102 | 103 | std_loss: 104 | weight: 0.0 105 | 106 | max_loss: 107 | weight: 0.0 108 | 109 | 110 | optim: 111 | hypernet_learning_rate: 2e-4 112 | quantizer_learning_rate: 2e-4 113 | unet_learning_rate: 5e-5 114 | 115 | quantizer_weight_decay: 0.00 116 | hypernet_weight_decay: 0.00 117 | unet_weight_decay: 0.00 118 | 119 | use_8bit_adam: false 120 | adam_beta1: 0.9 121 | adam_beta2: 0.999 122 | adam_epsilon: 1e-08 123 | 124 | scale_lr: true 125 | lr_scheduler: "constant_with_warmup" # see pdm.utils.arg_utils for available options 126 | lr_warmup_steps: 100 127 | 128 | 129 | hf_hub: 130 | push_to_hub: false 131 | hub_token: null 132 | hub_model_id: null 133 | 134 | logging: 135 | logging_dir: "path/to/logs/" 136 | 137 | report_to: "wandb" 138 | tracker_project_name: "text2image-dynamic-pruning" 139 | wandb_log_dir: "path/to/wandb" 140 | 141 | checkpoints_total_limit: 1 142 | auto_checkpoint_step: false 143 | resume_from_checkpoint: null # or latest 144 | 145 | -------------------------------------------------------------------------------- /configs/baselines/pruning/sd-2-1_coco2014_single_arch.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | unet: 3 | pretrained_model_name_or_path: stabilityai/stable-unet-2-1 4 | input_perturbation: 0.0 5 | revision: null 6 | resolution: 256 7 | use_ema: false 8 | noise_offset: 0.0 9 | prediction_type: v_prediction 10 | max_scheduler_steps: null 11 | unet_down_blocks: 12 | - CrossAttnDownBlock2DHalfGated 13 | - CrossAttnDownBlock2DHalfGated 14 | - CrossAttnDownBlock2DHalfGated 15 | - DownBlock2DHalfGated 16 | 17 | unet_mid_block: UNetMidBlock2DCrossAttnWidthGated 18 | 19 | unet_up_blocks: 20 | - UpBlock2DHalfGated 21 | - CrossAttnUpBlock2DHalfGated 22 | - CrossAttnUpBlock2DHalfGated 23 | - CrossAttnUpBlock2DHalfGated 24 | 25 | gated_ff: true 26 | ff_gate_width: 32 27 | 28 | hypernet: 29 | weight_norm: false 30 | linear_bias: true 31 | single_arch_param: true # if true, train a single experts for all prompts (used as a baseline in the paper) 32 | 33 | quantizer: 34 | quantizer_T: 0.4 35 | quantizer_base: 3 36 | num_arch_vq_codebook_embeddings: 8 37 | arch_vq_beta: 0.25 38 | depth_order: [-1, -2, 0, 1, -3, -4, 2, 3, -5, -6, 4, 5, -7, 6] 39 | non_zero_width: true 40 | resource_aware_normalization: false 41 | optimal_transport: false 42 | 43 | data: 44 | dataset_name: coco 45 | data_files: null 46 | dataset_config_name: null 47 | data_dir: "path/to/coco" 48 | 49 | max_train_samples: null 50 | max_validation_samples: null 51 | year: 2014 # 2014 or 2017. 2014 is used in the paper. 52 | 53 | image_column: "image" 54 | caption_column: "caption" 55 | 56 | prompts: null 57 | max_generated_samples: 2 58 | 59 | dataloader: 60 | dataloader_num_workers: 0 61 | train_batch_size: 64 62 | validation_batch_size: 16 63 | image_generation_batch_size: 4 64 | center_crop: false 65 | random_flip: true 66 | 67 | training: 68 | num_train_epochs: null 69 | max_train_steps: 5000 70 | hypernet_pretraining_steps: 5000 71 | validation_steps: 1000 72 | image_logging_steps: 1000 73 | num_inference_steps: 50 # number of scheduler steps to run for image generation 74 | 75 | mixed_precision: null 76 | gradient_accumulation_steps: 1 77 | gradient_checkpointing: false 78 | local_rank: -1 79 | allow_tf32: false 80 | enable_xformers_memory_efficient_attention: false 81 | 82 | losses: 83 | diffusion_loss: 84 | snr_gamma: 5.0 85 | weight: 1.0 86 | 87 | resource_loss: 88 | type: log 89 | weight: 2.0 90 | pruning_target: 0.6 91 | 92 | contrastive_loss: 93 | arch_vector_temperature: 0.03 94 | prompt_embedding_temperature: 0.03 95 | weight: 0.0 96 | 97 | distillation_loss: 98 | weight: 0.0 99 | 100 | block_loss: 101 | weight: 0.0 102 | 103 | std_loss: 104 | weight: 0.0 105 | 106 | max_loss: 107 | weight: 0.0 108 | 109 | 110 | optim: 111 | hypernet_learning_rate: 2e-4 112 | quantizer_learning_rate: 2e-4 113 | unet_learning_rate: 5e-5 114 | 115 | quantizer_weight_decay: 0.00 116 | hypernet_weight_decay: 0.00 117 | unet_weight_decay: 0.00 118 | 119 | use_8bit_adam: false 120 | adam_beta1: 0.9 121 | adam_beta2: 0.999 122 | adam_epsilon: 1e-08 123 | 124 | scale_lr: true 125 | lr_scheduler: "constant_with_warmup" # see pdm.utils.arg_utils for available options 126 | lr_warmup_steps: 100 127 | 128 | 129 | hf_hub: 130 | push_to_hub: false 131 | hub_token: null 132 | hub_model_id: null 133 | 134 | logging: 135 | logging_dir: "path/to/logs/" 136 | 137 | report_to: "wandb" 138 | tracker_project_name: "text2image-dynamic-pruning" 139 | wandb_log_dir: "path/to/wandb" 140 | 141 | checkpoints_total_limit: 1 142 | auto_checkpoint_step: false 143 | resume_from_checkpoint: null # or latest 144 | 145 | -------------------------------------------------------------------------------- /configs/baselines/pruning/sd-2-1_coco2014_without_ot.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | unet: 3 | pretrained_model_name_or_path: stabilityai/stable-unet-2-1 4 | input_perturbation: 0.0 5 | revision: null 6 | resolution: 256 7 | use_ema: false 8 | noise_offset: 0.0 9 | prediction_type: v_prediction 10 | max_scheduler_steps: null 11 | unet_down_blocks: 12 | - CrossAttnDownBlock2DHalfGated 13 | - CrossAttnDownBlock2DHalfGated 14 | - CrossAttnDownBlock2DHalfGated 15 | - DownBlock2DHalfGated 16 | 17 | unet_mid_block: UNetMidBlock2DCrossAttnWidthGated 18 | 19 | unet_up_blocks: 20 | - UpBlock2DHalfGated 21 | - CrossAttnUpBlock2DHalfGated 22 | - CrossAttnUpBlock2DHalfGated 23 | - CrossAttnUpBlock2DHalfGated 24 | 25 | gated_ff: true 26 | ff_gate_width: 32 27 | 28 | hypernet: 29 | weight_norm: false 30 | linear_bias: true 31 | single_arch_param: false # if true, train a single experts for all prompts (used as a baseline in the paper) 32 | 33 | quantizer: 34 | quantizer_T: 0.4 35 | quantizer_base: 3 36 | num_arch_vq_codebook_embeddings: 8 37 | arch_vq_beta: 0.25 38 | depth_order: [-1, -2, 0, 1, -3, -4, 2, 3, -5, -6, 4, 5, -7, 6] 39 | non_zero_width: true 40 | resource_aware_normalization: false 41 | optimal_transport: false 42 | 43 | data: 44 | dataset_name: coco 45 | data_files: null 46 | dataset_config_name: null 47 | data_dir: "path/to/coco" 48 | 49 | max_train_samples: null 50 | max_validation_samples: null 51 | year: 2014 # 2014 or 2017. 2014 is used in the paper. 52 | 53 | image_column: "image" 54 | caption_column: "caption" 55 | 56 | prompts: null 57 | max_generated_samples: 2 58 | 59 | dataloader: 60 | dataloader_num_workers: 0 61 | train_batch_size: 64 62 | validation_batch_size: 16 63 | image_generation_batch_size: 4 64 | center_crop: false 65 | random_flip: true 66 | 67 | training: 68 | num_train_epochs: null 69 | max_train_steps: 5000 70 | hypernet_pretraining_steps: 500 71 | validation_steps: 1000 72 | image_logging_steps: 1000 73 | num_inference_steps: 50 # number of scheduler steps to run for image generation 74 | 75 | mixed_precision: null 76 | gradient_accumulation_steps: 1 77 | gradient_checkpointing: false 78 | local_rank: -1 79 | allow_tf32: false 80 | enable_xformers_memory_efficient_attention: false 81 | 82 | losses: 83 | diffusion_loss: 84 | snr_gamma: 5.0 85 | weight: 1.0 86 | 87 | resource_loss: 88 | type: log 89 | weight: 2.0 90 | pruning_target: 0.6 91 | 92 | contrastive_loss: 93 | arch_vector_temperature: 0.03 94 | prompt_embedding_temperature: 0.03 95 | weight: 100.0 96 | 97 | distillation_loss: 98 | weight: 0.0 99 | 100 | block_loss: 101 | weight: 0.0 102 | 103 | std_loss: 104 | weight: 0.0 105 | 106 | max_loss: 107 | weight: 0.0 108 | 109 | 110 | optim: 111 | hypernet_learning_rate: 2e-4 112 | quantizer_learning_rate: 2e-4 113 | unet_learning_rate: 5e-5 114 | 115 | quantizer_weight_decay: 0.00 116 | hypernet_weight_decay: 0.00 117 | unet_weight_decay: 0.00 118 | 119 | use_8bit_adam: false 120 | adam_beta1: 0.9 121 | adam_beta2: 0.999 122 | adam_epsilon: 1e-08 123 | 124 | scale_lr: true 125 | lr_scheduler: "constant_with_warmup" # see pdm.utils.arg_utils for available options 126 | lr_warmup_steps: 100 127 | 128 | 129 | hf_hub: 130 | push_to_hub: false 131 | hub_token: null 132 | hub_model_id: null 133 | 134 | logging: 135 | logging_dir: "path/to/logs/" 136 | 137 | report_to: "wandb" 138 | tracker_project_name: "text2image-dynamic-pruning" 139 | wandb_log_dir: "path/to/wandb" 140 | 141 | checkpoints_total_limit: 1 142 | auto_checkpoint_step: false 143 | resume_from_checkpoint: null # or latest 144 | 145 | -------------------------------------------------------------------------------- /configs/filtering/sd-2-1_cc3m.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_name: cc3m 3 | data_files: null 4 | dataset_config_name: null 5 | data_dir: "/path/to/conceptual_captions" 6 | 7 | train_data_dir: "training" 8 | train_data_file: "Train_GCC-training.tsv" 9 | max_train_samples: null 10 | 11 | validation_data_dir: "validation" 12 | validation_data_file: "Validation_GCC-1.1.0-Validation.tsv" 13 | max_validation_samples: 1000 14 | 15 | image_column: "image" 16 | caption_column: "caption" 17 | 18 | prompts: null 19 | max_generated_samples: 16 -------------------------------------------------------------------------------- /configs/filtering/sd-2-1_coco2014.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_name: coco 3 | data_files: null 4 | dataset_config_name: null 5 | data_dir: "path/to/coco" 6 | 7 | max_train_samples: null 8 | max_validation_samples: 10 9 | year: 2014 # 2014 or 2017. 2014 is used in the paper. 10 | 11 | image_column: "image" 12 | caption_column: "caption" 13 | 14 | prompts: null 15 | max_generated_samples: 2 -------------------------------------------------------------------------------- /configs/finetuning/sd-2-1_cc3m.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | unet: 3 | pretrained_model_name_or_path: stabilityai/stable-unet-2-1 4 | input_perturbation: 0.0 5 | revision: null 6 | resolution: 256 7 | use_ema: false 8 | noise_offset: 0.0 9 | prediction_type: v_prediction 10 | max_scheduler_steps: null 11 | unet_down_blocks: 12 | - CrossAttnDownBlock2DHalfGated 13 | - CrossAttnDownBlock2DHalfGated 14 | - CrossAttnDownBlock2DHalfGated 15 | - DownBlock2DHalfGated 16 | 17 | unet_mid_block: UNetMidBlock2DCrossAttnWidthGated 18 | 19 | unet_up_blocks: 20 | - UpBlock2DHalfGated 21 | - CrossAttnUpBlock2DHalfGated 22 | - CrossAttnUpBlock2DHalfGated 23 | - CrossAttnUpBlock2DHalfGated 24 | 25 | gated_ff: true 26 | ff_gate_width: 32 27 | 28 | hypernet: 29 | weight_norm: false 30 | linear_bias: true 31 | single_arch_param: false # if true, train a single experts for all prompts (used as a baseline in the paper) 32 | 33 | quantizer: 34 | quantizer_T: 0.4 35 | quantizer_base: 3 36 | num_arch_vq_codebook_embeddings: 8 37 | arch_vq_beta: 0.25 38 | depth_order: [-1, -2, 0, 1, -3, -4, 2, 3, -5, -6, 4, 5, -7, 6] 39 | non_zero_width: true 40 | resource_aware_normalization: false 41 | optimal_transport: true 42 | 43 | data: 44 | dataset_name: cc3m 45 | data_files: null 46 | dataset_config_name: null 47 | data_dir: "/path/to/conceptual_captions" 48 | 49 | train_data_dir: "training" 50 | train_data_file: "Train_GCC-training.tsv" 51 | max_train_samples: null 52 | 53 | validation_data_dir: "validation" 54 | validation_data_file: "Validation_GCC-1.1.0-Validation.tsv" 55 | max_validation_samples: 1000 56 | 57 | image_column: "image" 58 | caption_column: "caption" 59 | 60 | prompts: null 61 | max_generated_samples: 16 62 | 63 | dataloader: 64 | dataloader_num_workers: 0 65 | train_batch_size: 128 66 | validation_batch_size: 32 67 | image_generation_batch_size: 8 68 | center_crop: false 69 | random_flip: true 70 | 71 | training: 72 | num_train_epochs: null 73 | max_train_steps: 30000 74 | hypernet_pretraining_steps: 500 75 | validation_steps: 1000 76 | image_logging_steps: 1000 77 | num_inference_steps: 50 # number of scheduler steps to run for image generation 78 | 79 | mixed_precision: null 80 | gradient_accumulation_steps: 1 81 | gradient_checkpointing: false 82 | local_rank: -1 83 | allow_tf32: false 84 | enable_xformers_memory_efficient_attention: false 85 | 86 | losses: 87 | diffusion_loss: 88 | snr_gamma: 5.0 89 | weight: 0.01 90 | 91 | distillation_loss: 92 | weight: 0.5 93 | 94 | block_loss: 95 | weight: 0.5 96 | 97 | 98 | optim: 99 | unet_learning_rate: 1e-5 100 | unet_weight_decay: 0.00 101 | 102 | use_8bit_adam: false 103 | adam_beta1: 0.9 104 | adam_beta2: 0.999 105 | adam_epsilon: 1e-08 106 | 107 | scale_lr: true 108 | lr_scheduler: "constant_with_warmup" # see pdm.utils.arg_utils for available options 109 | lr_warmup_steps: 250 110 | 111 | 112 | hf_hub: 113 | push_to_hub: false 114 | hub_token: null 115 | hub_model_id: null 116 | 117 | logging: 118 | logging_dir: "path/to/logs/" 119 | 120 | report_to: "wandb" 121 | tracker_project_name: "text2image-dynamic-pruning" 122 | wandb_log_dir: "path/to/wandb" 123 | 124 | checkpoints_total_limit: 1 125 | auto_checkpoint_step: false 126 | resume_from_checkpoint: latest # or null 127 | 128 | -------------------------------------------------------------------------------- /configs/finetuning/sd-2-1_coco2014.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | unet: 3 | pretrained_model_name_or_path: stabilityai/stable-unet-2-1 4 | input_perturbation: 0.0 5 | revision: null 6 | resolution: 256 7 | use_ema: false 8 | noise_offset: 0.0 9 | prediction_type: v_prediction 10 | max_scheduler_steps: null 11 | unet_down_blocks: 12 | - CrossAttnDownBlock2DHalfGated 13 | - CrossAttnDownBlock2DHalfGated 14 | - CrossAttnDownBlock2DHalfGated 15 | - DownBlock2DHalfGated 16 | 17 | unet_mid_block: UNetMidBlock2DCrossAttnWidthGated 18 | 19 | unet_up_blocks: 20 | - UpBlock2DHalfGated 21 | - CrossAttnUpBlock2DHalfGated 22 | - CrossAttnUpBlock2DHalfGated 23 | - CrossAttnUpBlock2DHalfGated 24 | 25 | gated_ff: true 26 | ff_gate_width: 32 27 | 28 | data: 29 | dataset_name: coco 30 | data_files: null 31 | dataset_config_name: null 32 | data_dir: "path/to/coco" 33 | 34 | max_train_samples: null 35 | max_validation_samples: 10 36 | year: 2014 # 2014 or 2017. 2014 is used in the paper. 37 | 38 | image_column: "image" 39 | caption_column: "caption" 40 | 41 | prompts: null 42 | max_generated_samples: 2 43 | 44 | dataloader: 45 | dataloader_num_workers: 0 46 | train_batch_size: 128 47 | validation_batch_size: 32 48 | image_generation_batch_size: 8 49 | center_crop: false 50 | random_flip: true 51 | 52 | training: 53 | num_train_epochs: null 54 | max_train_steps: 30000 55 | hypernet_pretraining_steps: 500 56 | validation_steps: 1000 57 | image_logging_steps: 1000 58 | num_inference_steps: 50 # number of scheduler steps to run for image generation 59 | 60 | mixed_precision: null 61 | gradient_accumulation_steps: 1 62 | gradient_checkpointing: false 63 | local_rank: -1 64 | allow_tf32: false 65 | enable_xformers_memory_efficient_attention: false 66 | 67 | losses: 68 | diffusion_loss: 69 | snr_gamma: 5.0 70 | weight: 1.0 71 | 72 | resource_loss: 73 | type: log 74 | weight: 2.0 75 | pruning_target: 0.6 76 | 77 | contrastive_loss: 78 | arch_vector_temperature: 0.03 79 | prompt_embedding_temperature: 0.03 80 | weight: 100.0 81 | 82 | distillation_loss: 83 | weight: 0.2 84 | 85 | block_loss: 86 | weight: 0.2 87 | 88 | std_loss: 89 | weight: 0.1 90 | 91 | max_loss: 92 | weight: 0.1 93 | 94 | 95 | optim: 96 | unet_learning_rate: 1e-5 97 | unet_weight_decay: 0.00 98 | 99 | use_8bit_adam: false 100 | adam_beta1: 0.9 101 | adam_beta2: 0.999 102 | adam_epsilon: 1e-08 103 | 104 | scale_lr: true 105 | lr_scheduler: "constant_with_warmup" # see pdm.utils.arg_utils for available options 106 | lr_warmup_steps: 250 107 | 108 | 109 | hf_hub: 110 | push_to_hub: false 111 | hub_token: null 112 | hub_model_id: null 113 | 114 | logging: 115 | logging_dir: "path/to/logs/" 116 | 117 | report_to: "wandb" 118 | tracker_project_name: "text2image-dynamic-pruning" 119 | wandb_log_dir: "path/to/wandb" 120 | 121 | checkpoints_total_limit: 1 122 | auto_checkpoint_step: false 123 | resume_from_checkpoint: latest # or null 124 | 125 | -------------------------------------------------------------------------------- /configs/img_generation/sd-2-1_cc3m.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | unet: 3 | pretrained_model_name_or_path: stabilityai/stable-unet-2-1 4 | input_perturbation: 0.0 5 | revision: null 6 | resolution: 256 7 | use_ema: false 8 | noise_offset: 0.0 9 | prediction_type: v_prediction 10 | max_scheduler_steps: null 11 | unet_down_blocks: 12 | - CrossAttnDownBlock2DHalfGated 13 | - CrossAttnDownBlock2DHalfGated 14 | - CrossAttnDownBlock2DHalfGated 15 | - DownBlock2DHalfGated 16 | 17 | unet_mid_block: UNetMidBlock2DCrossAttnWidthGated 18 | 19 | unet_up_blocks: 20 | - UpBlock2DHalfGated 21 | - CrossAttnUpBlock2DHalfGated 22 | - CrossAttnUpBlock2DHalfGated 23 | - CrossAttnUpBlock2DHalfGated 24 | 25 | gated_ff: true 26 | ff_gate_width: 32 27 | 28 | data: 29 | dataset_name: cc3m 30 | data_files: null 31 | dataset_config_name: null 32 | data_dir: "/path/to/conceptual_captions" 33 | 34 | train_data_dir: "training" 35 | train_data_file: "Train_GCC-training.tsv" 36 | max_train_samples: null 37 | 38 | validation_data_dir: "validation" 39 | validation_data_file: "Validation_GCC-1.1.0-Validation.tsv" 40 | max_validation_samples: null 41 | 42 | image_column: "image" 43 | caption_column: "caption" 44 | 45 | dataloader: 46 | dataloader_num_workers: 0 47 | image_generation_batch_size: 8 48 | 49 | training: 50 | num_inference_steps: 25 # number of scheduler steps to run for image generation 51 | local_rank: -1 52 | enable_xformers_memory_efficient_attention: false 53 | 54 | 55 | -------------------------------------------------------------------------------- /configs/pruning/sd-2-1_cc3m.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | unet: 3 | pretrained_model_name_or_path: stabilityai/stable-unet-2-1 4 | input_perturbation: 0.0 5 | revision: null 6 | resolution: 256 7 | use_ema: false 8 | noise_offset: 0.0 9 | prediction_type: v_prediction 10 | max_scheduler_steps: null 11 | unet_down_blocks: 12 | - CrossAttnDownBlock2DHalfGated 13 | - CrossAttnDownBlock2DHalfGated 14 | - CrossAttnDownBlock2DHalfGated 15 | - DownBlock2DHalfGated 16 | 17 | unet_mid_block: UNetMidBlock2DCrossAttnWidthGated 18 | 19 | unet_up_blocks: 20 | - UpBlock2DHalfGated 21 | - CrossAttnUpBlock2DHalfGated 22 | - CrossAttnUpBlock2DHalfGated 23 | - CrossAttnUpBlock2DHalfGated 24 | 25 | gated_ff: true 26 | ff_gate_width: 32 27 | 28 | hypernet: 29 | weight_norm: false 30 | linear_bias: true 31 | single_arch_param: false # if true, train a single experts for all prompts (used as a baseline in the paper) 32 | 33 | quantizer: 34 | quantizer_T: 0.4 35 | quantizer_base: 3 36 | num_arch_vq_codebook_embeddings: 8 # number of experts 37 | arch_vq_beta: 0.25 38 | depth_order: [-1, -2, 0, 1, -3, -4, 2, 3, -5, -6, 4, 5, -7, 6] 39 | non_zero_width: true 40 | resource_aware_normalization: false 41 | optimal_transport: true 42 | 43 | data: 44 | dataset_name: cc3m 45 | data_files: null 46 | dataset_config_name: null 47 | data_dir: "/path/to/conceptual_captions" 48 | 49 | train_data_dir: "training" 50 | train_data_file: "Train_GCC-training.tsv" 51 | max_train_samples: null 52 | 53 | validation_data_dir: "validation" 54 | validation_data_file: "Validation_GCC-1.1.0-Validation.tsv" 55 | max_validation_samples: 1000 56 | 57 | image_column: "image" 58 | caption_column: "caption" 59 | 60 | prompts: null 61 | max_generated_samples: 16 62 | 63 | dataloader: 64 | dataloader_num_workers: 0 65 | train_batch_size: 64 66 | validation_batch_size: 16 67 | image_generation_batch_size: 4 68 | center_crop: false 69 | random_flip: true 70 | 71 | training: 72 | num_train_epochs: null 73 | max_train_steps: 5000 74 | hypernet_pretraining_steps: 500 75 | validation_steps: 1000 76 | image_logging_steps: 1000 77 | num_inference_steps: 50 # number of scheduler steps to run for image generation 78 | 79 | mixed_precision: null 80 | gradient_accumulation_steps: 1 81 | gradient_checkpointing: false 82 | local_rank: -1 83 | allow_tf32: false 84 | enable_xformers_memory_efficient_attention: false 85 | 86 | losses: 87 | diffusion_loss: 88 | snr_gamma: 5.0 89 | weight: 1.0 90 | 91 | resource_loss: 92 | type: log 93 | weight: 2.0 94 | pruning_target: 0.6 95 | 96 | contrastive_loss: 97 | arch_vector_temperature: 0.03 98 | prompt_embedding_temperature: 0.03 99 | weight: 100.0 100 | 101 | distillation_loss: 102 | weight: 0.2 103 | 104 | block_loss: 105 | weight: 0.2 106 | 107 | std_loss: 108 | weight: 0.1 109 | 110 | max_loss: 111 | weight: 0.1 112 | 113 | 114 | optim: 115 | hypernet_learning_rate: 2e-4 116 | quantizer_learning_rate: 2e-4 117 | unet_learning_rate: 5e-5 118 | 119 | quantizer_weight_decay: 0.00 120 | hypernet_weight_decay: 0.00 121 | unet_weight_decay: 0.00 122 | 123 | use_8bit_adam: false 124 | adam_beta1: 0.9 125 | adam_beta2: 0.999 126 | adam_epsilon: 1e-08 127 | 128 | scale_lr: true 129 | lr_scheduler: "constant_with_warmup" # see pdm.utils.arg_utils for available options 130 | lr_warmup_steps: 100 131 | 132 | 133 | hf_hub: 134 | push_to_hub: false 135 | hub_token: null 136 | hub_model_id: null 137 | 138 | logging: 139 | logging_dir: "path/to/logs/" 140 | 141 | report_to: "wandb" 142 | tracker_project_name: "text2image-dynamic-pruning" 143 | wandb_log_dir: "path/to/wandb" 144 | 145 | checkpoints_total_limit: 1 146 | auto_checkpoint_step: false 147 | resume_from_checkpoint: null # or latest 148 | 149 | -------------------------------------------------------------------------------- /configs/pruning/sd-2-1_coco2014.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | unet: 3 | pretrained_model_name_or_path: stabilityai/stable-unet-2-1 4 | input_perturbation: 0.0 5 | revision: null 6 | resolution: 256 7 | use_ema: false 8 | noise_offset: 0.0 9 | prediction_type: v_prediction 10 | max_scheduler_steps: null 11 | unet_down_blocks: 12 | - CrossAttnDownBlock2DHalfGated 13 | - CrossAttnDownBlock2DHalfGated 14 | - CrossAttnDownBlock2DHalfGated 15 | - DownBlock2DHalfGated 16 | 17 | unet_mid_block: UNetMidBlock2DCrossAttnWidthGated 18 | 19 | unet_up_blocks: 20 | - UpBlock2DHalfGated 21 | - CrossAttnUpBlock2DHalfGated 22 | - CrossAttnUpBlock2DHalfGated 23 | - CrossAttnUpBlock2DHalfGated 24 | 25 | gated_ff: true 26 | ff_gate_width: 32 27 | 28 | hypernet: 29 | weight_norm: false 30 | linear_bias: true 31 | single_arch_param: false # if true, train a single experts for all prompts (used as a baseline in the paper) 32 | 33 | quantizer: 34 | quantizer_T: 0.4 35 | quantizer_base: 3 36 | num_arch_vq_codebook_embeddings: 8 # number of experts 37 | arch_vq_beta: 0.25 38 | depth_order: [-1, -2, 0, 1, -3, -4, 2, 3, -5, -6, 4, 5, -7, 6] 39 | non_zero_width: true 40 | resource_aware_normalization: false 41 | optimal_transport: true 42 | 43 | data: 44 | dataset_name: coco 45 | data_files: null 46 | dataset_config_name: null 47 | data_dir: "path/to/coco" 48 | 49 | max_train_samples: null 50 | max_validation_samples: 10 51 | year: 2014 # 2014 or 2017. 2014 is used in the paper. 52 | 53 | image_column: "image" 54 | caption_column: "caption" 55 | 56 | prompts: null 57 | max_generated_samples: 2 58 | 59 | dataloader: 60 | dataloader_num_workers: 0 61 | train_batch_size: 64 62 | validation_batch_size: 16 63 | image_generation_batch_size: 4 64 | center_crop: false 65 | random_flip: true 66 | 67 | training: 68 | num_train_epochs: null 69 | max_train_steps: 5000 70 | hypernet_pretraining_steps: 500 71 | validation_steps: 1000 72 | image_logging_steps: 1000 73 | num_inference_steps: 50 # number of scheduler steps to run for image generation 74 | 75 | mixed_precision: null 76 | gradient_accumulation_steps: 1 77 | gradient_checkpointing: false 78 | local_rank: -1 79 | allow_tf32: false 80 | enable_xformers_memory_efficient_attention: false 81 | 82 | losses: 83 | diffusion_loss: 84 | snr_gamma: 5.0 85 | weight: 1.0 86 | 87 | resource_loss: 88 | type: log 89 | weight: 2.0 90 | pruning_target: 0.6 91 | 92 | contrastive_loss: 93 | arch_vector_temperature: 0.03 94 | prompt_embedding_temperature: 0.03 95 | weight: 100.0 96 | 97 | distillation_loss: 98 | weight: 0.2 99 | 100 | block_loss: 101 | weight: 0.2 102 | 103 | std_loss: 104 | weight: 0.1 105 | 106 | max_loss: 107 | weight: 0.1 108 | 109 | 110 | optim: 111 | hypernet_learning_rate: 2e-4 112 | quantizer_learning_rate: 2e-4 113 | unet_learning_rate: 5e-5 114 | 115 | quantizer_weight_decay: 0.00 116 | hypernet_weight_decay: 0.00 117 | unet_weight_decay: 0.00 118 | 119 | use_8bit_adam: false 120 | adam_beta1: 0.9 121 | adam_beta2: 0.999 122 | adam_epsilon: 1e-08 123 | 124 | scale_lr: true 125 | lr_scheduler: "constant_with_warmup" # see pdm.utils.arg_utils for available options 126 | lr_warmup_steps: 100 127 | 128 | 129 | hf_hub: 130 | push_to_hub: false 131 | hub_token: null 132 | hub_model_id: null 133 | 134 | logging: 135 | logging_dir: "path/to/logs/" 136 | 137 | report_to: "wandb" 138 | tracker_project_name: "text2image-dynamic-pruning" 139 | wandb_log_dir: "path/to/wandb" 140 | 141 | checkpoints_total_limit: 1 142 | auto_checkpoint_step: false 143 | resume_from_checkpoint: null # or latest 144 | 145 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: pdm 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - blas=1.0=mkl 10 | - brotli-python=1.0.9=py311h6a678d5_7 11 | - bzip2=1.0.8=h7b6447c_0 12 | - ca-certificates=2023.12.12=h06a4308_0 13 | - certifi=2023.11.17=py311h06a4308_0 14 | - cffi=1.16.0=py311h5eee18b_0 15 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 16 | - cryptography=41.0.7=py311hdda0065_0 17 | - cuda-cudart=11.8.89=0 18 | - cuda-cupti=11.8.87=0 19 | - cuda-libraries=11.8.0=0 20 | - cuda-nvrtc=11.8.89=0 21 | - cuda-nvtx=11.8.86=0 22 | - cuda-runtime=11.8.0=0 23 | - ffmpeg=4.3=hf484d3e_0 24 | - filelock=3.13.1=py311h06a4308_0 25 | - freetype=2.12.1=h4a9f257_0 26 | - giflib=5.2.1=h5eee18b_3 27 | - gmp=6.2.1=h295c915_3 28 | - gmpy2=2.1.2=py311hc9b5ff0_0 29 | - gnutls=3.6.15=he1e5248_0 30 | - idna=3.4=py311h06a4308_0 31 | - intel-openmp=2023.1.0=hdb19cb5_46306 32 | - jinja2=3.1.2=py311h06a4308_0 33 | - jpeg=9e=h5eee18b_1 34 | - lame=3.100=h7b6447c_0 35 | - lcms2=2.12=h3be6417_0 36 | - ld_impl_linux-64=2.38=h1181459_1 37 | - lerc=3.0=h295c915_0 38 | - libcublas=11.11.3.6=0 39 | - libcufft=10.9.0.58=0 40 | - libcufile=1.8.1.2=0 41 | - libcurand=10.3.4.101=0 42 | - libcusolver=11.4.1.48=0 43 | - libcusparse=11.7.5.86=0 44 | - libdeflate=1.17=h5eee18b_1 45 | - libffi=3.4.4=h6a678d5_0 46 | - libgcc-ng=11.2.0=h1234567_1 47 | - libgomp=11.2.0=h1234567_1 48 | - libiconv=1.16=h7f8727e_2 49 | - libidn2=2.3.4=h5eee18b_0 50 | - libjpeg-turbo=2.0.0=h9bf148f_0 51 | - libnpp=11.8.0.86=0 52 | - libnvjpeg=11.9.0.86=0 53 | - libpng=1.6.39=h5eee18b_0 54 | - libstdcxx-ng=11.2.0=h1234567_1 55 | - libtasn1=4.19.0=h5eee18b_0 56 | - libtiff=4.5.1=h6a678d5_0 57 | - libunistring=0.9.10=h27cfd23_0 58 | - libuuid=1.41.5=h5eee18b_0 59 | - libwebp=1.3.2=h11a3e52_0 60 | - libwebp-base=1.3.2=h5eee18b_0 61 | - llvm-openmp=14.0.6=h9e868ea_0 62 | - lz4-c=1.9.4=h6a678d5_0 63 | - markupsafe=2.1.1=py311h5eee18b_0 64 | - mkl=2023.1.0=h213fc3f_46344 65 | - mkl-service=2.4.0=py311h5eee18b_1 66 | - mkl_fft=1.3.8=py311h5eee18b_0 67 | - mkl_random=1.2.4=py311hdb19cb5_0 68 | - mpc=1.1.0=h10f8cd9_1 69 | - mpfr=4.0.2=hb69a4c5_1 70 | - mpmath=1.3.0=py311h06a4308_0 71 | - ncurses=6.4=h6a678d5_0 72 | - nettle=3.7.3=hbbd107a_1 73 | - networkx=3.1=py311h06a4308_0 74 | - numpy=1.26.2=py311h08b1b3b_0 75 | - numpy-base=1.26.2=py311hf175353_0 76 | - openh264=2.1.1=h4ff587b_0 77 | - openjpeg=2.4.0=h3ad879b_0 78 | - openssl=3.0.12=h7f8727e_0 79 | - pip=23.3.1=py311h06a4308_0 80 | - pycparser=2.21=pyhd3eb1b0_0 81 | - pyopenssl=23.2.0=py311h06a4308_0 82 | - pysocks=1.7.1=py311h06a4308_0 83 | - python=3.11.5=h955ad1f_0 84 | - pytorch-cuda=11.8=h7e8668a_5 85 | - pytorch-mutex=1.0=cuda 86 | - pyyaml=6.0.1=py311h5eee18b_0 87 | - readline=8.2=h5eee18b_0 88 | - requests=2.31.0=py311h06a4308_0 89 | - setuptools=68.2.2=py311h06a4308_0 90 | - sqlite=3.41.2=h5eee18b_0 91 | - sympy=1.12=py311h06a4308_0 92 | - tbb=2021.8.0=hdb19cb5_0 93 | - tk=8.6.12=h1ccaba5_0 94 | - typing_extensions=4.7.1=py311h06a4308_0 95 | - urllib3=1.26.18=py311h06a4308_0 96 | - wheel=0.41.2=py311h06a4308_0 97 | - xz=5.4.5=h5eee18b_0 98 | - yaml=0.2.5=h7b6447c_0 99 | - zlib=1.2.13=h5eee18b_0 100 | - zstd=1.5.5=hc292b87_0 101 | - pip: 102 | - accelerate==0.24.1 103 | - aiohttp==3.9.1 104 | - aiosignal==1.3.1 105 | - antlr4-python3-runtime==4.9.3 106 | - appdirs==1.4.4 107 | - attrs==23.1.0 108 | - bitsandbytes==0.41.3.post2 109 | - click==8.1.7 110 | - cmake==3.25.0 111 | - contourpy==1.2.0 112 | - cycler==0.12.1 113 | - datasets==2.14.5 114 | - diffusers==0.23.1 115 | - dill==0.3.7 116 | - docker-pycreds==0.4.0 117 | - fonttools==4.50.0 118 | - frozenlist==1.4.1 119 | - fsspec==2023.6.0 120 | - gitdb==4.0.11 121 | - gitpython==3.1.40 122 | - huggingface-hub==0.17.3 123 | - importlib-metadata==7.0.1 124 | - kiwisolver==1.4.5 125 | - lit==15.0.7 126 | - matplotlib==3.8.3 127 | - multidict==6.0.4 128 | - multiprocess==0.70.15 129 | - nvidia-cublas-cu12==12.1.3.1 130 | - nvidia-cuda-cupti-cu12==12.1.105 131 | - nvidia-cuda-nvrtc-cu12==12.1.105 132 | - nvidia-cuda-runtime-cu12==12.1.105 133 | - nvidia-cudnn-cu12==8.9.2.26 134 | - nvidia-cufft-cu12==11.0.2.54 135 | - nvidia-curand-cu12==10.3.2.106 136 | - nvidia-cusolver-cu12==11.4.5.107 137 | - nvidia-cusparse-cu12==12.1.0.106 138 | - nvidia-nccl-cu12==2.18.1 139 | - nvidia-nvjitlink-cu12==12.3.101 140 | - nvidia-nvtx-cu12==12.1.105 141 | - omegaconf==2.3.0 142 | - packaging==23.2 143 | - pandas==2.1.4 144 | - pdm==0.1.0 145 | - pillow==10.1.0 146 | - protobuf==4.25.1 147 | - psutil==5.9.7 148 | - pyarrow==14.0.2 149 | - pyparsing==3.1.2 150 | - python-dateutil==2.8.2 151 | - python-magic==0.4.27 152 | - pytz==2023.3.post1 153 | - regex==2023.12.25 154 | - safetensors==0.4.1 155 | - seaborn==0.13.2 156 | - sentry-sdk==1.39.1 157 | - setproctitle==1.3.3 158 | - six==1.16.0 159 | - smmap==5.0.1 160 | - tokenizers==0.14.1 161 | - torch==2.1.0 162 | - torchaudio==2.1.0 163 | - torchvision==0.16.0 164 | - tqdm==4.66.1 165 | - transformers==4.34.1 166 | - triton==2.1.0 167 | - tzdata==2023.3 168 | - wandb==0.16.1 169 | - xformers==0.0.22.post7 170 | - xxhash==3.4.1 171 | - yarl==1.9.4 172 | - zipp==3.17.0 -------------------------------------------------------------------------------- /pdm/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "2.2.0" 2 | -------------------------------------------------------------------------------- /pdm/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cc3m import load_cc3m_dataset 2 | from .coco import load_coco_dataset 3 | -------------------------------------------------------------------------------- /pdm/datasets/cc3m.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import pandas as pd 4 | from PIL import ImageFile 5 | from datasets import Dataset 6 | 7 | ImageFile.LOAD_TRUNCATED_IMAGES = True 8 | 9 | 10 | def load_cc3m_dataset(data_dir, split="train", split_file="Train_GCC-training.tsv", 11 | split_dir="training"): 12 | captions = pd.read_csv(os.path.join(data_dir, split_file), 13 | sep="\t", header=None, names=["caption", "link"], 14 | dtype={"caption": str, "link": str}) 15 | 16 | names_file = os.path.join(os.getcwd(), "../data", f"{split}_cc3m_names.pkl") 17 | if os.path.exists(names_file): 18 | with open(names_file, 'rb') as file: 19 | images = pickle.load(file) 20 | 21 | else: 22 | images = os.listdir(os.path.join(data_dir, split_dir)) 23 | with open(names_file, 'wb') as file: 24 | pickle.dump(images, file) 25 | 26 | images = [os.path.join(data_dir, split_dir, image) for image in images] 27 | 28 | image_indices = [int(os.path.basename(image).split("_")[0]) for image in images] 29 | captions = captions.iloc[image_indices].caption.values.tolist() 30 | dataset = Dataset.from_dict({"image": images, "caption": captions}) 31 | return dataset 32 | -------------------------------------------------------------------------------- /pdm/datasets/coco.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from datasets import Dataset 5 | 6 | 7 | def load_coco_dataset(images_dir, annotations_file): 8 | captions_file = json.load(open(annotations_file)) 9 | images = [] 10 | captions = [] 11 | for capt in captions_file['annotations']: 12 | if '2014' in images_dir: 13 | split_name = os.path.basename(images_dir) 14 | image_path = os.path.join(images_dir, f"COCO_{split_name}_%012d.jpg" % capt['image_id']) 15 | else: 16 | image_path = os.path.join(images_dir, "%012d.jpg" % capt['image_id']) 17 | caption = capt['caption'] 18 | images.append(image_path) 19 | captions.append(caption) 20 | dataset = Dataset.from_dict({"image": images, "caption": captions}) 21 | return dataset -------------------------------------------------------------------------------- /pdm/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .contrastive_loss import ContrastiveLoss 2 | from .resource_loss import ResourceLoss 3 | -------------------------------------------------------------------------------- /pdm/losses/contrastive_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class ContrastiveLoss(nn.Module): 6 | def __init__(self, arch_vector_temperature=1.0, prompt_embedding_temperature=1.0): 7 | super().__init__() 8 | self.arch_vector_temperature = arch_vector_temperature 9 | self.prompt_embedding_temperature = prompt_embedding_temperature 10 | 11 | def forward(self, prompt_embeddings, arch_vectors, return_similarity=False): 12 | arch_vectors_normalized = arch_vectors / arch_vectors.norm(dim=1, keepdim=True) 13 | prompt_embeddings = prompt_embeddings / prompt_embeddings.norm(dim=1, keepdim=True) 14 | arch_vectors_similarity = F.softmax( 15 | (arch_vectors_normalized @ arch_vectors_normalized.T) / self.arch_vector_temperature, dim=-1) 16 | texts_similarity = F.softmax((prompt_embeddings @ prompt_embeddings.T) / self.prompt_embedding_temperature, 17 | dim=-1) 18 | loss = F.binary_cross_entropy(arch_vectors_similarity.T, texts_similarity.T, reduction='mean') 19 | if return_similarity: 20 | return loss, arch_vectors_similarity.detach().cpu().numpy() 21 | else: 22 | return loss 23 | -------------------------------------------------------------------------------- /pdm/losses/resource_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ResourceLoss(nn.Module): 6 | def __init__(self, p=0.9, loss_type='log'): 7 | super().__init__() 8 | assert loss_type in ["log", "mae", "mse"], f"Unknown loss type {loss_type}" 9 | self.p = p 10 | self.loss_type = loss_type 11 | 12 | def forward(self, resource_ratio): 13 | if self.loss_type == "log": 14 | if resource_ratio > self.p: 15 | resource_loss = torch.log(resource_ratio / self.p) 16 | else: 17 | resource_loss = torch.log(self.p / resource_ratio) 18 | elif self.loss_type == "mae": 19 | resource_loss = torch.abs(resource_ratio - self.p) 20 | else: 21 | resource_loss = (resource_ratio - self.p) ** 2 22 | 23 | return resource_loss 24 | -------------------------------------------------------------------------------- /pdm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet.unet_2d_conditional import UNet2DConditionModelGated, UNet2DConditionModelPruned 2 | from .hypernet.hypernet import HyperStructure 3 | from .vq.quantizer import StructureVectorQuantizer 4 | -------------------------------------------------------------------------------- /pdm/models/hypernet/__init__.py: -------------------------------------------------------------------------------- 1 | from pdm.models.hypernet.hypernet import HyperStructure 2 | -------------------------------------------------------------------------------- /pdm/models/hypernet/hypernet.py: -------------------------------------------------------------------------------- 1 | """ 2 | credit to: 3 | https://github.com/Alii-Ganjj/InterpretationsSteeredPruning/blob/96e5be3c721714bda76a4ab814ff5fe0ddf4417d/Models/hypernet.py (thanks!) 4 | """ 5 | 6 | from __future__ import absolute_import 7 | 8 | import os 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.nn.utils.parametrizations import weight_norm 14 | from diffusers import ModelMixin, ConfigMixin 15 | from diffusers.configuration_utils import register_to_config 16 | from pdm.utils.estimation_utils import hard_concrete 17 | 18 | 19 | class SimpleGate(nn.Module): 20 | def __init__(self, width): 21 | super(SimpleGate, self).__init__() 22 | self.weight = nn.Parameter(torch.randn(width)) 23 | 24 | def forward(self): 25 | return self.weight 26 | 27 | 28 | class HyperStructure(ModelMixin, ConfigMixin): 29 | @register_to_config 30 | def __init__(self, structure, input_dim=768, wn_flag=True, linear_bias=False, single_arch_param=False): 31 | super(HyperStructure, self).__init__() 32 | 33 | self.structure = structure 34 | self.input_dim = input_dim 35 | self.linear_bias = linear_bias 36 | self.wn_flag = wn_flag 37 | 38 | self.width_list = [w for sub_width_list in self.structure['width'] for w in sub_width_list] 39 | self.depth_list = [d for sub_depth_list in self.structure['depth'] for d in sub_depth_list] 40 | 41 | self.single_arch_param = single_arch_param 42 | if self.single_arch_param: 43 | self.arch = nn.Parameter(torch.randn(1, sum(self.width_list) + sum(self.depth_list))) 44 | self.arch_gs = torch.zeros(1, sum(self.width_list) + sum(self.depth_list)) 45 | else: 46 | width_linear_list = [nn.Linear(self.input_dim, self.width_list[i], bias=self.linear_bias) for i in 47 | range(len(self.width_list))] 48 | depth_linear_layer = nn.Linear(self.input_dim, sum(self.depth_list), bias=self.linear_bias) 49 | 50 | linear_list = width_linear_list + [depth_linear_layer] 51 | 52 | if self.wn_flag: 53 | linear_list = [weight_norm(linear) for linear in linear_list] 54 | 55 | self.mh_fc = torch.nn.ModuleList(linear_list) 56 | self.initialize_weights() 57 | 58 | def initialize_weights(self): 59 | for name, param in self.named_parameters(): 60 | if "weight" in name: 61 | nn.init.orthogonal_(param) 62 | elif "bias" in name: 63 | nn.init.zeros_(param) 64 | 65 | def forward(self, x): 66 | if self.single_arch_param: 67 | # repeat the same architecture for all samples in the batch 68 | return self.arch 69 | else: 70 | return self._forward(x) 71 | 72 | def _forward(self, x): 73 | # x: B * L * D 74 | if self.mh_fc[0].weight.is_cuda: 75 | x = x.cuda() 76 | outputs = [self.mh_fc[i](x) for i in range(len(self.mh_fc))] 77 | out = torch.cat(outputs, dim=1) 78 | 79 | return out 80 | 81 | def print_param_stats(self): 82 | for name, param in self.named_parameters(): 83 | if "weight" in name: 84 | print(f"{name}: {param.mean()}, {param.std()}") 85 | 86 | def transform_structure_vector(self, inputs): 87 | assert inputs.shape[1] == (sum(self.width_list) + sum(self.depth_list)) 88 | width_list = [] 89 | depth_list = [] 90 | width_vectors = inputs[:, :sum(self.width_list)] 91 | depth_vectors = inputs[:, sum(self.width_list):] 92 | start = 0 93 | for i in range(len(self.width_list)): 94 | end = start + self.width_list[i] 95 | width_list.append(width_vectors[:, start:end]) 96 | start = end 97 | 98 | for i in range(sum(self.depth_list)): 99 | depth_list.append(depth_vectors[:, i]) 100 | 101 | return {"width": width_list, "depth": depth_list} 102 | 103 | @classmethod 104 | def transform_arch_vector(cls, inputs, structure, force_width_non_zero=False): 105 | width_list = [w for sub_width_list in structure['width'] for w in sub_width_list] 106 | depth_list = [d for sub_depth_list in structure['depth'] for d in sub_depth_list] 107 | assert inputs.shape[1] == (sum(width_list) + sum(depth_list)) 108 | width_vectors = inputs[:, :sum(width_list)] 109 | depth_vectors = inputs[:, sum(width_list):] 110 | start = 0 111 | w_list = [] 112 | d_list = [] 113 | for i in range(len(width_list)): 114 | end = start + width_list[i] 115 | w_sub_list = width_vectors[:, start:end] 116 | # This shouldn't be necessary, but just in case 117 | if force_width_non_zero: 118 | w_sub_list_sum = hard_concrete(w_sub_list).sum(dim=1) 119 | if not w_sub_list_sum.all(): 120 | ind = (w_sub_list_sum == 0) 121 | w_sub_list = w_sub_list.clone() 122 | w_sub_list[ind, 0] = w_sub_list[ind, 0] + 0.5 123 | w_list.append(w_sub_list) 124 | start = end 125 | 126 | for i in range(sum(depth_list)): 127 | d_list.append(depth_vectors[:, i]) 128 | 129 | return {"width": w_list, "depth": d_list} 130 | 131 | @classmethod 132 | def get_random_arch_vector(cls, target_ratio, structure): 133 | # randomly generate the width and depth vectors so each sublist has target_ratio of elements greater than 0.5 134 | width_list = [w for sub_width_list in structure['width'] for w in sub_width_list] 135 | depth_list = [d for sub_depth_list in structure['depth'] for d in sub_depth_list] 136 | arch_vectors = [] 137 | start = 0 138 | for i in range(len(width_list)): 139 | end = start + width_list[i] 140 | w_sub_list = torch.zeros(1, width_list[i]) 141 | num_non_zero = int(target_ratio * width_list[i]) 142 | # randomly select num_non_zero indices to set to 1 143 | non_zero_indices = torch.randperm(width_list[i])[:num_non_zero] 144 | w_sub_list[0, non_zero_indices] = 0.9 145 | 146 | arch_vectors.append(w_sub_list) 147 | start = end 148 | 149 | for i in range(sum(depth_list)): 150 | arch_vectors.append(torch.tensor([[0.9]])) 151 | 152 | arch_vectors = torch.cat(arch_vectors, dim=1) 153 | return arch_vectors 154 | -------------------------------------------------------------------------------- /pdm/models/unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet_2d_conditional import UNet2DConditionModelGated, UNet2DConditionModelPruned, UNet2DConditionModelMagnitudePruned 2 | from .blocks import * 3 | from .gates import DepthGate, WidthGate 4 | -------------------------------------------------------------------------------- /pdm/models/unet/gates.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script contains the implementations of gate functions and their gradient calculation. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.autograd.function 7 | 8 | 9 | class VirtualGate(nn.Module): 10 | def __init__(self, width, bs=1): 11 | super(VirtualGate, self).__init__() 12 | self.width = width 13 | self.gate_f = torch.ones(bs, width) 14 | 15 | def forward(self, x): 16 | mask = self.gate_f.repeat_interleave(x.shape[1] // self.width, dim=1).unsqueeze(-1).unsqueeze(-1) 17 | # to handle cfg where actual batch size is double the value of the batch size used to create the gate. 18 | if mask.shape[0] != x.shape[0]: 19 | mask = mask.repeat(x.shape[0] // mask.shape[0], 1, 1, 1) 20 | x = mask.expand_as(x) * x 21 | return x 22 | 23 | def set_structure_value(self, value): 24 | self.gate_f = value 25 | 26 | 27 | class WidthGate(VirtualGate): 28 | def __init__(self, width): 29 | super(WidthGate, self).__init__(width) 30 | 31 | 32 | class DepthGate(VirtualGate): 33 | def __init__(self, width): 34 | super(DepthGate, self).__init__(width) 35 | 36 | def forward(self, x): 37 | input_hidden_states, output_tensor = x 38 | mask = self.gate_f.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 39 | if mask.shape[0] != output_tensor.shape[0]: 40 | mask = mask.repeat(output_tensor.shape[0] // mask.shape[0], 1, 1, 1) 41 | output = (1 - mask) * input_hidden_states + mask * output_tensor 42 | return output 43 | 44 | 45 | class LinearWidthGate(WidthGate): 46 | def __init__(self, width): 47 | super(LinearWidthGate, self).__init__(width) 48 | 49 | def forward(self, x): 50 | mask = self.gate_f.repeat_interleave(x.shape[-1] // self.width, dim=1).unsqueeze(1) 51 | # to handle cfg where actual batch size is double the value of the batch size used to create the gate. 52 | if mask.shape[0] != x.shape[0]: 53 | mask = mask.repeat(x.shape[0] // mask.shape[0], 1, 1) 54 | x = mask.expand_as(x) * x 55 | return x 56 | -------------------------------------------------------------------------------- /pdm/models/vq/__init__.py: -------------------------------------------------------------------------------- 1 | from .quantizer import StructureVectorQuantizer -------------------------------------------------------------------------------- /pdm/models/vq/quantizer.py: -------------------------------------------------------------------------------- 1 | # credit to taming-transformers: 2 | # https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L112 3 | import os 4 | from typing import Tuple 5 | 6 | import numpy as np 7 | import torch 8 | from diffusers.configuration_utils import register_to_config, ConfigMixin 9 | from torch import nn, Tensor 10 | from pdm.utils.estimation_utils import gumbel_softmax_sample, hard_concrete, importance_gumbel_softmax_sample 11 | from diffusers import ModelMixin 12 | import torch.distributed as dist 13 | 14 | 15 | class StructureVectorQuantizer(ModelMixin, ConfigMixin): 16 | """ 17 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix 18 | multiplications and allows for post-hoc remapping of indices. 19 | """ 20 | 21 | # NOTE: due to a bug the beta term was applied to the wrong term. for 22 | # backwards compatibility we use the buggy version by default, but you can 23 | # specify legacy=False to fix it. 24 | @register_to_config 25 | def __init__( 26 | self, 27 | n_e: int, 28 | structure: dict, 29 | beta: float = 0.25, 30 | remap=None, 31 | unknown_index: str = "random", 32 | sane_index_shape: bool = True, 33 | temperature: float = 0.4, 34 | base: int = 2, 35 | depth_order: list = None, 36 | non_zero_width: bool = True, 37 | sinkhorn_epsilon: float = 0.05, 38 | sinkhorn_iterations: int = 3, 39 | resource_aware_normalization: bool = True, 40 | optimal_transport: bool = True, 41 | ): 42 | super().__init__() 43 | 44 | vq_embed_dim = 0 45 | for w_config, d_config in zip(structure['width'], structure['depth']): 46 | vq_embed_dim += sum(w_config) 47 | if d_config == [1]: 48 | # depth_indices.append(vq_embed_dim) 49 | vq_embed_dim += 1 50 | 51 | self.n_e = n_e 52 | self.vq_embed_dim = vq_embed_dim 53 | self.beta = beta 54 | 55 | self.structure = structure 56 | 57 | self.width_list = [w for sub_width_list in self.structure['width'] for w in sub_width_list] 58 | self.width_list_sum = [sum(sub_width_list) for sub_width_list in self.structure['width']] 59 | width_indices = [0] + np.cumsum(self.width_list_sum).tolist() 60 | self.width_intervals = [(width_indices[i], width_indices[i + 1]) for i in range(len(width_indices) - 1)] 61 | 62 | self.depth_list = [d for sub_depth_list in self.structure['depth'] for d in sub_depth_list] 63 | widths_sum = sum(self.width_list) - 1 64 | self.depth_indices = (widths_sum + np.cumsum(self.depth_list)).tolist() 65 | 66 | num_depth_block = sum(self.depth_list) 67 | if depth_order is None: 68 | depth_order = list(range(num_depth_block)) 69 | self.input_depth_order = depth_order 70 | self.depth_order = [i % num_depth_block for i in depth_order] 71 | 72 | template = torch.tensor(self.width_list + [d for d in self.depth_list if d != 0]) 73 | template = torch.repeat_interleave(template, template).type(torch.float32) 74 | self.template = (1.0 / template).requires_grad_(False) 75 | self.prunable_macs_template = None 76 | if resource_aware_normalization is None: 77 | resource_aware_normalization = True 78 | self.resource_effect_normalization = resource_aware_normalization 79 | 80 | self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) 81 | nn.init.orthogonal_(self.embedding.weight) 82 | self.embedding_gs = nn.Parameter(self.embedding.weight.detach().clone(), requires_grad=False) 83 | 84 | self.remap = remap 85 | if self.remap is not None: 86 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 87 | self.used: torch.Tensor 88 | self.re_embed = self.used.shape[0] 89 | self.unknown_index = unknown_index # "random" or "extra" or integer 90 | if self.unknown_index == "extra": 91 | self.unknown_index = self.re_embed 92 | self.re_embed = self.re_embed + 1 93 | print( 94 | f"Remapping {self.n_e} indices to {self.re_embed} indices." 95 | f"Using {self.unknown_index} for unknown indices." 96 | ) 97 | else: 98 | self.re_embed = n_e 99 | 100 | self.sane_index_shape = sane_index_shape 101 | 102 | self.temperature = temperature 103 | self.base = base 104 | 105 | # Avoid cases that width gets zero but we don't actually want to remove the whole block. 106 | self.non_zero_width = non_zero_width 107 | 108 | self.optimal_transport = optimal_transport 109 | self.sinkhorn_epsilon = sinkhorn_epsilon 110 | self.sinkhorn_iterations = sinkhorn_iterations 111 | 112 | def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor: 113 | ishape = inds.shape 114 | assert len(ishape) > 1 115 | inds = inds.reshape(ishape[0], -1) 116 | used = self.used.to(inds) 117 | match = (inds[:, :, None] == used[None, None, ...]).long() 118 | new = match.argmax(-1) 119 | unknown = match.sum(2) < 1 120 | if self.unknown_index == "random": 121 | new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) 122 | else: 123 | new[unknown] = self.unknown_index 124 | return new.reshape(ishape) 125 | 126 | def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor: 127 | ishape = inds.shape 128 | assert len(ishape) > 1 129 | inds = inds.reshape(ishape[0], -1) 130 | used = self.used.to(inds) 131 | if self.re_embed > self.used.shape[0]: # extra token 132 | inds[inds >= self.used.shape[0]] = 0 # simply set to zero 133 | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) 134 | return back.reshape(ishape) 135 | 136 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Tuple]: 137 | z = z.contiguous() 138 | z_flattened = z.view(-1, self.vq_embed_dim) 139 | 140 | if self.training: 141 | embedding_gs = self.gumbel_sigmoid_trick(self.embedding.weight) 142 | self.embedding_gs.data = embedding_gs.detach() 143 | if self.optimal_transport: 144 | min_encoding_indices = self.get_optimal_transport_min_encoding_indices(z_flattened) 145 | else: 146 | min_encoding_indices = self.get_cosine_sim_min_encoding_indices(z_flattened) 147 | else: 148 | embedding_gs = self.embedding_gs.detach() 149 | min_encoding_indices = self.get_cosine_sim_min_encoding_indices(z_flattened) 150 | 151 | z_q = embedding_gs[min_encoding_indices].view(z.shape) 152 | 153 | z_q_out = z_q.contiguous() 154 | 155 | perplexity = None 156 | min_encodings = None 157 | 158 | if self.remap is not None: 159 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis 160 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 161 | min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten 162 | 163 | if self.sane_index_shape: 164 | min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0]) 165 | 166 | if not self.training: 167 | z_q_out = hard_concrete(z_q) 168 | 169 | return z_q_out, (perplexity, min_encodings, min_encoding_indices) 170 | 171 | def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...] = None) -> torch.Tensor: 172 | # shape specifying (batch, dim) 173 | if self.remap is not None: 174 | indices = indices.reshape(shape[0], -1) # add batch axis 175 | indices = self.unmap_to_all(indices) 176 | indices = indices.reshape(-1) # flatten again 177 | 178 | # get quantized latent vectors 179 | z_q: torch.Tensor = self.embedding(indices) 180 | 181 | if shape is not None: 182 | z_q = z_q.view(shape) 183 | # reshape back to match original input shape 184 | z_q = z_q.contiguous() 185 | 186 | return z_q 187 | 188 | def get_codebook_entry_gumbel_sigmoid(self, indices: torch.LongTensor, shape: Tuple[int, ...] = None, 189 | hard=False) -> torch.Tensor: 190 | z_q = self.get_codebook_entry(indices, shape).contiguous() 191 | if hard: 192 | return hard_concrete(self.gumbel_sigmoid_trick(z_q)) 193 | else: 194 | return self.gumbel_sigmoid_trick(z_q) 195 | 196 | def gumbel_sigmoid_trick(self, z_q: torch.Tensor): 197 | num_width = sum(self.width_list) 198 | z_q_width = z_q[:, :num_width] 199 | z_q_depth = z_q[:, num_width:] 200 | 201 | fixed_seed = not self.training 202 | 203 | z_q_depth_b_ = importance_gumbel_softmax_sample(z_q_depth, temperature=self.temperature, offset=self.base, 204 | fixed_seed=fixed_seed) 205 | z_q_depth_b = torch.zeros_like(z_q_depth_b_, device=z_q_depth_b_.device) 206 | z_q_depth_b[:, self.depth_order] = z_q_depth_b_ 207 | 208 | z_q_width_list = self._transform_width_vector(z_q_width) 209 | z_q_width_b_list = [gumbel_softmax_sample(zw, temperature=self.temperature, offset=self.base, 210 | force_width_non_zero=self.non_zero_width, fixed_seed=fixed_seed) for 211 | zw in z_q_width_list] 212 | z_q_width_b = torch.cat(z_q_width_b_list, dim=1) 213 | 214 | z_q_out = torch.cat([z_q_width_b, z_q_depth_b], dim=1) 215 | return z_q_out 216 | 217 | def print_param_stats(self): 218 | for name, param in self.named_parameters(): 219 | if "weight" in name: 220 | print(f"{name}: {param.mean()}, {param.std()}") 221 | 222 | def _transform_width_vector(self, inputs): 223 | assert inputs.shape[1] == sum(self.width_list) 224 | arch_vector = [] 225 | start = 0 226 | for i in range(len(self.width_list)): 227 | end = start + self.width_list[i] 228 | arch_vector.append(inputs[:, start:end]) 229 | start = end 230 | 231 | return arch_vector 232 | 233 | def width_depth_normalize(self, inputs): 234 | self.template = self.template.to(inputs.device) 235 | if self.resource_effect_normalization: 236 | self.prunable_macs_template = self.prunable_macs_template.to(inputs.device) 237 | # Multiply the slice of the arch_vectors defined by the start and end index of the width of the block with the 238 | # corresponding depth element of the arch_vectors. 239 | inputs_clone = hard_concrete(inputs.clone()) 240 | for i, elem in enumerate(self.depth_list): 241 | if elem != 0: 242 | inputs_clone[:, self.width_intervals[i][0]:self.width_intervals[i][1]] = ( 243 | inputs[:, self.width_intervals[i][0]:self.width_intervals[i][1]] * 244 | inputs[:, self.depth_indices[i]:(self.depth_indices[i] + 1)]) 245 | 246 | outputs = inputs_clone * (torch.sqrt(self.template).detach()) 247 | if self.resource_effect_normalization: 248 | outputs = outputs * self.prunable_macs_template.detach() 249 | 250 | return outputs 251 | 252 | def set_prunable_macs_template(self, prunable_macs_list): 253 | depth_template = [] 254 | for i, elem in enumerate(self.depth_list): 255 | if elem == 1: 256 | depth_template.append([sum(prunable_macs_list[i])]) 257 | prunable_macs_list += depth_template 258 | prunable_macs_list = [item for sublist in prunable_macs_list for item in sublist] 259 | self.prunable_macs_template = torch.repeat_interleave(torch.tensor(prunable_macs_list), 260 | torch.tensor(self.width_list + [1 for _ in range( 261 | len(depth_template))])) 262 | 263 | @torch.no_grad() 264 | def get_cosine_sim_min_encoding_indices(self, z: torch.Tensor) -> torch.Tensor: 265 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 266 | u = self.width_depth_normalize(self.gumbel_sigmoid_trick(z)) 267 | u = u / u.norm(dim=-1, keepdim=True) 268 | v = self.width_depth_normalize(self.embedding_gs) 269 | v = v / v.norm(dim=-1, keepdim=True) 270 | min_encoding_indices = torch.argmax(u @ v.t(), dim=-1) 271 | return min_encoding_indices 272 | 273 | @torch.no_grad() 274 | def get_optimal_transport_min_encoding_indices(self, a: torch.Tensor) -> torch.Tensor: 275 | # credit to 276 | # https://github.com/facebookresearch/swav/blob/06b1b7cbaf6ba2a792300d79c7299db98b93b7f9/main_swav.py#L354 277 | @torch.no_grad() 278 | def distributed_sinkhorn(out): 279 | Q = torch.exp(out / self.sinkhorn_epsilon).t() # Q is K-by-B for consistency with notations from the paper 280 | B = Q.shape[1] * dist.get_world_size() 281 | K = Q.shape[0] 282 | 283 | # make the matrix sums to 1 284 | sum_Q = torch.sum(Q) 285 | dist.all_reduce(sum_Q) 286 | Q /= sum_Q 287 | 288 | for it in range(self.sinkhorn_iterations): 289 | # normalize each row: total weight per prototype must be 1/K 290 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True) 291 | dist.all_reduce(sum_of_rows) 292 | Q /= sum_of_rows 293 | Q /= K 294 | 295 | # normalize each column: total weight per sample must be 1/B 296 | Q /= torch.sum(Q, dim=0, keepdim=True) 297 | Q /= B 298 | 299 | Q *= B # the columns must sum to 1 so that Q is an assignment 300 | return Q.t() 301 | 302 | @torch.no_grad() 303 | def sinkhorn(out): 304 | Q = torch.exp(out / self.sinkhorn_epsilon).t() # Q is K-by-B for consistency with notations from the paper 305 | B = Q.shape[1] 306 | K = Q.shape[0] 307 | 308 | # make the matrix sums to 1 309 | sum_Q = torch.sum(Q) 310 | Q /= sum_Q 311 | 312 | for it in range(self.sinkhorn_iterations): 313 | # normalize each row: total weight per prototype must be 1/K 314 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True) 315 | Q /= sum_of_rows 316 | Q /= K 317 | 318 | # normalize each column: total weight per sample must be 1/B 319 | Q /= torch.sum(Q, dim=0, keepdim=True) 320 | Q /= B 321 | 322 | Q *= B # the colomns must sum to 1 so that Q is an assignment 323 | return Q.t() 324 | 325 | a = self.gumbel_sigmoid_trick(a) 326 | a = self.width_depth_normalize(a) 327 | a = a / a.norm(dim=-1, keepdim=True) 328 | 329 | codes = self.embedding_gs 330 | codes = self.width_depth_normalize(codes) 331 | codes = codes / codes.norm(dim=-1, keepdim=True) 332 | 333 | out = a @ codes.t() 334 | 335 | if dist.is_initialized(): 336 | Q = distributed_sinkhorn(out) 337 | else: 338 | Q = sinkhorn(out) 339 | min_encoding_indices = torch.argmax(Q, dim=-1) 340 | return min_encoding_indices 341 | -------------------------------------------------------------------------------- /pdm/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .pruning_pipelines import StableDiffusionPruningPipeline 2 | -------------------------------------------------------------------------------- /pdm/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rezashkv/diffusion_pruning/d0709c3a750c6b9ebb25a2901284cffdca98d769/pdm/training/__init__.py -------------------------------------------------------------------------------- /pdm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .arg_utils import parse_args 2 | from .metric_utils import compute_snr 3 | -------------------------------------------------------------------------------- /pdm/utils/arg_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def parse_args(): 6 | parser = argparse.ArgumentParser(description="Dynamic Pruning of StableDiffusion-2.1") 7 | parser.add_argument( 8 | "--pretrained_model_name_or_path", 9 | type=str, 10 | default="stabilityai/stable-diffusion-2-1", 11 | required=False, 12 | help="Path to pretrained model or model identifier from huggingface.co/models.", 13 | ) 14 | parser.add_argument( 15 | "--clip_model_name_or_path", 16 | type=str, 17 | default="laion/CLIP-ViT-H-14-laion2B-s32B-b79K", 18 | required=False, 19 | help="Path to pretrained clip model or model identifier from huggingface.co/models.", 20 | ) 21 | parser.add_argument("--prompt_encoder_model_name_or_path", 22 | type=str, 23 | default="sentence-transformers/all-mpnet-base-v2", 24 | required=False, 25 | help="Path to pretrained prompt encoder model or model identifier from huggingface.co/models.") 26 | parser.add_argument( 27 | "--base_config_path", 28 | type=str, 29 | required=True, 30 | help="Path to the model/data/training config file.", 31 | ) 32 | parser.add_argument( 33 | "--cache_dir", 34 | type=str, 35 | default=None, 36 | help="Path to the model/data/training config file.", 37 | ) 38 | parser.add_argument( 39 | "--pruning_ckpt_dir", 40 | type=str, 41 | default=None, 42 | help="Path to the saved pruning checkpoint dir. used for finetuning.", 43 | ) 44 | parser.add_argument( 45 | "--finetuning_ckpt_dir", 46 | type=str, 47 | default=None, 48 | help="Path to the saved finetuning checkpoint dir. used for image generation.", 49 | ) 50 | parser.add_argument( 51 | "--use_ema", 52 | action="store_true", 53 | help="Whether to use EMA model.", 54 | ) 55 | parser.add_argument( 56 | "--non_ema_revision", 57 | type=str, 58 | default=None, 59 | required=False, 60 | help="Revision of pretrained model identifier from huggingface.co/models.", 61 | ) 62 | parser.add_argument( 63 | "--revision", 64 | type=str, 65 | default=None, 66 | required=False, 67 | help="Revision of pretrained model identifier from huggingface.co/models.", 68 | ) 69 | parser.add_argument( 70 | "--seed", 71 | type=int, 72 | default=43, 73 | help="A seed for reproducible training." 74 | ) 75 | parser.add_argument( 76 | "--mixed_precision", 77 | type=str, 78 | default=None, 79 | choices=["no", "fp16", "bf16"], 80 | help=( 81 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 82 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 83 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 84 | ), 85 | ) 86 | parser.add_argument( 87 | "--tracker_project_name", 88 | type=str, 89 | default="text2image-dynamic-pruning", 90 | help=( 91 | "The `project_name` argument passed to Accelerator.init_trackers for more information see " 92 | "https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" 93 | ), 94 | ) 95 | parser.add_argument( 96 | "--expert_id", 97 | type=int, 98 | default=None, 99 | help="Index of the expert to finetune", 100 | ) 101 | parser.add_argument("--pruning_type", 102 | type=str, 103 | default="multi-expert", 104 | choices=["multi-expert", "single-expert"], 105 | help="Type of pruning to perform. Used in calculate_pruning_ratio.py") 106 | parser.add_argument( 107 | "--wandb_run_name", 108 | type=str, 109 | default=None, 110 | help="The `run_name` argument passed to Accelerator.init_trackers" 111 | ) 112 | parser.add_argument( 113 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 114 | ) 115 | parser.add_argument("--push_to_hub", action="store_true", 116 | help="Whether or not to push the model to the Hub.") 117 | parser.add_argument("--local_rank", type=int, default=-1, 118 | help="For distributed training: local_rank") 119 | 120 | args = parser.parse_args() 121 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 122 | if env_local_rank != -1 and env_local_rank != args.local_rank: 123 | args.local_rank = env_local_rank 124 | 125 | # default to using the same revision for the non-ema model if not specified 126 | if args.non_ema_revision is None: 127 | args.non_ema_revision = args.revision 128 | 129 | # return parser 130 | return args 131 | -------------------------------------------------------------------------------- /pdm/utils/clip_utils.py: -------------------------------------------------------------------------------- 1 | # Credits: https://github.com/Taited/clip-score/blob/master/src/clip_score/clip_score.py 2 | """Calculates the CLIP Scores 3 | 4 | The CLIP model is a contrasitively learned language-image model. There is 5 | an image encoder and a text encoder. It is believed that the CLIP model could 6 | measure the similarity of cross modalities. Please find more information from 7 | https://github.com/openai/CLIP. 8 | 9 | The CLIP Score measures the Cosine Similarity between two embedded features. 10 | This repository utilizes the pretrained CLIP Model to calculate 11 | the mean average of cosine similarities. 12 | 13 | See --help to see further details. 14 | 15 | Code apapted from https://github.com/mseitzer/pytorch-fid and https://github.com/openai/CLIP. 16 | 17 | Copyright 2023 The Hong Kong Polytechnic University 18 | 19 | Licensed under the Apache License, Version 2.0 (the "License"); 20 | you may not use this file except in compliance with the License. 21 | You may obtain a copy of the License at 22 | 23 | http://www.apache.org/licenses/LICENSE-2.0 24 | 25 | Unless required by applicable law or agreed to in writing, software 26 | distributed under the License is distributed on an "AS IS" BASIS, 27 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 28 | See the License for the specific language governing permissions and 29 | limitations under the License. 30 | """ 31 | import os 32 | import os.path as osp 33 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 34 | 35 | import clip 36 | import torch 37 | import numpy as np 38 | from PIL import Image 39 | from torch.utils.data import Dataset, DataLoader 40 | 41 | try: 42 | from tqdm import tqdm 43 | except ImportError: 44 | # If tqdm is not available, provide a mock version of it 45 | def tqdm(x): 46 | return x 47 | 48 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 49 | 'tif', 'tiff', 'webp'} 50 | 51 | TEXT_EXTENSIONS = {'txt'} 52 | 53 | 54 | class DummyDataset(Dataset): 55 | FLAGS = ['img', 'txt', 'npy'] 56 | 57 | def __init__(self, real_path, fake_path, 58 | real_flag: str = 'img', 59 | fake_flag: str = 'img', 60 | transform=None, 61 | tokenizer=None, 62 | mode="orig") -> None: 63 | super().__init__() 64 | assert real_flag in self.FLAGS and fake_flag in self.FLAGS, \ 65 | "Got unexpected modality flag: {} or {}".format(real_flag, fake_flag) 66 | self.real_folder = self._combine_without_prefix(real_path) 67 | self.real_flag = real_flag 68 | self.fake_foler = self._combine_without_prefix(fake_path) 69 | self.fake_flag = fake_flag 70 | self.transform = transform 71 | self.tokenizer = tokenizer 72 | self.mode = mode 73 | # assert self._check() 74 | 75 | def __len__(self): 76 | return len(self.real_folder) 77 | 78 | def __getitem__(self, index): 79 | if index >= len(self): 80 | raise IndexError 81 | real_path = self.real_folder[index] 82 | fake_path = self.fake_foler[index] 83 | real_data = self._load_modality(real_path, self.real_flag, mode=self.mode) 84 | fake_data = self._load_modality(fake_path, self.fake_flag, mode="orig") 85 | 86 | sample = dict(real=real_data, fake=fake_data, name=os.path.basename(real_path)) 87 | return sample 88 | 89 | def _load_modality(self, path, modality, mode="orig"): 90 | if modality == 'img': 91 | data = self._load_img(path) 92 | elif modality == 'txt': 93 | data = self._load_txt(path) 94 | elif modality == 'npy': 95 | data = self._load_npy(path, mode=mode) 96 | else: 97 | raise TypeError("Got unexpected modality: {}".format(modality)) 98 | return data 99 | 100 | def _load_img(self, path): 101 | img = Image.open(path) 102 | if self.transform is not None: 103 | img = self.transform(img) 104 | return img 105 | 106 | def _load_npy(self, path, mode="orig"): 107 | data = np.load(path) 108 | if mode == 'stats': 109 | return data 110 | img = Image.fromarray(data) 111 | if self.transform is not None: 112 | img = self.transform(img) 113 | return img 114 | 115 | def _load_txt(self, path): 116 | with open(path, 'r') as fp: 117 | data = fp.read() 118 | fp.close() 119 | if self.tokenizer is not None: 120 | data = self.tokenizer(data).squeeze() 121 | return data 122 | 123 | def _check(self): 124 | for idx in range(len(self)): 125 | real_name = self.real_folder[idx].split('.') 126 | fake_name = self.fake_folder[idx].split('.') 127 | if fake_name != real_name: 128 | return False 129 | return True 130 | 131 | def _combine_without_prefix(self, folder_path, prefix='.'): 132 | folder = [] 133 | for name in os.listdir(folder_path): 134 | if name[0] == prefix: 135 | continue 136 | folder.append(osp.join(folder_path, name)) 137 | folder.sort() 138 | return folder 139 | 140 | 141 | @torch.no_grad() 142 | def calculate_clip_score(dataloader, model, real_flag, fake_flag): 143 | score_acc = 0. 144 | sample_num = 0. 145 | logit_scale = model.logit_scale.exp() 146 | 147 | for batch_data in tqdm(dataloader): 148 | real = batch_data['real'] 149 | if real_flag == 'txt': 150 | real_features = forward_modality(model, real, real_flag) 151 | else: 152 | real_features = batch_data['real'] 153 | device = next(model.parameters()).device 154 | real_features = real_features.to(device) 155 | 156 | fake = batch_data['fake'] 157 | fake_features = forward_modality(model, fake, fake_flag) 158 | 159 | # normalize features 160 | real_features = real_features / real_features.norm(dim=1, keepdim=True).to(torch.float32) 161 | fake_features = fake_features / fake_features.norm(dim=1, keepdim=True).to(torch.float32) 162 | 163 | # calculate scores 164 | # score = logit_scale * real_features @ fake_features.t() 165 | # score_acc += torch.diag(score).sum() 166 | score = logit_scale * (fake_features * real_features).sum() 167 | score_acc += score 168 | sample_num += real.shape[0] 169 | 170 | return score_acc / sample_num 171 | 172 | 173 | @torch.no_grad() 174 | def get_clip_features(dataloader, model, flag): 175 | features = [] 176 | names = [] 177 | for batch_data in tqdm(dataloader): 178 | data = batch_data['real'] 179 | names.extend(batch_data['name']) 180 | feature = forward_modality(model, data, flag) 181 | feature = feature / feature.norm(dim=1, keepdim=True).to(torch.float32) 182 | features.append(feature) 183 | return torch.cat(features, dim=0), names 184 | 185 | 186 | def forward_modality(model, data, flag): 187 | device = next(model.parameters()).device 188 | if flag == 'img': 189 | features = model.encode_image(data.to(device)) 190 | elif flag == 'txt': 191 | features = model.encode_text(data.to(device)) 192 | else: 193 | raise TypeError 194 | return features 195 | 196 | 197 | def clip_score(real_path, fake_path, clip_model='ViT-B/32', num_workers=None, batch_size=64): 198 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 199 | 200 | if num_workers is None: 201 | try: 202 | num_cpus = len(os.sched_getaffinity(0)) 203 | except AttributeError: 204 | num_cpus = os.cpu_count() 205 | 206 | num_workers = min(num_cpus, 8) if num_cpus is not None else 0 207 | 208 | print('Loading CLIP model: {}'.format(clip_model)) 209 | model, preprocess = clip.load(clip_model, device=device) 210 | 211 | dataset = DummyDataset(real_path, fake_path, 212 | "npy", "npy", 213 | transform=preprocess, tokenizer=clip.tokenize, mode="stats") 214 | dataloader = DataLoader(dataset, batch_size, 215 | num_workers=num_workers, pin_memory=True) 216 | 217 | print('Calculating CLIP Score:') 218 | score = calculate_clip_score(dataloader, model, 219 | "npy", "img") 220 | print(f'{clip_model.replace("/", "-")} CLIP Score: {score:.4f}') 221 | return score 222 | 223 | 224 | def clip_features(dataset_path, clip_model='ViT-B/32', num_workers=None, batch_size=64): 225 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 226 | if num_workers is None: 227 | try: 228 | num_cpus = len(os.sched_getaffinity(0)) 229 | except AttributeError: 230 | num_cpus = os.cpu_count() 231 | 232 | num_workers = min(num_cpus, 8) if num_cpus is not None else 0 233 | 234 | print('Loading CLIP model: {}'.format(clip_model)) 235 | model, preprocess = clip.load(clip_model, device=device) 236 | 237 | # find the extension of the dataset files 238 | extension = os.listdir(dataset_path)[0].split('.')[-1] 239 | if extension in IMAGE_EXTENSIONS: 240 | flag = 'img' 241 | elif extension in TEXT_EXTENSIONS: 242 | flag = 'txt' 243 | elif extension == 'npy': 244 | flag = 'npy' 245 | else: 246 | raise TypeError("Got unexpected extension: {}".format(extension)) 247 | 248 | dataset = DummyDataset(dataset_path, dataset_path, 249 | flag, flag, 250 | transform=preprocess, tokenizer=clip.tokenize) 251 | 252 | dataloader = DataLoader(dataset, batch_size, 253 | num_workers=num_workers, pin_memory=True) 254 | 255 | print('Calculating CLIP Features:') 256 | features, names = get_clip_features(dataloader, model, 'txt') 257 | features = features.cpu().numpy() 258 | save_path = os.path.join(os.path.dirname(dataset_path), f'{clip_model.replace("/", "-")}_clip_features') 259 | os.makedirs(save_path, exist_ok=True) 260 | 261 | for i, name in enumerate(names): 262 | np.save(os.path.join(save_path, f'{name[:-4]}.npy'), features[i]) 263 | print('CLIP Features saved successfully!') 264 | -------------------------------------------------------------------------------- /pdm/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from pdm.datasets import load_cc3m_dataset, load_coco_dataset 3 | import os 4 | import random 5 | import numpy as np 6 | import requests 7 | from torchvision import transforms 8 | import torch 9 | import PIL 10 | 11 | 12 | def get_dataset(config): 13 | # Get the datasets: you can either provide your own training and evaluation files (see below) 14 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 15 | 16 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 17 | # download the dataset. 18 | data_dir = getattr(config, "data_dir", None) 19 | 20 | train_data_dir = getattr(config, "train_data_dir", None) 21 | train_data_file = getattr(config, "train_data_file", None) 22 | 23 | validation_data_dir = getattr(config, "validation_data_dir", None) 24 | validation_data_file = getattr(config, "validation_data_file", None) 25 | 26 | if "conceptual_captions" in data_dir: 27 | dataset = {"train": load_cc3m_dataset(data_dir, 28 | split="train", 29 | split_file=train_data_file, 30 | split_dir=train_data_dir)} 31 | if validation_data_dir is not None: 32 | dataset["validation"] = load_cc3m_dataset(data_dir, 33 | split="validation", 34 | split_file=validation_data_file, 35 | split_dir=validation_data_dir) 36 | 37 | elif "coco" in data_dir: 38 | year = config.year 39 | dataset = {"train": load_coco_dataset(os.path.join(data_dir, "images", f"train{year}"), 40 | os.path.join(data_dir, "annotations", f"captions_train{year}.json")), 41 | "validation": load_coco_dataset(os.path.join(data_dir, "images", f"val{year}"), 42 | os.path.join(data_dir, "annotations", 43 | f"captions_val{year}.json"))} 44 | 45 | else: 46 | data_files = {} 47 | if config.data_dir is not None: 48 | data_files["train"] = os.path.join(config.data_dir, "**") 49 | dataset = load_dataset( 50 | "imagefolder", 51 | data_files=data_files, 52 | cache_dir=config.cache_dir, 53 | ) 54 | # See more about loading custom images at 55 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder 56 | 57 | return dataset 58 | 59 | 60 | def get_transforms(config): 61 | train_transform = transforms.Compose( 62 | [ 63 | transforms.Resize(config.model.unet.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 64 | transforms.CenterCrop( 65 | config.model.unet.resolution) if config.data.dataloader.center_crop else transforms.RandomCrop( 66 | config.model.unet.resolution), 67 | transforms.RandomHorizontalFlip() if config.data.dataloader.random_flip else transforms.Lambda(lambda x: x), 68 | transforms.ToTensor(), 69 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 70 | ] 71 | ) 72 | 73 | validation_transform = transforms.Compose( 74 | [ 75 | transforms.Resize(config.model.unet.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 76 | transforms.CenterCrop( 77 | config.model.unet.resolution) if config.data.dataloader.center_crop else transforms.RandomCrop( 78 | config.model.unet.resolution), 79 | transforms.ToTensor(), 80 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 81 | ] 82 | ) 83 | 84 | return train_transform, validation_transform 85 | 86 | 87 | def download_images_if_missing(samples, image_column): 88 | if isinstance(samples[image_column][0], str): 89 | if not os.path.exists(samples[image_column][0]): 90 | downloaded_images = [] 91 | for image in samples[image_column]: 92 | try: 93 | # download image and convert it to a PIL image 94 | downloaded_images.append(PIL.Image.open(requests.get(image, stream=True).raw)) 95 | except: 96 | # remove the caption if the image is not found 97 | downloaded_images.append(None) 98 | samples[image_column] = downloaded_images 99 | else: 100 | imgs = [] 101 | for image in samples[image_column]: 102 | try: 103 | imgs.append(PIL.Image.open(image)) 104 | except: 105 | samples[image_column] = None 106 | continue 107 | samples[image_column] = imgs 108 | return samples 109 | 110 | 111 | def tokenize_captions(samples, tokenizer, is_train=True): 112 | captions = [] 113 | for caption in samples: 114 | if isinstance(caption, str): 115 | captions.append(caption) 116 | elif isinstance(caption, (list, np.ndarray)): 117 | # take a random caption if there are multiple 118 | captions.append(random.choice(caption) if is_train else caption[0]) 119 | else: 120 | raise ValueError( 121 | f"Caption column should contain either strings or lists of strings." 122 | ) 123 | 124 | inputs = tokenizer( 125 | captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" 126 | ) 127 | return inputs.input_ids 128 | 129 | 130 | def get_mpnet_embeddings(capts, mpnet_model, mpnet_tokenizer, is_train=True): 131 | # Mean Pooling - Take attention mask into account for correct averaging 132 | def mean_pooling(model_output, attention_mask): 133 | token_embeddings = model_output[0] # First element of model_output contains all token embeddings 134 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 135 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), 136 | min=1e-9) 137 | 138 | captions = [] 139 | for caption in capts: 140 | if isinstance(caption, str): 141 | captions.append(caption) 142 | elif isinstance(caption, (list, np.ndarray)): 143 | # take a random caption if there are multiple 144 | captions.append(random.choice(caption) if is_train else caption[0]) 145 | else: 146 | raise ValueError( 147 | f"Caption column should contain either strings or lists of strings." 148 | ) 149 | 150 | encoded_input = mpnet_tokenizer(captions, padding=True, truncation=True, return_tensors="pt") 151 | with torch.no_grad(): 152 | encoded_input = encoded_input.to(mpnet_model.device) 153 | model_output = mpnet_model(**encoded_input) 154 | sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) 155 | return sentence_embeddings 156 | 157 | 158 | def preprocess_samples(samples, tokenizer, mpnet_model, mpnet_tokenizer, transform, image_column="image", 159 | caption_column="caption", is_train=True): 160 | samples = download_images_if_missing(samples, image_column) 161 | images = [image.convert("RGB") if image is not None else image for image in samples[image_column]] 162 | samples["pixel_values"] = [transform(image) if image is not None else image for image in images] 163 | samples["input_ids"] = tokenize_captions(samples[caption_column], tokenizer=tokenizer, is_train=is_train) 164 | samples["mpnet_embeddings"] = get_mpnet_embeddings(samples[caption_column], mpnet_model=mpnet_model, 165 | mpnet_tokenizer=mpnet_tokenizer, is_train=is_train) 166 | return samples 167 | 168 | 169 | def preprocess_prompts(samples, mpnet_model, mpnet_tokenizer): 170 | samples["mpnet_embeddings"] = get_mpnet_embeddings(samples["prompts"], mpnet_model=mpnet_model, 171 | mpnet_tokenizer=mpnet_tokenizer, is_train=False) 172 | return samples 173 | 174 | 175 | def collate_fn(samples): 176 | samples = [sample for sample in samples if sample["pixel_values"] is not None] 177 | if len(samples) == 0: 178 | return {"pixel_values": torch.tensor([]), "input_ids": torch.tensor([]), 179 | "mpnet_embeddings": torch.tensor([])} 180 | pixel_values = torch.stack([sample["pixel_values"] for sample in samples]) 181 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 182 | input_ids = torch.stack([sample["input_ids"] for sample in samples]) 183 | mpnet_embeddings = torch.stack([sample["mpnet_embeddings"] for sample in samples]) 184 | mpnet_embeddings = mpnet_embeddings.to(memory_format=torch.contiguous_format).float() 185 | return {"pixel_values": pixel_values, "input_ids": input_ids, "mpnet_embeddings": mpnet_embeddings} 186 | 187 | 188 | def prompts_collator(samples): 189 | prompts = [sample["prompts"] for sample in samples] 190 | prompt_embdeddings = torch.stack([sample["mpnet_embeddings"] for sample in samples]) 191 | prompt_embdeddings = prompt_embdeddings.to(memory_format=torch.contiguous_format).float() 192 | return {"prompts": prompts, "mpnet_embeddings": prompt_embdeddings} 193 | 194 | 195 | def filter_dataset(dataset, hyper_net, quantizer, mpnet_model, mpnet_tokenizer, caption_column="caption"): 196 | train_captions = dataset["train"][caption_column] 197 | validation_captions = dataset["validation"][caption_column] 198 | train_filtering_dataloader = torch.utils.data.DataLoader(train_captions, batch_size=2048, shuffle=False) 199 | validation_filtering_dataloader = torch.utils.data.DataLoader(validation_captions, batch_size=2048, 200 | shuffle=False) 201 | device = "cuda" if torch.cuda.is_available() else "cpu" 202 | hyper_net.to(device) 203 | quantizer.to(device) 204 | mpnet_model.to(device) 205 | hyper_net.eval() 206 | quantizer.eval() 207 | train_indices = [] 208 | validation_indices = [] 209 | with torch.no_grad(): 210 | for batch in train_filtering_dataloader: 211 | batch = get_mpnet_embeddings(batch, mpnet_model, mpnet_tokenizer, is_train=True) 212 | arch_v = hyper_net(batch) 213 | indices = quantizer.get_cosine_sim_min_encoding_indices(arch_v) 214 | train_indices.append(indices) 215 | for batch in validation_filtering_dataloader: 216 | batch = get_mpnet_embeddings(batch, mpnet_model, mpnet_tokenizer, is_train=False) 217 | arch_v = hyper_net(batch) 218 | indices = quantizer.get_cosine_sim_min_encoding_indices(arch_v) 219 | validation_indices.append(indices) 220 | train_indices = torch.cat(train_indices, dim=0) 221 | validation_indices = torch.cat(validation_indices, dim=0) 222 | 223 | return train_indices, validation_indices 224 | -------------------------------------------------------------------------------- /pdm/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | def deepspeed_zero_init_disabled_context_manager(): 2 | import accelerate 3 | from accelerate.state import AcceleratorState 4 | """ 5 | returns either a context list that includes one that will disable zero.Init or an empty context list 6 | """ 7 | deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None 8 | if deepspeed_plugin is None: 9 | return [] 10 | 11 | return [deepspeed_plugin.zero3_init_context_manager(enable=False)] 12 | 13 | 14 | def nodesplitter(src, group=None): 15 | import torch 16 | if torch.distributed.is_initialized(): 17 | if group is None: 18 | group = torch.distributed.group.WORLD 19 | rank = torch.distributed.get_rank(group=group) 20 | size = torch.distributed.get_world_size(group=group) 21 | print(f"nodesplitter: rank={rank} size={size}") 22 | count = 0 23 | for i, item in enumerate(src): 24 | if i % size == rank: 25 | yield item 26 | count += 1 27 | print(f"nodesplitter: rank={rank} size={size} count={count} DONE") 28 | else: 29 | yield from src 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /pdm/utils/estimation_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def sample_gumbel(shape, eps=1e-20, fixed_seed=False): 6 | if fixed_seed: 7 | u = torch.rand(shape, generator=torch.Generator().manual_seed(0)) 8 | else: 9 | u = torch.rand(shape) 10 | return -torch.log(-torch.log(u + eps) + eps) 11 | 12 | 13 | def vector_gumbel_softmax(logits, temperature, offset=0, force_width_non_zero=False, fixed_seed=False): 14 | gumbel_sample = sample_gumbel(logits.size(), fixed_seed=fixed_seed) 15 | if logits.is_cuda: 16 | gumbel_sample = gumbel_sample.cuda() 17 | 18 | y = logits + gumbel_sample + offset 19 | y_out = F.sigmoid(y / temperature) 20 | if not force_width_non_zero: 21 | return y_out 22 | else: 23 | y_out_h = hard_concrete(y_out).sum(dim=1) 24 | if (y_out_h > 0).all(): 25 | return y_out 26 | else: 27 | y_out_h = hard_concrete(y_out).sum(dim=1) 28 | ind = (y_out_h == 0) 29 | new_y_out = y_out.clone() 30 | new_y_out[ind, 0] = y_out[ind, 0] + 0.5 31 | return new_y_out 32 | 33 | 34 | def gumbel_softmax_sample(logits, temperature, offset=0, force_width_non_zero=False, fixed_seed=False): 35 | if not force_width_non_zero: 36 | gumbel_sample = sample_gumbel(logits.size(), fixed_seed=fixed_seed) 37 | if logits.is_cuda: 38 | gumbel_sample = gumbel_sample.cuda() 39 | 40 | y = logits + gumbel_sample + offset 41 | return F.sigmoid(y / temperature) 42 | 43 | else: 44 | y_out = vector_gumbel_softmax(logits=logits, temperature=temperature, offset=offset, force_width_non_zero=True, 45 | fixed_seed=fixed_seed) 46 | return y_out 47 | 48 | 49 | def importance_gumbel_softmax_sample(logits, temperature, offset=0, fixed_seed=False): 50 | x = torch.softmax(logits, dim=1) 51 | x = torch.cumsum(x, dim=1) 52 | x = torch.flip(x, dims=[1]) 53 | 54 | eps = 1e-6 55 | # inverse sigmoid function. add eps to avoid numerical instability. 56 | x = torch.log(x + eps) - torch.log1p(-(x - eps)) 57 | 58 | gumbel_sample = sample_gumbel(x.size(), fixed_seed=fixed_seed) 59 | if logits.is_cuda: 60 | gumbel_sample = gumbel_sample.cuda() 61 | x = x.cuda() 62 | 63 | y = x + gumbel_sample + offset 64 | return F.sigmoid(y / temperature) 65 | 66 | 67 | def hard_concrete(out): 68 | out_hard = torch.zeros(out.size()) 69 | out_hard[out >= 0.5] = 1 70 | out_hard[out < 0.5] = 0 71 | if out.is_cuda: 72 | out_hard = out_hard.cuda() 73 | # Straight through estimation 74 | out_hard = (out_hard - out).detach() + out 75 | return out_hard 76 | -------------------------------------------------------------------------------- /pdm/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from PIL import Image, ImageDraw, ImageFont 4 | import seaborn as sns 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def create_image_grid_from_indices(indices, grid_size=(5, 5), image_size=(256, 256), font_size=40): 9 | # Create a white background image 10 | grid_width = grid_size[0] * image_size[0] 11 | grid_height = grid_size[1] * image_size[1] 12 | background = Image.new('RGB', (grid_width, grid_height), 'white') 13 | 14 | # Create a draw object 15 | draw = ImageDraw.Draw(background) 16 | 17 | # Use a larger font 18 | try: 19 | font = ImageFont.truetype("Arial.ttf", font_size) 20 | except: 21 | font = ImageFont.load_default() 22 | 23 | # Iterate through indices and place them on the grid 24 | for i, index in enumerate(indices): 25 | row = i // grid_size[0] 26 | col = i % grid_size[0] 27 | 28 | # Calculate the position to place the text 29 | x = col * image_size[0] + (image_size[0] - font_size) // 2 30 | y = row * image_size[1] + (image_size[1] - font_size) // 2 31 | 32 | # Draw the index on the image 33 | draw.text((x, y), str(index), font=font, fill='black') 34 | 35 | # Save or display the resulting image 36 | return background 37 | 38 | 39 | def create_heatmap(data, n_rows, n_cols): 40 | plt.figure() 41 | data = data.reshape(n_rows, n_cols) 42 | fig = sns.heatmap(data, cmap='Blues', linewidth=0.5, xticklabels=False, yticklabels=False).get_figure() 43 | return fig 44 | 45 | 46 | def init_logging(config): 47 | config.training.logging.logging_dir = os.path.join(config.training.logging.logging_dir, 48 | os.getcwd().split('/')[-2], 49 | config.base_config_path.split('/')[-2], 50 | config.base_config_path.split('/')[-1].split('.')[0], 51 | config.wandb_run_name 52 | ) 53 | 54 | os.makedirs(config.training.logging.logging_dir, exist_ok=True) 55 | 56 | # Make one log on every process with the configuration for debugging. 57 | logging.basicConfig( 58 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 59 | datefmt="%m/%d/%Y %H:%M:%S", 60 | level=logging.INFO, 61 | ) 62 | -------------------------------------------------------------------------------- /pdm/utils/metric_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def compute_snr(noise_scheduler, timesteps): 4 | """ 5 | Computes SNR as per 6 | https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 7 | """ 8 | alphas_cumprod = noise_scheduler.alphas_cumprod 9 | sqrt_alphas_cumprod = alphas_cumprod ** 0.5 10 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 11 | 12 | # Expand the tensors. 13 | # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 14 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() 15 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): 16 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] 17 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) 18 | 19 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() 20 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): 21 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] 22 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) 23 | 24 | # Compute SNR. 25 | snr = (alpha / sigma) ** 2 26 | return snr 27 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "pdm" 7 | dynamic = ["version"] 8 | description = "Pruning Diffusion Models" 9 | readme = "README.md" 10 | license-files = { paths = ["LICENSE"] } 11 | requires-python = ">=3.8" 12 | 13 | [project.urls] 14 | Homepage = "https://github.com/rezashkv/diffusion_pruning" 15 | 16 | [tool.hatch.version] 17 | path = "pdm/__init__.py" 18 | 19 | [tool.hatch.build] 20 | # This needs to be explicitly set so the configuration files 21 | # grafted into the `sgm` directory get included in the wheel's 22 | # RECORD file. 23 | include = [ 24 | "pdm", 25 | ] 26 | 27 | 28 | [tool.hatch.envs.ci] 29 | skip-install = false 30 | 31 | dependencies = [ 32 | "pytest" 33 | ] 34 | 35 | [tool.hatch.envs.ci.scripts] 36 | 37 | -------------------------------------------------------------------------------- /scripts/aptp/filter_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from omegaconf import OmegaConf 5 | 6 | from PIL import Image 7 | 8 | import torch 9 | import torch.utils.checkpoint 10 | 11 | from accelerate.logging import get_logger 12 | from accelerate.utils import set_seed 13 | 14 | from diffusers.utils import check_min_version 15 | 16 | from transformers import AutoModel, AutoTokenizer 17 | 18 | from pdm.models import HyperStructure, StructureVectorQuantizer 19 | from pdm.utils.arg_utils import parse_args 20 | from pdm.utils.data_utils import get_dataset, filter_dataset 21 | 22 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 23 | check_min_version("0.22.0.dev0") 24 | 25 | logger = get_logger(__name__) 26 | 27 | 28 | def main(): 29 | Image.MAX_IMAGE_PIXELS = 933120000 30 | torch.autograd.set_detect_anomaly(True) 31 | 32 | args = parse_args() 33 | config = OmegaConf.load(args.base_config_path) 34 | # add args to config 35 | config.update(vars(args)) 36 | 37 | assert config.pruning_ckpt_dir is not None, "Please provide a path to the pruning checkpoint directory." 38 | 39 | if config.seed is not None: 40 | set_seed(config.seed) 41 | 42 | # Make one log on every process with the configuration for debugging. 43 | logging.basicConfig( 44 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 45 | datefmt="%m/%d/%Y %H:%M:%S", 46 | level=logging.INFO, 47 | ) 48 | 49 | # ########################### Hypernet and Quantizer for Dataset Preprocessing ##################################### 50 | 51 | hyper_net = HyperStructure.from_pretrained(config.pruning_ckpt_dir, subfolder="hypernet") 52 | quantizer = StructureVectorQuantizer.from_pretrained(config.pruning_ckpt_dir, subfolder="quantizer") 53 | 54 | mpnet_tokenizer = AutoTokenizer.from_pretrained(config.prompt_encoder_model_name_or_path) 55 | mpnet_model = AutoModel.from_pretrained(config.prompt_encoder_model_name_or_path) 56 | 57 | # #################################################### Datasets #################################################### 58 | 59 | logging.info("Loading datasets...") 60 | # Get the datasets: you can either provide your own training and evaluation files (see below) 61 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 62 | dataset = get_dataset(config.data) 63 | dataset_name = config.data.dataset_name 64 | column_names = dataset["train"].column_names 65 | 66 | caption_column = config.data.caption_column 67 | if caption_column not in column_names: 68 | raise ValueError( 69 | f"--caption_column' value '{config.data.caption_column}' needs to be one of: {', '.join(column_names)}" 70 | ) 71 | 72 | if not (os.path.exists(os.path.join(config.pruning_ckpt_dir, f"{dataset_name}_train_mapped_indices.pt")) and 73 | os.path.exists(os.path.join(config.pruning_ckpt_dir, f"{dataset_name}_validation_mapped_indices.pt"))): 74 | tr_indices, val_indices = filter_dataset(dataset, hyper_net, quantizer, mpnet_model, mpnet_tokenizer, 75 | caption_column=caption_column) 76 | torch.save(tr_indices, os.path.join(config.pruning_ckpt_dir, f"{dataset_name}_train_mapped_indices.pt")) 77 | torch.save(val_indices, os.path.join(config.pruning_ckpt_dir, f"{dataset_name}_validation_mapped_indices.pt")) 78 | 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /scripts/aptp/finetune.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | 3 | import PIL.Image 4 | 5 | import torch 6 | import torch.utils.checkpoint 7 | 8 | from accelerate.utils import set_seed 9 | from accelerate.logging import get_logger 10 | 11 | from pdm.utils.arg_utils import parse_args 12 | from pdm.utils.logging_utils import init_logging 13 | from pdm.training.trainer import FineTuner 14 | 15 | 16 | logger = get_logger(__name__) 17 | 18 | 19 | def main(): 20 | PIL.Image.MAX_IMAGE_PIXELS = 933120000 21 | torch.autograd.set_detect_anomaly(True) 22 | 23 | args = parse_args() 24 | config = OmegaConf.load(args.base_config_path) 25 | config.update(vars(args)) 26 | 27 | assert config.pruning_ckpt_dir is not None, "Please provide a path to the pruning checkpoint directory." 28 | assert config.expert_id is not None, "Please provide an expert index to finetune" 29 | 30 | if config.seed is not None: 31 | set_seed(config.seed) 32 | 33 | init_logging(config) 34 | 35 | # Enable TF32 for faster training on Ampere GPUs, 36 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 37 | if config.training.allow_tf32: 38 | torch.backends.cuda.matmul.allow_tf32 = True 39 | 40 | trainer = FineTuner(config=config) 41 | trainer.train() 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /scripts/aptp/prune.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | 3 | import PIL 4 | 5 | import torch 6 | import torch.utils.checkpoint 7 | 8 | from accelerate.utils import set_seed 9 | from accelerate.logging import get_logger 10 | 11 | from pdm.utils.arg_utils import parse_args 12 | from pdm.utils.logging_utils import init_logging 13 | from pdm.training.trainer import Pruner 14 | 15 | logger = get_logger(__name__) 16 | 17 | 18 | def main(): 19 | PIL.Image.MAX_IMAGE_PIXELS = 933120000 20 | torch.autograd.set_detect_anomaly(True) 21 | 22 | args = parse_args() 23 | config = OmegaConf.load(args.base_config_path) 24 | config.update(vars(args)) 25 | 26 | if config.seed is not None: 27 | set_seed(config.seed) 28 | 29 | init_logging(config) 30 | 31 | # Enable TF32 for faster training on Ampere GPUs, 32 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 33 | if config.training.allow_tf32: 34 | torch.backends.cuda.matmul.allow_tf32 = True 35 | 36 | pruner = Pruner(config=config) 37 | 38 | pruner.train() 39 | 40 | 41 | if __name__ == "__main__": 42 | main() 43 | -------------------------------------------------------------------------------- /scripts/baselines/magnitude/finetune_magnitude.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | 3 | import PIL.Image 4 | 5 | import torch 6 | import torch.utils.checkpoint 7 | 8 | from accelerate.utils import set_seed 9 | from accelerate.logging import get_logger 10 | 11 | from pdm.utils.arg_utils import parse_args 12 | from pdm.utils.logging_utils import init_logging 13 | from pdm.training.trainer import BaselineFineTuner 14 | 15 | logger = get_logger(__name__) 16 | 17 | 18 | def main(): 19 | PIL.Image.MAX_IMAGE_PIXELS = 933120000 20 | torch.autograd.set_detect_anomaly(True) 21 | 22 | args = parse_args() 23 | config = OmegaConf.load(args.base_config_path) 24 | config.update(vars(args)) 25 | 26 | assert config.training.pruning_method is not None, "Please provide a method for norm pruning" 27 | assert config.training.pruning_target is not None, "Please provide a target for norm pruning" 28 | 29 | if config.seed is not None: 30 | set_seed(config.seed) 31 | 32 | init_logging(config) 33 | 34 | # Enable TF32 for faster training on Ampere GPUs, 35 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 36 | if config.training.allow_tf32: 37 | torch.backends.cuda.matmul.allow_tf32 = True 38 | 39 | trainer = BaselineFineTuner(config=config, pruning_type="magnitude") 40 | trainer.train() 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /scripts/baselines/magnitude/generate_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | 4 | from omegaconf import OmegaConf 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | import torch 10 | import torch.utils.checkpoint 11 | 12 | import accelerate 13 | from accelerate.logging import get_logger 14 | from accelerate.utils import set_seed 15 | 16 | from diffusers import PNDMScheduler, UNet2DConditionModel 17 | from diffusers.utils import check_min_version 18 | from diffusers import StableDiffusionPipeline 19 | 20 | import safetensors 21 | 22 | from transformers import CLIPTextModel 23 | 24 | from pdm.models.unet import UNet2DConditionModelMagnitudePruned 25 | from pdm.utils.arg_utils import parse_args 26 | from pdm.utils.data_utils import get_dataset 27 | 28 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 29 | check_min_version("0.22.0.dev0") 30 | 31 | logger = get_logger(__name__) 32 | 33 | 34 | def main(): 35 | args = parse_args() 36 | config = OmegaConf.load(args.base_config_path) 37 | # add args to config 38 | config.update(vars(args)) 39 | 40 | if config.seed is not None: 41 | set_seed(config.seed) 42 | 43 | # #################################################### Accelerator ################################################# 44 | accelerator = accelerate.Accelerator() 45 | 46 | # #################################################### Datasets #################################################### 47 | 48 | logger.info("Loading datasets...") 49 | # Get the datasets: you can either provide your own training and evaluation files (see below) 50 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 51 | 52 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 53 | # download the dataset. 54 | dataset_name = getattr(config.data, "dataset_name", None) 55 | img_col = getattr(config.data, "image_column", "image") 56 | capt_col = getattr(config.data, "caption_column", "caption") 57 | 58 | assert config.expert_id is not None, "expert index must be provided" 59 | assert config.finetuning_ckpt_dir is not None, "finetuning checkpoint directory must be provided" 60 | 61 | def collate_fn(examples, caption_column="caption", image_column="image"): 62 | captions = [example[caption_column] for example in examples] 63 | images = [example[image_column] for example in examples] 64 | return {"image": images, "caption": captions} 65 | 66 | dataset = get_dataset(config.data) 67 | 68 | dataset = dataset["validation"] 69 | fid_val_indices_path = os.path.abspath( 70 | os.path.join(config.finetuning_ckpt_dir, "..", "..", f"{dataset_name}_validation_mapped_indices.pt")) 71 | assert os.path.exists(fid_val_indices_path), \ 72 | (f"{dataset_name}_validation_mapped_indices.pt must be present in two upper directory of the checkpoint" 73 | f" directory {config.finetuning_ckpt_dir}") 74 | val_indices = torch.load(fid_val_indices_path, map_location="cpu") 75 | 76 | dataset = dataset.select(torch.where(val_indices == config.expert_id)[0]) 77 | logger.info("Dataset of size %d loaded." % len(dataset)) 78 | 79 | dataloader = torch.utils.data.DataLoader( 80 | dataset, 81 | shuffle=False, 82 | batch_size=config.data.dataloader.image_generation_batch_size * accelerator.num_processes, 83 | num_workers=config.data.dataloader.dataloader_num_workers, 84 | collate_fn=partial(collate_fn, caption_column=capt_col, image_column=img_col), 85 | ) 86 | 87 | dataloader = accelerator.prepare(dataloader) 88 | 89 | # #################################################### Models #################################################### 90 | text_encoder = CLIPTextModel.from_pretrained( 91 | config.pretrained_model_name_or_path, subfolder="text_encoder", revision=config.revision 92 | ) 93 | 94 | teacher_unet = UNet2DConditionModel.from_pretrained( 95 | config.pretrained_model_name_or_path, 96 | subfolder="unet", 97 | revision=config.revision, 98 | ) 99 | 100 | sample_inputs = {'sample': torch.randn(1, teacher_unet.config.in_channels, teacher_unet.config.sample_size, teacher_unet.config.sample_size), 101 | 'timestep': torch.ones((1,)).long(), 102 | 'encoder_hidden_states': text_encoder(torch.tensor([[100]]))[0], 103 | } 104 | 105 | unet = UNet2DConditionModelMagnitudePruned.from_pretrained( 106 | config.pretrained_model_name_or_path, 107 | subfolder="unet", 108 | revision=config.non_ema_revision, 109 | target_pruning_rate=config.training.pruning_target, 110 | pruning_method=config.training.pruning_method, 111 | sample_inputs=sample_inputs 112 | ) 113 | 114 | state_dict = safetensors.torch.load_file(os.path.join(config.finetuning_ckpt_dir, "unet", 115 | "diffusion_pytorch_model.safetensors")) 116 | unet.load_state_dict(state_dict) 117 | 118 | noise_scheduler = PNDMScheduler.from_pretrained(config.pretrained_model_name_or_path, subfolder="scheduler") 119 | pipeline = StableDiffusionPipeline.from_pretrained( 120 | config.pretrained_model_name_or_path, 121 | unet=unet, 122 | scheduler=noise_scheduler, 123 | ) 124 | 125 | if config.enable_xformers_memory_efficient_attention: 126 | pipeline.enable_xformers_memory_efficient_attention() 127 | 128 | pipeline.set_progress_bar_config(disable=not accelerator.is_main_process) 129 | 130 | pipeline.to(accelerator.device) 131 | 132 | image_output_dir = os.path.join(config.finetuning_ckpt_dir, f"{dataset_name}_fid_images") 133 | os.makedirs(image_output_dir, exist_ok=True) 134 | 135 | for batch in dataloader: 136 | if config.seed is None: 137 | generator = None 138 | else: 139 | generator = torch.Generator(device=accelerator.device).manual_seed(config.seed) 140 | gen_images = pipeline(batch[capt_col], num_inference_steps=config.training.num_inference_steps, 141 | generator=generator, output_type="np" 142 | ).images 143 | 144 | for idx, caption in enumerate(batch[capt_col]): 145 | image_name = batch["image"][idx].split("/")[-1] 146 | if image_name.endswith(".jpg"): 147 | image_name = image_name[:-4] 148 | 149 | image_path = os.path.join(image_output_dir, f"{image_name}.npy") 150 | img = gen_images[idx] 151 | img = img * 255 152 | img = img.astype(np.uint8) 153 | img = cv2.resize(img, (256, 256)) 154 | np.save(image_path, img) 155 | 156 | 157 | if __name__ == "__main__": 158 | main() 159 | -------------------------------------------------------------------------------- /scripts/baselines/random/finetune_random_arch.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | 3 | import PIL.Image 4 | 5 | import torch 6 | import torch.utils.checkpoint 7 | 8 | from accelerate.utils import set_seed 9 | from accelerate.logging import get_logger 10 | 11 | from pdm.utils.arg_utils import parse_args 12 | from pdm.utils.logging_utils import init_logging 13 | from pdm.training.trainer import BaselineFineTuner 14 | 15 | logger = get_logger(__name__) 16 | 17 | 18 | def main(): 19 | PIL.Image.MAX_IMAGE_PIXELS = 933120000 20 | torch.autograd.set_detect_anomaly(True) 21 | 22 | args = parse_args() 23 | config = OmegaConf.load(args.base_config_path) 24 | config.update(vars(args)) 25 | 26 | config.training.random_pruning_ratio = 0.5 27 | 28 | assert config.training.random_pruning_ratio is not None, "Please provide a ratio for random pruning" 29 | 30 | if config.seed is not None: 31 | set_seed(config.seed) 32 | 33 | init_logging(config) 34 | 35 | # Enable TF32 for faster training on Ampere GPUs, 36 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 37 | if config.training.allow_tf32: 38 | torch.backends.cuda.matmul.allow_tf32 = True 39 | 40 | trainer = BaselineFineTuner(config=config, pruning_type="random") 41 | trainer.train() 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /scripts/baselines/sd/finetune_sd.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | 3 | import PIL.Image 4 | 5 | import torch 6 | import torch.utils.checkpoint 7 | 8 | from accelerate.utils import set_seed 9 | from accelerate.logging import get_logger 10 | 11 | from pdm.utils.arg_utils import parse_args 12 | from pdm.utils.logging_utils import init_logging 13 | from pdm.training.trainer import BaselineFineTuner 14 | 15 | logger = get_logger(__name__) 16 | 17 | 18 | def main(): 19 | PIL.Image.MAX_IMAGE_PIXELS = 933120000 20 | torch.autograd.set_detect_anomaly(True) 21 | 22 | args = parse_args() 23 | config = OmegaConf.load(args.base_config_path) 24 | config.update(vars(args)) 25 | 26 | assert config.training.pruning_method is not None, "Please provide a method for norm pruning" 27 | 28 | if config.seed is not None: 29 | set_seed(config.seed) 30 | 31 | init_logging(config) 32 | 33 | # Enable TF32 for faster training on Ampere GPUs, 34 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 35 | if config.training.allow_tf32: 36 | torch.backends.cuda.matmul.allow_tf32 = True 37 | 38 | trainer = BaselineFineTuner(config=config, pruning_type="no-pruning") 39 | trainer.train() 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /scripts/baselines/sd/generate_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | 4 | from omegaconf import OmegaConf 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | import torch 10 | import torch.utils.checkpoint 11 | 12 | import accelerate 13 | from accelerate.logging import get_logger 14 | from accelerate.utils import set_seed 15 | 16 | from diffusers import PNDMScheduler, UNet2DConditionModel 17 | from diffusers.utils import check_min_version 18 | from diffusers import StableDiffusionPipeline 19 | 20 | import safetensors 21 | 22 | from transformers import CLIPTextModel 23 | 24 | from pdm.models.unet import UNet2DConditionModelMagnitudePruned 25 | from pdm.utils.arg_utils import parse_args 26 | from pdm.utils.data_utils import get_dataset 27 | 28 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 29 | check_min_version("0.22.0.dev0") 30 | 31 | logger = get_logger(__name__) 32 | 33 | 34 | def main(): 35 | args = parse_args() 36 | config = OmegaConf.load(args.base_config_path) 37 | # add args to config 38 | config.update(vars(args)) 39 | 40 | if config.seed is not None: 41 | set_seed(config.seed) 42 | 43 | # #################################################### Accelerator ################################################# 44 | accelerator = accelerate.Accelerator() 45 | 46 | # #################################################### Datasets #################################################### 47 | 48 | logger.info("Loading datasets...") 49 | # Get the datasets: you can either provide your own training and evaluation files (see below) 50 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 51 | 52 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 53 | # download the dataset. 54 | dataset_name = getattr(config.data, "dataset_name", None) 55 | img_col = getattr(config.data, "image_column", "image") 56 | capt_col = getattr(config.data, "caption_column", "caption") 57 | 58 | assert config.finetuning_ckpt_dir is not None, "finetuning checkpoint directory must be provided" 59 | 60 | def collate_fn(examples, caption_column="caption", image_column="image"): 61 | captions = [example[caption_column] for example in examples] 62 | images = [example[image_column] for example in examples] 63 | return {"image": images, "caption": captions} 64 | 65 | dataset = get_dataset(config.data) 66 | 67 | dataset = dataset["validation"] 68 | logger.info("Dataset of size %d loaded." % len(dataset)) 69 | 70 | dataloader = torch.utils.data.DataLoader( 71 | dataset, 72 | shuffle=False, 73 | batch_size=config.data.dataloader.image_generation_batch_size * accelerator.num_processes, 74 | num_workers=config.data.dataloader.dataloader_num_workers, 75 | collate_fn=partial(collate_fn, caption_column=capt_col, image_column=img_col), 76 | ) 77 | 78 | dataloader = accelerator.prepare(dataloader) 79 | 80 | # #################################################### Models #################################################### 81 | # unet = UNet2DConditionModel.from_pretrained( 82 | # config.pretrained_model_name_or_path, 83 | # subfolder="unet", 84 | # revision=config.revision, 85 | # ) 86 | # 87 | # state_dict = safetensors.torch.load_file(os.path.join(config.finetuning_ckpt_dir, "unet", 88 | # "diffusion_pytorch_model.safetensors")) 89 | # unet.load_state_dict(state_dict) 90 | 91 | noise_scheduler = PNDMScheduler.from_pretrained(config.pretrained_model_name_or_path, subfolder="scheduler") 92 | pipeline = StableDiffusionPipeline.from_pretrained( 93 | config.pretrained_model_name_or_path, 94 | scheduler=noise_scheduler, 95 | ) 96 | 97 | pipeline.load_lora_weights(config.finetuning_ckpt_dir) 98 | 99 | if config.enable_xformers_memory_efficient_attention: 100 | pipeline.enable_xformers_memory_efficient_attention() 101 | 102 | pipeline.set_progress_bar_config(disable=not accelerator.is_main_process) 103 | 104 | pipeline.to(accelerator.device) 105 | 106 | image_output_dir = os.path.join(config.finetuning_ckpt_dir, f"{dataset_name}_fid_images") 107 | os.makedirs(image_output_dir, exist_ok=True) 108 | 109 | for batch in dataloader: 110 | if config.seed is None: 111 | generator = None 112 | else: 113 | generator = torch.Generator(device=accelerator.device).manual_seed(config.seed) 114 | gen_images = pipeline(batch[capt_col], num_inference_steps=config.training.num_inference_steps, 115 | generator=generator, output_type="np" 116 | ).images 117 | 118 | for idx, caption in enumerate(batch[capt_col]): 119 | image_name = batch["image"][idx].split("/")[-1] 120 | if image_name.endswith(".jpg"): 121 | image_name = image_name[:-4] 122 | 123 | image_path = os.path.join(image_output_dir, f"{image_name}.npy") 124 | img = gen_images[idx] 125 | img = img * 255 126 | img = img.astype(np.uint8) 127 | img = cv2.resize(img, (256, 256)) 128 | np.save(image_path, img) 129 | 130 | 131 | if __name__ == "__main__": 132 | main() 133 | -------------------------------------------------------------------------------- /scripts/baselines/structural/finetune_structural.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | 3 | import PIL.Image 4 | 5 | import torch 6 | import torch.utils.checkpoint 7 | 8 | from accelerate.utils import set_seed 9 | from accelerate.logging import get_logger 10 | 11 | from pdm.utils.arg_utils import parse_args 12 | from pdm.utils.logging_utils import init_logging 13 | from pdm.training.trainer import BaselineFineTuner 14 | 15 | logger = get_logger(__name__) 16 | 17 | 18 | def main(): 19 | PIL.Image.MAX_IMAGE_PIXELS = 933120000 20 | torch.autograd.set_detect_anomaly(True) 21 | 22 | args = parse_args() 23 | config = OmegaConf.load(args.base_config_path) 24 | config.update(vars(args)) 25 | 26 | assert config.training.pruning_method is not None, "Please provide a method for norm pruning" 27 | assert config.training.pruning_target is not None, "Please provide a target for norm pruning" 28 | 29 | if config.seed is not None: 30 | set_seed(config.seed) 31 | 32 | init_logging(config) 33 | 34 | # Enable TF32 for faster training on Ampere GPUs, 35 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 36 | if config.training.allow_tf32: 37 | torch.backends.cuda.matmul.allow_tf32 = True 38 | 39 | trainer = BaselineFineTuner(config=config, pruning_type="structural") 40 | trainer.train() 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /scripts/baselines/structural/generate_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | 4 | from omegaconf import OmegaConf 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | import torch 10 | import torch.utils.checkpoint 11 | 12 | import accelerate 13 | from accelerate.logging import get_logger 14 | from accelerate.utils import set_seed 15 | 16 | from diffusers import PNDMScheduler, UNet2DConditionModel 17 | from diffusers.utils import check_min_version 18 | from diffusers import StableDiffusionPipeline 19 | 20 | import safetensors 21 | 22 | from transformers import CLIPTextModel 23 | 24 | from pdm.models.unet import UNet2DConditionModelMagnitudePruned 25 | from pdm.utils.arg_utils import parse_args 26 | from pdm.utils.data_utils import get_dataset 27 | 28 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 29 | check_min_version("0.22.0.dev0") 30 | 31 | logger = get_logger(__name__) 32 | 33 | 34 | def main(): 35 | args = parse_args() 36 | config = OmegaConf.load(args.base_config_path) 37 | # add args to config 38 | config.update(vars(args)) 39 | 40 | if config.seed is not None: 41 | set_seed(config.seed) 42 | 43 | # #################################################### Accelerator ################################################# 44 | accelerator = accelerate.Accelerator() 45 | 46 | # #################################################### Datasets #################################################### 47 | 48 | logger.info("Loading datasets...") 49 | # Get the datasets: you can either provide your own training and evaluation files (see below) 50 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 51 | 52 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 53 | # download the dataset. 54 | dataset_name = getattr(config.data, "dataset_name", None) 55 | img_col = getattr(config.data, "image_column", "image") 56 | capt_col = getattr(config.data, "caption_column", "caption") 57 | 58 | assert config.pruning_ckpt_dir is not None, "pruning checkpoint directory must be provided" 59 | assert config.finetuning_ckpt_dir is not None, "finetuning checkpoint directory must be provided" 60 | 61 | def collate_fn(examples, caption_column="caption", image_column="image"): 62 | captions = [example[caption_column] for example in examples] 63 | images = [example[image_column] for example in examples] 64 | return {"image": images, "caption": captions} 65 | 66 | dataset = get_dataset(config.data) 67 | 68 | dataset = dataset["validation"] 69 | 70 | logger.info("Dataset of size %d loaded." % len(dataset)) 71 | 72 | dataloader = torch.utils.data.DataLoader( 73 | dataset, 74 | shuffle=False, 75 | batch_size=config.data.dataloader.image_generation_batch_size * accelerator.num_processes, 76 | num_workers=config.data.dataloader.dataloader_num_workers, 77 | collate_fn=partial(collate_fn, caption_column=capt_col, image_column=img_col), 78 | ) 79 | 80 | dataloader = accelerator.prepare(dataloader) 81 | 82 | # #################################################### Models #################################################### 83 | 84 | unet = torch.load(os.path.join(config.pruning_ckpt_dir, "unet_pruned.pth"), map_location="cpu") 85 | 86 | state_dict = safetensors.torch.load_file(os.path.join(config.finetuning_ckpt_dir, "unet", 87 | "diffusion_pytorch_model.safetensors")) 88 | unet.load_state_dict(state_dict) 89 | 90 | noise_scheduler = PNDMScheduler.from_pretrained(config.pretrained_model_name_or_path, subfolder="scheduler") 91 | pipeline = StableDiffusionPipeline.from_pretrained( 92 | config.pretrained_model_name_or_path, 93 | unet=unet, 94 | scheduler=noise_scheduler, 95 | ) 96 | 97 | if config.enable_xformers_memory_efficient_attention: 98 | pipeline.enable_xformers_memory_efficient_attention() 99 | 100 | pipeline.set_progress_bar_config(disable=not accelerator.is_main_process) 101 | 102 | pipeline.to(accelerator.device) 103 | 104 | image_output_dir = os.path.join(config.finetuning_ckpt_dir, f"{dataset_name}_fid_images") 105 | os.makedirs(image_output_dir, exist_ok=True) 106 | 107 | for batch in dataloader: 108 | if config.seed is None: 109 | generator = None 110 | else: 111 | generator = torch.Generator(device=accelerator.device).manual_seed(config.seed) 112 | gen_images = pipeline(batch[capt_col], num_inference_steps=config.training.num_inference_steps, 113 | generator=generator, output_type="np" 114 | ).images 115 | 116 | for idx, caption in enumerate(batch[capt_col]): 117 | image_name = batch["image"][idx].split("/")[-1] 118 | if image_name.endswith(".jpg"): 119 | image_name = image_name[:-4] 120 | 121 | image_path = os.path.join(image_output_dir, f"{image_name}.npy") 122 | img = gen_images[idx] 123 | img = img * 255 124 | img = img.astype(np.uint8) 125 | img = cv2.resize(img, (256, 256)) 126 | np.save(image_path, img) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /scripts/baselines/uni_arch/finetune_baseline_param.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | 3 | import PIL.Image 4 | 5 | import torch 6 | import torch.utils.checkpoint 7 | 8 | from accelerate.utils import set_seed 9 | from accelerate.logging import get_logger 10 | 11 | from pdm.utils.arg_utils import parse_args 12 | from pdm.utils.logging_utils import init_logging 13 | from pdm.training.trainer import SingleArchFinetuner 14 | 15 | 16 | logger = get_logger(__name__) 17 | 18 | 19 | def main(): 20 | PIL.Image.MAX_IMAGE_PIXELS = 933120000 21 | torch.autograd.set_detect_anomaly(True) 22 | 23 | args = parse_args() 24 | config = OmegaConf.load(args.base_config_path) 25 | config.update(vars(args)) 26 | 27 | assert config.pruning_ckpt_dir is not None, "Please provide a path to the pruning checkpoint directory." 28 | 29 | if config.seed is not None: 30 | set_seed(config.seed) 31 | 32 | init_logging(config) 33 | 34 | # Enable TF32 for faster training on Ampere GPUs, 35 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 36 | if config.training.allow_tf32: 37 | torch.backends.cuda.matmul.allow_tf32 = True 38 | 39 | trainer = SingleArchFinetuner(config=config) 40 | trainer.train() 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /scripts/metrics/clip_features.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pdm.utils.clip_utils import clip_features 3 | import logging 4 | 5 | logging.basicConfig(level=logging.INFO) 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--dataset_path', type=str, required=True) 11 | parser.add_argument('--clip_model', type=str, default="ViT-B/32") 12 | parser.add_argument('--num_workers', type=int, default=None) 13 | parser.add_argument('--batch_size', type=int, default=64) 14 | return parser.parse_args() 15 | 16 | 17 | def main(): 18 | args = parse_args() 19 | clip_features(args.dataset_path, clip_model=args.clip_model, num_workers=args.num_workers, 20 | batch_size=args.batch_size) 21 | 22 | 23 | if __name__ == '__main__': 24 | main() 25 | -------------------------------------------------------------------------------- /scripts/metrics/clip_score.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pdm.utils.clip_utils import clip_score 4 | import logging 5 | 6 | logging.basicConfig(level=logging.INFO) 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--gen_images_dir', type=str, required=True) 12 | parser.add_argument('--text_features_dir', type=str, required=True) 13 | parser.add_argument('--clip_model', type=str, default="ViT-B/32") 14 | parser.add_argument('--num_workers', type=int, default=None) 15 | parser.add_argument('--batch_size', type=int, default=64) 16 | parser.add_argument('--result_dir', type=str, required=True, help="Directory to save the results") 17 | parser.add_argument('--dataset_name', type=str, required=True, help="Dataset name") 18 | return parser.parse_args() 19 | 20 | 21 | def main(): 22 | args = parse_args() 23 | logging.info(f"Calculating CLIP score for {args.gen_images_dir} using {args.text_features_dir} as text features.") 24 | score = clip_score(args.text_features_dir, args.gen_images_dir, clip_model=args.clip_model, num_workers=args.num_workers, 25 | batch_size=args.batch_size) 26 | logging.info(f"CLIP score: {score}") 27 | 28 | os.makedirs(args.result_dir, exist_ok=True) 29 | 30 | with open(f"{args.result_dir}/clip_score_{args.dataset_name}.txt", "a") as f: 31 | f.write(f"{args.gen_images_dir} {score}\n") 32 | 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /scripts/metrics/fid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from cleanfid import fid 5 | import logging 6 | 7 | logging.basicConfig(level=logging.INFO) 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--gen_dir', type=str, required=True) 13 | parser.add_argument('--dataset', type=str, default="coco-30k") 14 | parser.add_argument('--mode', type=str, default="legacy_pytorch") 15 | parser.add_argument('--result_dir', type=str, required=True, help="Directory to save the results") 16 | return parser.parse_args() 17 | 18 | 19 | def main(): 20 | args = parse_args() 21 | fid_value = fid.compute_fid(args.gen_dir, dataset_name=args.dataset, mode=args.mode, dataset_split="custom") 22 | logging.info(f"FID: {fid_value}") 23 | 24 | os.makedirs(args.result_dir, exist_ok=True) 25 | 26 | with open(f"{args.result_dir}/fid.txt", "a") as f: 27 | f.write(f"{args.gen_dir} {fid_value}\n") 28 | 29 | 30 | if __name__ == '__main__': 31 | main() 32 | -------------------------------------------------------------------------------- /scripts/metrics/generate_fid_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | 4 | from omegaconf import OmegaConf 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | import torch 10 | import torch.utils.checkpoint 11 | 12 | import accelerate 13 | from accelerate.logging import get_logger 14 | from accelerate.utils import set_seed 15 | 16 | from diffusers import PNDMScheduler 17 | from diffusers.utils import check_min_version 18 | from diffusers import StableDiffusionPipeline 19 | 20 | import safetensors 21 | 22 | from pdm.models.diffusion import UNet2DConditionModelPruned 23 | from pdm.utils.arg_utils import parse_args 24 | from pdm.utils.data_utils import get_dataset 25 | 26 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 27 | check_min_version("0.22.0.dev0") 28 | 29 | logger = get_logger(__name__) 30 | 31 | 32 | def main(): 33 | args = parse_args() 34 | config = OmegaConf.load(args.base_config_path) 35 | # add args to config 36 | config.update(vars(args)) 37 | 38 | if config.seed is not None: 39 | set_seed(config.seed) 40 | 41 | # #################################################### Accelerator ################################################# 42 | accelerator = accelerate.Accelerator() 43 | 44 | # #################################################### Datasets #################################################### 45 | 46 | logger.info("Loading datasets...") 47 | # Get the datasets: you can either provide your own training and evaluation files (see below) 48 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 49 | 50 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 51 | # download the dataset. 52 | dataset_name = getattr(config.data, "dataset_name", None) 53 | img_col = getattr(config.data, "image_column", "image") 54 | capt_col = getattr(config.data, "caption_column", "caption") 55 | 56 | assert config.expert_id is not None, "expert index must be provided" 57 | assert config.finetuning_ckpt_dir is not None, "finetuning checkpoint directory must be provided" 58 | 59 | def collate_fn(examples, caption_column="caption", image_column="image"): 60 | captions = [example[caption_column] for example in examples] 61 | images = [example[image_column] for example in examples] 62 | return {"image": images, "caption": captions} 63 | 64 | dataset = get_dataset(config.data) 65 | 66 | dataset = dataset["validation"] 67 | fid_val_indices_path = os.path.abspath( 68 | os.path.join(config.finetuning_ckpt_dir, "..", "..", f"{dataset_name}_validation_mapped_indices.pt")) 69 | assert os.path.exists(fid_val_indices_path), \ 70 | (f"{dataset_name}_validation_mapped_indices.pt must be present in two upper directory of the checkpoint" 71 | f" directory {config.finetuning_ckpt_dir}") 72 | val_indices = torch.load(fid_val_indices_path, map_location="cpu") 73 | 74 | dataset = dataset.select(torch.where(val_indices == config.expert_id)[0]) 75 | logger.info("Dataset of size %d loaded." % len(dataset)) 76 | 77 | dataloader = torch.utils.data.DataLoader( 78 | dataset, 79 | shuffle=False, 80 | batch_size=config.data.dataloader.image_generation_batch_size * accelerator.num_processes, 81 | num_workers=config.data.dataloader.dataloader_num_workers, 82 | collate_fn=partial(collate_fn, caption_column=capt_col, image_column=img_col), 83 | ) 84 | 85 | dataloader = accelerator.prepare(dataloader) 86 | 87 | # #################################################### Models #################################################### 88 | arch_v = torch.load(os.path.join(config.finetuning_ckpt_dir, "arch_vector.pt"), map_location="cpu") 89 | 90 | unet = UNet2DConditionModelPruned.from_pretrained( 91 | config.pretrained_model_name_or_path, 92 | subfolder="unet", 93 | revision=config.revision, 94 | down_block_types=config.model.unet.unet_down_blocks, 95 | mid_block_type=config.model.unet.unet_mid_block, 96 | up_block_types=config.model.unet.unet_up_blocks, 97 | arch_vector=arch_v 98 | ) 99 | 100 | state_dict = safetensors.torch.load_file(os.path.join(config.finetuning_ckpt_dir, "unet", 101 | "diffusion_pytorch_model.safetensors")) 102 | unet.load_state_dict(state_dict) 103 | 104 | noise_scheduler = PNDMScheduler.from_pretrained(config.pretrained_model_name_or_path, subfolder="scheduler") 105 | pipeline = StableDiffusionPipeline.from_pretrained( 106 | config.pretrained_model_name_or_path, 107 | unet=unet, 108 | scheduler=noise_scheduler, 109 | ) 110 | 111 | if config.enable_xformers_memory_efficient_attention: 112 | pipeline.enable_xformers_memory_efficient_attention() 113 | 114 | pipeline.set_progress_bar_config(disable=not accelerator.is_main_process) 115 | 116 | pipeline.to(accelerator.device) 117 | 118 | image_output_dir = os.path.join(config.finetuning_ckpt_dir, "..", "..", f"{dataset_name}_fid_images") 119 | os.makedirs(image_output_dir, exist_ok=True) 120 | 121 | for batch in dataloader: 122 | if config.seed is None: 123 | generator = None 124 | else: 125 | generator = torch.Generator(device=accelerator.device).manual_seed(config.seed) 126 | gen_images = pipeline(batch[capt_col], num_inference_steps=config.training.num_inference_steps, 127 | generator=generator, output_type="np" 128 | ).images 129 | 130 | for idx, caption in enumerate(batch[capt_col]): 131 | image_name = batch["image"][idx].split("/")[-1] 132 | if image_name.endswith(".jpg"): 133 | image_name = image_name[:-4] 134 | image_path = os.path.join(image_output_dir, f"{image_name}.npy") 135 | img = gen_images[idx] 136 | img = img * 255 137 | img = img.astype(np.uint8) 138 | img = cv2.resize(img, (256, 256)) 139 | np.save(image_path, img) 140 | 141 | 142 | if __name__ == "__main__": 143 | main() 144 | -------------------------------------------------------------------------------- /scripts/metrics/resize_and_save_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | from PIL import Image 5 | import os 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser(description="Resize images in a directory") 9 | parser.add_argument("--data_dir", type=str, required=True, help="Directory containing images") 10 | parser.add_argument("--output_dir", type=str, required=True, help="Directory to save resized images") 11 | parser.add_argument("--size", type=int, nargs=2, default=[256, 256], help="Size of the resized images") 12 | return parser.parse_args() 13 | 14 | 15 | def resize_images_in_dir(data_dir, output_dir, size): 16 | for img_name in os.listdir(data_dir): 17 | img = Image.open(os.path.join(data_dir, img_name)) 18 | img = img.resize(size) 19 | img = np.array(img) 20 | np.save(os.path.join(output_dir, img_name + ".npy"), img) 21 | 22 | 23 | def main(): 24 | args = parse_args() 25 | resize_images_in_dir(args.data_dir, args.output_dir, size=tuple(args.size)) 26 | 27 | 28 | if __name__ == "__main__": 29 | main() 30 | -------------------------------------------------------------------------------- /scripts/metrics/sample_coco_30k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import argparse 5 | import cv2 6 | 7 | from PIL import Image 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--data_dir", type=str, required=True, 13 | help="Path to the COCO dataset directory.") 14 | parser.add_argument("--seed", type=int, default=42, help="Random seed.") 15 | parser.add_argument("--num_samples", type=int, default=30000, 16 | help="Number of samples to take from the COCO 2014 validation set.") 17 | return parser.parse_args() 18 | 19 | 20 | def main(args): 21 | annotations_path = os.path.join(args.data_dir, "annotations", "captions_val2014.json") 22 | images_dir = os.path.join(args.data_dir, "images", "val2014") 23 | output_dir = os.path.join(args.data_dir, "images", "val2014_30k") 24 | output_annotations_path = os.path.join(args.data_dir, "annotations", "captions_val2014_30k.json") 25 | 26 | os.makedirs(output_dir, exist_ok=True) 27 | 28 | annotations = json.load(open(annotations_path)) 29 | # deduplicate the annotations 30 | image_ids = set() 31 | deduplicated_annotations = [] 32 | for ann in annotations['annotations']: 33 | if ann['image_id'] not in image_ids: 34 | deduplicated_annotations.append(ann) 35 | image_ids.add(ann['image_id']) 36 | 37 | annotations['annotations'] = deduplicated_annotations 38 | np.random.seed(args.seed) 39 | indices = np.random.choice(len(annotations['annotations']), args.num_samples, replace=False) 40 | selected_annotations = [annotations['annotations'][i] for i in indices] 41 | 42 | with open(output_annotations_path, "w") as f: 43 | json.dump({"annotations": selected_annotations}, f) 44 | 45 | for i in indices: 46 | image_id = annotations['annotations'][i]['image_id'] 47 | 48 | image_path = os.path.join(images_dir, f"COCO_val2014_{image_id:012d}.jpg") 49 | img = Image.open(image_path).convert("RGB") 50 | data = np.asarray(img) 51 | data = cv2.resize(data, (256, 256)) 52 | output_path = os.path.join(output_dir, f"COCO_val2014_{image_id:012d}.npy") 53 | np.save(output_path, data) 54 | 55 | print(f"Saved {args.num_samples} samples to {output_dir}.") 56 | print(f"Saved annotations to {output_annotations_path}.") 57 | print("Done.") 58 | 59 | 60 | if __name__ == "__main__": 61 | args = parse_args() 62 | main(args) 63 | -------------------------------------------------------------------------------- /scripts/metrics/save_captions.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pdm.datasets.cc3m import load_cc3m_webdataset 4 | 5 | 6 | def save_coco_captions(annotations_file): 7 | # annotations file's name is something like 'annotations/captions_val2014_30k.json' 8 | split_name = os.path.basename(annotations_file)[len('captions_'):-len('.json')] 9 | captions_file = json.load(open(annotations_file)) 10 | captions_dir = os.path.dirname(annotations_file) 11 | save_dir = os.path.join(captions_dir, 'clip-captions') 12 | os.makedirs(save_dir, exist_ok=True) 13 | for capt in captions_file['annotations']: 14 | if '2014' in annotations_file: 15 | image_id = f"COCO_{split_name}_%012d" % capt['image_id'] 16 | else: 17 | image_id = "%012d" % capt['image_id'] 18 | 19 | caption = capt['caption'] 20 | with open(os.path.join(save_dir, image_id + '.txt'), 'w') as f: 21 | f.write(caption) 22 | 23 | 24 | def save_cc3m_captions(data_dir): 25 | split = "validation" 26 | dataset = load_cc3m_webdataset(data_dir, split=split) 27 | save_dir = os.path.join(data_dir, 'clip-captions') 28 | os.makedirs(save_dir, exist_ok=True) 29 | for sample in dataset: 30 | image_id = sample['__key__'] 31 | caption = sample['caption'] 32 | with open(os.path.join(save_dir, image_id + '.txt'), 'w') as f: 33 | f.write(caption) 34 | 35 | 36 | if __name__ == '__main__': 37 | save_coco_captions('/home/rezashkv/scratch/research/data/coco/annotations/captions_val2014_30k.json') 38 | save_cc3m_captions('/home/rezashkv/scratch/research/data/cc3m') 39 | -------------------------------------------------------------------------------- /scripts/other/calculate_pruning_ratio.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from omegaconf import OmegaConf 5 | 6 | import torch 7 | import torch.utils.checkpoint 8 | 9 | from accelerate.utils import set_seed 10 | from accelerate.logging import get_logger 11 | 12 | from diffusers import UNet2DConditionModel 13 | from diffusers.utils import check_min_version 14 | 15 | from transformers import CLIPTextModel 16 | from transformers.utils import ContextManagers 17 | 18 | from pdm.models.unet import UNet2DConditionModelGated 19 | from pdm.models import HyperStructure, StructureVectorQuantizer 20 | from pdm.utils.arg_utils import parse_args 21 | from pdm.utils.op_counter import count_ops_and_params 22 | from pdm.utils.dist_utils import deepspeed_zero_init_disabled_context_manager 23 | 24 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 25 | check_min_version("0.22.0.dev0") 26 | 27 | logger = get_logger(__name__) 28 | 29 | 30 | def main(): 31 | args = parse_args() 32 | config = OmegaConf.load(args.base_config_path) 33 | config.update(vars(args)) 34 | 35 | assert config.pruning_ckpt_dir is not None, "Please provide a path to the pruning checkpoint directory." 36 | 37 | if config.seed is not None: 38 | set_seed(config.seed) 39 | 40 | # Make one log on every process with the configuration for debugging. 41 | logging.basicConfig( 42 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 43 | datefmt="%m/%d/%Y %H:%M:%S", 44 | level=logging.INFO, 45 | ) 46 | 47 | # #################################################### Models #################################################### 48 | 49 | # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. 50 | # For this to work properly all models must be run through `accelerate.prepare`. But accelerate 51 | # will try to assign the same optimizer with the same weights to all models during 52 | # `deepspeed.initialize`, which of course doesn't work. 53 | # 54 | # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 55 | # frozen models from being partitioned during `zero.Init` which gets called during 56 | # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding 57 | # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. 58 | with ContextManagers(deepspeed_zero_init_disabled_context_manager()): 59 | text_encoder = CLIPTextModel.from_pretrained( 60 | config.pretrained_model_name_or_path, subfolder="text_encoder", revision=config.revision 61 | ) 62 | 63 | pretrained_config = UNet2DConditionModel.load_config(config.pretrained_model_name_or_path, subfolder="unet") 64 | 65 | sample_inputs = {'sample': torch.randn(1, pretrained_config["in_channels"], pretrained_config["sample_size"], 66 | pretrained_config["sample_size"]), 67 | 'timestep': torch.ones((1,)).long(), 68 | 'encoder_hidden_states': text_encoder(torch.tensor([[100]]))[0], 69 | } 70 | 71 | unet = UNet2DConditionModelGated.from_pretrained( 72 | config.pretrained_model_name_or_path, 73 | subfolder="unet", 74 | revision=config.non_ema_revision, 75 | down_block_types=tuple(config.model.unet.unet_down_blocks), 76 | mid_block_type=config.model.unet.unet_mid_block, 77 | up_block_types=tuple(config.model.unet.unet_up_blocks), 78 | gated_ff=config.model.unet.gated_ff, 79 | ff_gate_width=config.model.unet.ff_gate_width 80 | ) 81 | 82 | hyper_net = HyperStructure.from_pretrained(config.pruning_ckpt_dir, subfolder="hypernet") 83 | 84 | if config.pruning_type == "multi-expert": 85 | embeddings_gs = torch.load(os.path.join(config.pruning_ckpt_dir, "quantizer_embeddings.pt"), map_location="cpu") 86 | else: 87 | quantizer = StructureVectorQuantizer.from_pretrained(config.pruning_ckpt_dir, subfolder="quantizer") 88 | embeddings_gs = quantizer.gumbel_sigmoid_trick(hyper_net.arch) 89 | 90 | arch_vecs_separated = hyper_net.transform_structure_vector( 91 | torch.ones((1, embeddings_gs.shape[1]), device=embeddings_gs.device)) 92 | 93 | unet.set_structure(arch_vecs_separated) 94 | 95 | macs, params = count_ops_and_params(unet, sample_inputs) 96 | 97 | logging.info( 98 | "Full UNet's Params/MACs calculated by OpCounter:\tparams: {:.3f}M\t MACs: {:.3f}G".format( 99 | params / 1e6, macs / 1e9)) 100 | 101 | sanity_macs_dict = unet.calc_macs() 102 | prunable_macs_list = [[e / sanity_macs_dict['prunable_macs'] for e in elem] for elem in 103 | unet.get_prunable_macs()] 104 | 105 | unet.prunable_macs_list = prunable_macs_list 106 | unet.resource_info_dict = sanity_macs_dict 107 | 108 | sanity_string = "Our MACs calculation:\t" 109 | for k, v in sanity_macs_dict.items(): 110 | if isinstance(v, torch.Tensor): 111 | sanity_string += f" {k}: {v.item() / 1e9:.3f}\t" 112 | else: 113 | sanity_string += f" {k}: {v / 1e9:.3f}\t" 114 | logging.info(sanity_string) 115 | 116 | arch_vectors_separated = hyper_net.transform_structure_vector(embeddings_gs) 117 | unet.set_structure(arch_vectors_separated) 118 | 119 | macs_dict = unet.calc_macs() 120 | resource_ratios = macs_dict['cur_total_macs'] / (unet.resource_info_dict['cur_total_macs'].squeeze()) 121 | logging.info(f"Resource Ratios: {resource_ratios}") 122 | torch.save(resource_ratios, os.path.join(config.pruning_ckpt_dir, "resource_ratios.pt")) 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /scripts/other/depth_analysis.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | 3 | import PIL 4 | 5 | import torch 6 | import torch.utils.checkpoint 7 | 8 | from accelerate.utils import set_seed 9 | from accelerate.logging import get_logger 10 | 11 | from pdm.utils.arg_utils import parse_args 12 | from pdm.utils.logging_utils import init_logging 13 | from pdm.training.trainer import Pruner 14 | 15 | logger = get_logger(__name__) 16 | 17 | 18 | def main(): 19 | PIL.Image.MAX_IMAGE_PIXELS = 933120000 20 | torch.autograd.set_detect_anomaly(True) 21 | 22 | args = parse_args() 23 | config = OmegaConf.load(args.base_config_path) 24 | config.update(vars(args)) 25 | 26 | if config.seed is not None: 27 | set_seed(config.seed) 28 | 29 | init_logging(config) 30 | 31 | # Enable TF32 for faster training on Ampere GPUs, 32 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 33 | if config.training.allow_tf32: 34 | torch.backends.cuda.matmul.allow_tf32 = True 35 | 36 | pruner = Pruner(config=config) 37 | 38 | pruner.depth_analysis(n_consecutive_blocks=config.n_blocks) 39 | 40 | 41 | if __name__ == "__main__": 42 | main() 43 | --------------------------------------------------------------------------------