├── .gitignore ├── .gitmodules ├── .vscode ├── launch.json └── settings.json ├── LICENSE ├── README.md ├── assets └── teaser.png ├── augmentation ├── __init__.py ├── base_augmentation.py ├── diff_mix.py ├── real_mix.py └── ti_mix.py ├── dataset ├── __init__.py ├── base.py ├── instance │ ├── aircraft.py │ ├── car.py │ ├── cub.py │ ├── dog.py │ ├── flower.py │ ├── food.py │ ├── pascal.py │ ├── pet.py │ └── waterbird.py └── template.py ├── downstream_tasks ├── imb_utils │ ├── __init__.py │ ├── autoaug.py │ ├── moco_loader.py │ ├── randaugment.py │ └── util.py ├── losses.py ├── mixup.py ├── train_hub.py ├── train_hub_imb.py └── train_hub_waterbird.py ├── outputs ├── requirement.txt ├── scripts ├── classification.sh ├── classification_imb.sh ├── classification_waterbird.sh ├── compose_syn_data.py ├── filter_syn_data.py ├── finetune.sh ├── finetune_imb.sh ├── sample.sh ├── sample_imb.sh └── sample_mp.py ├── train_lora.py ├── utils ├── misc.py ├── network.py └── visualization.py └── visualization ├── visualize_attn_map.py ├── visualize_cases.py └── visualize_filtered_data.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | debug/ 7 | src/ 8 | notebook/ 9 | figures/ 10 | ckpts/* 11 | aug_samples/* 12 | _others/ 13 | pyrightconfig.json 14 | # images 15 | *.jpg 16 | *.pdf 17 | 18 | # experiment data 19 | *.out 20 | *.csv 21 | *.pt 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | pip-wheel-metadata/ 41 | share/python-wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | .vscode/* 59 | htmlcov/ 60 | .tox/ 61 | .nox/ 62 | .coverage 63 | .coverage.* 64 | .cache 65 | nosetests.xml 66 | coverage.xml 67 | *.cover 68 | *.py,cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | db.sqlite3-journal 81 | 82 | # Flask stuff: 83 | .webassets-cache 84 | 85 | # Scrapy stuff: 86 | .scrapy 87 | 88 | # Sphinx documentation 89 | docs/_build/ 90 | 91 | # PyBuilder 92 | target/ 93 | 94 | # Jupyter Notebook 95 | .ipynb_checkpoints 96 | 97 | # IPython 98 | profile_default/ 99 | ipython_config.py 100 | 101 | # pyenv 102 | .python-version 103 | 104 | # pipenv 105 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 106 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 107 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 108 | # install all needed dependencies. 109 | #Pipfile.lock 110 | 111 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 112 | __pypackages__/ 113 | 114 | # Celery stuff 115 | celerybeat-schedule 116 | celerybeat.pid 117 | 118 | # SageMath parsed files 119 | *.sage.py 120 | 121 | # Environments 122 | .env 123 | .venv 124 | env/ 125 | venv/ 126 | ENV/ 127 | env.bak/ 128 | venv.bak/ 129 | 130 | # Spyder project settings 131 | .spyderproject 132 | .spyproject 133 | 134 | # Rope project settings 135 | .ropeproject 136 | 137 | # mkdocs documentation 138 | /site 139 | 140 | # mypy 141 | .mypy_cache/ 142 | .dmypy.json 143 | dmypy.json 144 | 145 | # Pyre type checker 146 | .pyre/ 147 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "stable-diffusion"] 2 | path = stable-diffusion 3 | url = https://github.com/CompVis/stable-diffusion.git 4 | [submodule "erasing"] 5 | path = erasing 6 | url = https://github.com/brandontrabucco/erasing.git 7 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | 8 | { 9 | "name": "Python: Finetune Script", 10 | "type": "python", 11 | "request": "launch", 12 | "module": "accelerate.commands.launch", 13 | "console": "integratedTerminal", 14 | "env":{"CUDA_VISIBLE_DEVICES": "0,1"}, 15 | "args": [ 16 | "train_lora.py", 17 | "--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5", 18 | "--dataset_name=cub", // 请替换为实际的数据集名称 19 | "--resolution=512", // 20 | "--random_flip", 21 | "--max_train_steps=35000", 22 | "--num_train_epochs=10", // 23 | "--checkpointing_steps=5000", 24 | "--learning_rate=5e-05", 25 | "--lr_scheduler=constant", 26 | "--lr_warmup_steps=0", 27 | "--seed=42", 28 | "--rank=10", 29 | "--local_files_only", 30 | "--examples_per_class=5", 31 | "--train_batch_size=1", 32 | "--output_dir=/data/zhicai/code/Diff-Mix/outputs/finetune_model/finetune_ti_db/sd-cub-5shot-model-lora-rank10", // 请替换为实际的输出目录 33 | "--report_to=tensorboard" 34 | ] 35 | }, 36 | { 37 | "name": "Python: Finetune IMB Script", 38 | "type": "python", 39 | "request": "launch", 40 | "module": "accelerate.commands.launch", 41 | "console": "integratedTerminal", 42 | "env":{"CUDA_VISIBLE_DEVICES": "4"}, 43 | "args": [ 44 | "train_lora.py", 45 | "--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5", 46 | "--dataset_name=cub", // 请替换为实际的数据集名称 47 | "--resolution=512", // 48 | "--random_flip", 49 | "--max_train_steps=35000", 50 | "--num_train_epochs=10", // 51 | "--checkpointing_steps=5000", 52 | "--learning_rate=5e-05", 53 | "--lr_scheduler=constant", 54 | "--lr_warmup_steps=0", 55 | "--seed=42", 56 | "--rank=10", 57 | "--local_files_only", 58 | "--train_batch_size=1", 59 | "--output_dir=ckpts/cub/IMB0.01_lora_rank10", // 请替换为实际的输出目录 60 | "--report_to=tensorboard", 61 | "--task=imbalanced", 62 | "--imbalance_factor=0.01" 63 | 64 | ] 65 | }, 66 | { 67 | "name": "Python: Sample MP", 68 | "type": "python", 69 | "request": "launch", 70 | "program": "${workspaceFolder}/scripts/sample_mp.py", 71 | "args": [ 72 | "--output_root", "aug_samples", 73 | "--finetuned_ckpt", "ckpts/cub/shot-1-lora-rank10", 74 | "--dataset", "cub", 75 | "--syn_dataset_mulitiplier", "5", 76 | "--strength_strategy", "fixed", 77 | "--resolution", "512", 78 | "--batch_size", "1", 79 | "--aug_strength", "0.7", 80 | "--model_path", "runwayml/stable-diffusion-v1-5", 81 | "--sample_strategy", "diff-mix", 82 | "--gpu_ids", "3", "4", 83 | ], 84 | "console": "integratedTerminal" 85 | }, 86 | { 87 | "name": "Python: Sample MP IMB", 88 | "type": "python", 89 | "request": "launch", 90 | "program": "${workspaceFolder}/scripts/sample_mp.py", 91 | "args": [ 92 | "--output_root", "aug_samples", 93 | "--finetuned_ckpt", "ckpts/cub/shot-1-lora-rank10", 94 | "--dataset", "cub", 95 | "--syn_dataset_mulitiplier", "5", 96 | "--strength_strategy", "fixed", 97 | "--resolution", "512", 98 | "--batch_size", "1", 99 | "--aug_strength", "0.7", 100 | "--model_path", "runwayml/stable-diffusion-v1-5", 101 | "--sample_strategy", "diff-mix", 102 | "--gpu_ids", "3", "4", 103 | "--task", "imbalanced", 104 | "--imbalance_factor", "0.01" 105 | ], 106 | "console": "integratedTerminal" 107 | }, 108 | { 109 | "name": "Python: Train Hub", 110 | "type": "python", 111 | "request": "launch", 112 | "program": "${workspaceFolder}/downstream_tasks/train_hub.py", 113 | "args": [ 114 | "--dataset", "cub", 115 | "--syndata_p", "0.1", 116 | "--syndata_dir", "outputs/aug_samples/cub/dreambooth-lora-mixup-Multi7-db_ti10000-Strength0.5", 117 | "--model", "resnet50", 118 | "--gamma", "0.8", 119 | "--examples_per_class", "-1", 120 | "--gpu", "1", 121 | "--amp", "2", 122 | "--note", "${env:DATE}", 123 | "--group_note", "test", 124 | "--nepoch", "60", 125 | "--res_mode", "224", 126 | "--lr", "0.05", 127 | "--seed", "0", 128 | "--weight_decay", "0.0005" 129 | ], 130 | "console": "integratedTerminal", 131 | "env": { 132 | "DATE": "${command:python.interpreterPath} -c \"import datetime; print(datetime.datetime.now().strftime('%m%d%H%M'))\"" 133 | } 134 | }, 135 | { 136 | "name": "Python: Train Hub Waterbird", 137 | "type": "python", 138 | "request": "launch", 139 | "program": "${workspaceFolder}/downstream_tasks/train_hub_waterbird.py", 140 | "args": [ 141 | "--dataset", "cub", 142 | "--syndata_p", "0.1", 143 | "--syndata_dir", "outputs/aug_samples/cub/dreambooth-lora-mixup-Multi7-db_ti10000-Strength0.5", 144 | "--model", "resnet50", 145 | "--gamma", "0.8", 146 | "--examples_per_class", "-1", 147 | "--gpu", "4", 148 | "--amp", "2", 149 | "--note", "${env:DATE}", 150 | "--group_note", "test", 151 | "--nepoch", "60", 152 | "--res_mode", "224", 153 | "--lr", "0.05", 154 | "--seed", "0", 155 | "--weight_decay", "0.0005" 156 | ], 157 | "console": "integratedTerminal", 158 | "env": { 159 | "DATE": "${command:python.interpreterPath} -c \"import datetime; print(datetime.datetime.now().strftime('%m%d%H%M'))\"" 160 | } 161 | }, 162 | { 163 | "name": "Python: Train Hub Imb", 164 | "type": "python", 165 | "request": "launch", 166 | "program": "${workspaceFolder}/downstream_tasks/train_hub_imb.py", 167 | "args": [ 168 | "--dataset", "cub", 169 | "--loss_type", "CE", 170 | "--lr", "0.005", 171 | "--epochs", "200", 172 | "--imb_factor", "0.01", 173 | "-b", "128", 174 | "--gpu", "4", 175 | "--data_aug", "vanilla", 176 | "--root_log", "outputs/results_cmo", 177 | "--syndata_dir", "outputs/aug_samples_imbalance/cub/dreambooth-lora-mixup-Multi10-db_ti_latest_imb_0.1-Strength0.7", 178 | "--syndata_p", "0.5", 179 | "--gamma", "0.8", 180 | "--use_weighted_syn" 181 | ], 182 | "console": "integratedTerminal", 183 | } 184 | ] 185 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.typeCheckingMode": "off" 3 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Zhicai Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

3 | 4 |

Enhance Image Classification Via Inter-Class Image Mixup With Diffusion Model

5 |

6 | Paper PDF 7 |

8 | 9 |
10 | Image 11 |
12 | 13 | ## Introduction 👋 14 | This repository implements various **generative data augmentation** strategies using stable diffusion to create synthetic datasets, aimed at enhancing classification tasks. 15 | ## Requirements 16 | The key packages and their versions are listed below. The code is tested on a single node with 4 NVIDIA RTX3090 GPUs. 17 | ``` 18 | torch==2.0.1+cu118 19 | diffusers==0.25.1 20 | transformers==4.36.2 21 | datasets==2.16.1 22 | accelerate==0.26.1 23 | numpy==1.24.4 24 | ``` 25 | ## Datasets 26 | For convenience, well-structured datasets in Hugging Face can be utilized. The fine-grained datasets `CUB` and `Aircraft` we experimented with can be downloaded from [Multimodal-Fatima/CUB_train](https://huggingface.co/datasets/Multimodal-Fatima/CUB_train) and [Multimodal-Fatima/FGVC_Aircraft_train](https://huggingface.co/datasets/Multimodal-Fatima/FGVC_Aircraft_train), respectively. In case of encountering network connection problem during training, please pre-download the data from the website, and the saved local path `HUG_LOCAL_IMAGE_TRAIN_DIR` should be specified in the `dataset/instance/cub.py`. 27 | 28 | 29 | 30 | ## Fine-tune on a dataset 🔥 31 | ### Pre-trained lora weights 32 | We provide the lora weights fine-tuned on the full dataset in case for fast reproducation on given datasets. One can download using the following link, and unzip the file into dir `ckpts` and the file structure look like: 33 | 34 | ``` 35 | ckpts 36 | ├── cub -packages/torch/nn/modules/module.py", line 1501, in _call_impl 37 | │ └── shot-1-lora-rank10 38 | │ ├── learned_embeds-steps-last.bin -packages/diffusers/models/attention_processor.py", line 527, in forward 39 | │ └── pytorch_lora_weights.safetensors 40 | └── put_finetuned_ckpts_here.txt 41 | ``` 42 | 43 | | Dataset | data | ckpts (fullshot) | 44 | |---------|---------------------------------------------------------------------|---------------------------------------------------------------------| 45 | | CUB | huggingface ([train](https://huggingface.co/datasets/Multimodal-Fatima/CUB_train)/[test](https://huggingface.co/datasets/Multimodal-Fatima/CUB_test))| [google drive](https://drive.google.com/file/d/1AOX4TcXSPGRSmxSgB08L8P-28c5TPkxw/view?usp=sharing) | 46 | | Flower | [official website ](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/) | [google drive](https://drive.google.com/file/d/1hBodBaLb_GokxfMXvQyhr4OGzyBgyBm0/view?usp=sharing) | 47 | | Aircraft | huggingface ([train](https://huggingface.co/datasets/Multimodal-Fatima/FGVC_Aircraft_train)/[test](https://huggingface.co/datasets/Multimodal-Fatima/FGVC_Aircraft_test)) | [google drive](https://drive.google.com/file/d/19PuRbIsurv1IKeu-jx5WieocMy5rfIKg/view?usp=sharing) | 48 | 49 | 50 | 51 | ### Customized fine-tuning 52 | The `scripts/finetune.sh` script allows users to perform fine-tuning on their own datasets. By default, it implements a fine-tuning strategy combining `DreamBooth` and `Textual Inversion`. Users can customize the `examples_per_class` argument to fine-tune the model on a dataset with {examples_per_class} shots. The tuning process costs around 4 hours on 4 RTX3090 GPUs for full-shot cub dataset. 53 | 54 | ``` 55 | MODEL_NAME="runwayml/stable-diffusion-v1-5" 56 | DATASET='cub' 57 | SHOT=-1 # set -1 for full shot 58 | OUTPUT_DIR="ckpts/${DATASET}/shot${SHOT}_lora_rank10" 59 | 60 | accelerate launch --mixed_precision='fp16' --main_process_port 29507 \ 61 | train_lora.py \ 62 | --pretrained_model_name_or_path=$MODEL_NAME \ 63 | --dataset_name=$DATASET \ 64 | --resolution=224 \ 65 | --random_flip \ 66 | --max_train_steps=35000 \ 67 | --num_train_epochs=10 \ 68 | --checkpointing_steps=5000 \ 69 | --learning_rate=5e-05 \ 70 | --lr_scheduler='constant' \ 71 | --lr_warmup_steps=0 \ 72 | --seed=42 \ 73 | --rank=10 \ 74 | --local_files_only \ 75 | --examples_per_class $SHOT \ 76 | --train_batch_size 2 \ 77 | --output_dir=$OUTPUT_DIR \ 78 | --report_to='tensorboard'" 79 | ``` 80 | 81 | ## Contruct synthetic data 82 | `scripts/sample.sh` provides script to synthesize augmented images in a multi-processing way. Each item in `GPU_IDS` denotes the process running on the indexed GPU. The simplified command for sampling a $5\times$ synthetic subset in an inter-class translation manner (`diff-mix`) with strength $s=0.7$ is: 83 | 84 | ```bash 85 | DATASET='cub' 86 | # set -1 for full shot 87 | SHOT=-1 88 | FINETUNED_CKPT="ckpts/cub/shot${SHOT}-lora-rank10" 89 | # ['diff-mix', 'diff-aug', 'diff-gen', 'real-mix', 'real-aug', 'real-gen', 'ti_mix', 'ti_aug'] 90 | SAMPLE_STRATEGY='diff-mix' 91 | STRENGTH=0.8 92 | # ['fixed', 'uniform']. 'fixed': use fixed $STRENGTH, 'uniform': sample from [0.3, 0.5, 0.7, 0.9] 93 | STRENGTH_STRATEGY='fixed' 94 | # expand the dataset by 5 times 95 | MULTIPLIER=5 96 | # spwan 4 processes 97 | GPU_IDS=(0 1 2 3) 98 | 99 | python scripts/sample_mp.py \ 100 | --model-path='runwayml/stable-diffusion-v1-5' \ 101 | --output_root='outputs/aug_samples' \ 102 | --dataset=$DATASET \ 103 | --finetuned_ckpt=$FINETUNED_CKPT \ 104 | --syn_dataset_mulitiplier=$MULTIPLIER \ 105 | --strength_strategy=$STRENGTH_STRATEGY \ 106 | --sample_strategy=$SAMPLE_STRATEGY \ 107 | --examples_per_class=$SHOT \ 108 | --resolution=512 \ 109 | --batch_size=1 \ 110 | --aug_strength=0.8 \ 111 | --gpu-ids=${GPU_IDS[@]} 112 | ``` 113 | The output synthetic dir will be located at `aug_samples/cub/diff-mix_-1_fixed_0.7`. To create a 5-shot setting, set the `examples_per_class` argument to 5 and the output dir will be at `aug_samples/cub/diff-mix_5_fixed_0.7`. Please ensure that the `finetuned_ckpt` is also fine-tuned under the same 5-shot setting. 114 | 115 | 116 | ## Downstream classification 117 | After completing the sampling process, you can integrate the synthetic data into downstream classification and initiate training using the script `scripts/classification.sh`: 118 | ``` 119 | GPU=1 120 | DATASET="cub" 121 | SHOT=-1 122 | # "shot{args.examples_per_class}_{args.sample_strategy}_{args.strength_strategy}_{args.aug_strength}" 123 | SYNDATA_DIR="aug_samples/cub/shot${SHOT}_diff-mix_fixed_0.7" # shot-1 denotes full shot 124 | SYNDATA_P=0.1 125 | GAMMA=0.8 126 | 127 | python downstream_tasks/train_hub.py \ 128 | --dataset $DATASET \ 129 | --syndata_dir $SYNDATA_DIR \ 130 | --syndata_p $SYNDATA_P \ 131 | --model "resnet50" \ 132 | --gamma $GAMMA \ 133 | --examples_per_class $SHOT \ 134 | --gpu $GPU \ 135 | --amp 2 \ 136 | --note $(date +%m%d%H%M) \ 137 | --group_note "fullshot" \ 138 | --nepoch 120 \ 139 | --res_mode 224 \ 140 | --lr 0.05 \ 141 | --seed 0 \ 142 | --weight_decay 0.0005 143 | ``` 144 | 145 | We also provides the scripts for robustness test and long-tail classification in `scripts/classification_waterbird.sh` and `scripts/classification_imb.sh`, respectively. 146 | ## Acknowledgements 147 | 148 | This project is built upon the repository [Da-fusion](https://github.com/brandontrabucco/da-fusion) and [diffusers](https://github.com/huggingface/diffusers). Special thanks to the contributors. 149 | 150 | 151 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhicaiwww/Diff-Mix/a81337b1492bcc7f7a8a61921836e94191a3a0ef/assets/teaser.png -------------------------------------------------------------------------------- /augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from .diff_mix import * 2 | from .real_mix import * 3 | from .ti_mix import * 4 | 5 | AUGMENT_METHODS = { 6 | "ti-mix": TextualInversionMixup, 7 | "ti_aug": TextualInversionMixup, 8 | "real-aug": DreamboothLoraMixup, 9 | "real-mix": DreamboothLoraMixup, 10 | "real-gen": RealGeneration, 11 | "diff-mix": DreamboothLoraMixup, 12 | "diff-aug": DreamboothLoraMixup, 13 | "diff-gen": DreamboothLoraGeneration, 14 | } 15 | -------------------------------------------------------------------------------- /augmentation/base_augmentation.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Any, Tuple 3 | 4 | import torch.nn as nn 5 | from PIL import Image 6 | 7 | 8 | class GenerativeAugmentation(nn.Module, abc.ABC): 9 | 10 | @abc.abstractmethod 11 | def forward( 12 | self, image: Image.Image, label: int, metadata: dict 13 | ) -> Tuple[Image.Image, int]: 14 | 15 | return NotImplemented 16 | 17 | 18 | class GenerativeMixup(nn.Module, abc.ABC): 19 | 20 | @abc.abstractmethod 21 | def forward( 22 | self, image: Image.Image, label: int, metadata: dict, strength: float 23 | ) -> Tuple[Image.Image, int]: 24 | 25 | return NotImplemented 26 | -------------------------------------------------------------------------------- /augmentation/diff_mix.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable, Tuple 3 | 4 | import torch 5 | from PIL import Image 6 | from torch import autocast 7 | from transformers import CLIPTextModel, CLIPTokenizer 8 | 9 | from augmentation.base_augmentation import GenerativeMixup 10 | from diffusers import ( 11 | DPMSolverMultistepScheduler, 12 | StableDiffusionImg2ImgPipeline, 13 | StableDiffusionPipeline, 14 | ) 15 | from diffusers.utils import logging 16 | 17 | os.environ["WANDB_DISABLED"] = "true" 18 | ERROR_MESSAGE = "Tokenizer already contains the token {token}. \ 19 | Please pass a different `token` that is not already in the tokenizer." 20 | 21 | 22 | def format_name(name): 23 | return f"<{name.replace(' ', '_')}>" 24 | 25 | 26 | def load_diffmix_embeddings( 27 | embed_path: str, 28 | text_encoder: CLIPTextModel, 29 | tokenizer: CLIPTokenizer, 30 | device="cuda", 31 | ): 32 | 33 | embedding_ckpt = torch.load(embed_path, map_location="cpu") 34 | learned_embeds_dict = embedding_ckpt["learned_embeds_dict"] 35 | name2placeholder = embedding_ckpt["name2placeholder"] 36 | placeholder2name = embedding_ckpt["placeholder2name"] 37 | 38 | name2placeholder = { 39 | k.replace("/", " ").replace("_", " "): v for k, v in name2placeholder.items() 40 | } 41 | placeholder2name = { 42 | v: k.replace("/", " ").replace("_", " ") for k, v in name2placeholder.items() 43 | } 44 | 45 | for token, token_embedding in learned_embeds_dict.items(): 46 | 47 | # add the token in tokenizer 48 | num_added_tokens = tokenizer.add_tokens(token) 49 | assert num_added_tokens > 0, ERROR_MESSAGE.format(token=token) 50 | 51 | # resize the token embeddings 52 | text_encoder.resize_token_embeddings(len(tokenizer)) 53 | added_token_id = tokenizer.convert_tokens_to_ids(token) 54 | 55 | # get the old word embeddings 56 | embeddings = text_encoder.get_input_embeddings() 57 | 58 | # get the id for the token and assign new embeds 59 | embeddings.weight.data[added_token_id] = token_embedding.to( 60 | embeddings.weight.dtype 61 | ) 62 | 63 | return name2placeholder, placeholder2name 64 | 65 | 66 | def identity(*args): 67 | return args 68 | 69 | 70 | class IdentityMap: 71 | def __getitem__(self, key): 72 | return key 73 | 74 | 75 | class DreamboothLoraMixup(GenerativeMixup): 76 | 77 | pipe = None # global sharing is a hack to avoid OOM 78 | 79 | def __init__( 80 | self, 81 | lora_path: str, 82 | model_path: str = "runwayml/stable-diffusion-v1-5", 83 | embed_path: str = None, 84 | prompt: str = "a photo of a {name}", 85 | format_name: Callable = format_name, 86 | guidance_scale: float = 7.5, 87 | disable_safety_checker: bool = True, 88 | revision: str = None, 89 | device="cuda", 90 | **kwargs, 91 | ): 92 | 93 | super(DreamboothLoraMixup, self).__init__() 94 | 95 | if DreamboothLoraMixup.pipe is None: 96 | 97 | PipelineClass = StableDiffusionImg2ImgPipeline 98 | 99 | DreamboothLoraMixup.pipe = PipelineClass.from_pretrained( 100 | model_path, 101 | use_auth_token=True, 102 | revision=revision, 103 | local_files_only=True, 104 | torch_dtype=torch.float16, 105 | ).to(device) 106 | 107 | scheduler = DPMSolverMultistepScheduler.from_config( 108 | DreamboothLoraMixup.pipe.scheduler.config, local_files_only=True 109 | ) 110 | self.placeholder2name = {} 111 | self.name2placeholder = {} 112 | if embed_path is not None: 113 | self.name2placeholder, self.placeholder2name = load_diffmix_embeddings( 114 | embed_path, 115 | DreamboothLoraMixup.pipe.text_encoder, 116 | DreamboothLoraMixup.pipe.tokenizer, 117 | ) 118 | if lora_path is not None: 119 | DreamboothLoraMixup.pipe.load_lora_weights(lora_path) 120 | DreamboothLoraMixup.pipe.scheduler = scheduler 121 | 122 | print(f"successfuly load lora weights from {lora_path}! ! ! ") 123 | 124 | logging.disable_progress_bar() 125 | self.pipe.set_progress_bar_config(disable=True) 126 | 127 | if disable_safety_checker: 128 | self.pipe.safety_checker = None 129 | 130 | self.prompt = prompt 131 | self.guidance_scale = guidance_scale 132 | self.format_name = format_name 133 | 134 | def forward( 135 | self, 136 | image: Image.Image, 137 | label: int, 138 | metadata: dict, 139 | strength: float = 0.5, 140 | resolution=512, 141 | ) -> Tuple[Image.Image, int]: 142 | 143 | canvas = [img.resize((resolution, resolution), Image.BILINEAR) for img in image] 144 | name = metadata.get("name", "") 145 | 146 | if self.name2placeholder is not None: 147 | name = self.name2placeholder[name] 148 | if metadata.get("super_class", None) is not None: 149 | name = name + " " + metadata.get("super_class", "") 150 | prompt = self.prompt.format(name=name) 151 | 152 | print(prompt) 153 | 154 | kwargs = dict( 155 | image=canvas, 156 | prompt=[prompt], 157 | strength=strength, 158 | guidance_scale=self.guidance_scale, 159 | num_inference_steps=25, 160 | num_images_per_prompt=len(canvas), 161 | ) 162 | 163 | has_nsfw_concept = True 164 | while has_nsfw_concept: 165 | with autocast("cuda"): 166 | outputs = self.pipe(**kwargs) 167 | 168 | has_nsfw_concept = ( 169 | self.pipe.safety_checker is not None 170 | and outputs.nsfw_content_detected[0] 171 | ) 172 | canvas = [] 173 | for orig, out in zip(image, outputs.images): 174 | canvas.append(out.resize(orig.size, Image.BILINEAR)) 175 | return canvas, label 176 | 177 | 178 | class DreamboothLoraGeneration(GenerativeMixup): 179 | 180 | pipe = None # global sharing is a hack to avoid OOM 181 | 182 | def __init__( 183 | self, 184 | lora_path: str, 185 | model_path: str = "runwayml/stable-diffusion-v1-5", 186 | embed_path: str = None, 187 | prompt: str = "a photo of a {name}", 188 | format_name: Callable = format_name, 189 | guidance_scale: float = 7.5, 190 | disable_safety_checker: bool = True, 191 | revision: str = None, 192 | device="cuda", 193 | **kwargs, 194 | ): 195 | 196 | super(DreamboothLoraGeneration, self).__init__() 197 | 198 | if DreamboothLoraGeneration.pipe is None: 199 | 200 | PipelineClass = StableDiffusionPipeline 201 | 202 | DreamboothLoraGeneration.pipe = PipelineClass.from_pretrained( 203 | model_path, 204 | use_auth_token=True, 205 | revision=revision, 206 | local_files_only=True, 207 | torch_dtype=torch.float16, 208 | ).to(device) 209 | 210 | scheduler = DPMSolverMultistepScheduler.from_config( 211 | DreamboothLoraGeneration.pipe.scheduler.config, local_files_only=True 212 | ) 213 | self.placeholder2name = None 214 | self.name2placeholder = None 215 | if embed_path is not None: 216 | self.name2placeholder, self.placeholder2name = load_diffmix_embeddings( 217 | embed_path, 218 | DreamboothLoraGeneration.pipe.text_encoder, 219 | DreamboothLoraGeneration.pipe.tokenizer, 220 | ) 221 | if lora_path is not None: 222 | DreamboothLoraGeneration.pipe.load_lora_weights(lora_path) 223 | DreamboothLoraGeneration.pipe.scheduler = scheduler 224 | 225 | print(f"successfuly load lora weights from {lora_path}! ! ! ") 226 | 227 | logging.disable_progress_bar() 228 | self.pipe.set_progress_bar_config(disable=True) 229 | 230 | if disable_safety_checker: 231 | self.pipe.safety_checker = None 232 | 233 | self.prompt = prompt 234 | self.guidance_scale = guidance_scale 235 | self.format_name = format_name 236 | 237 | def forward( 238 | self, 239 | image: Image.Image, 240 | label: int, 241 | metadata: dict, 242 | strength: float = 0.5, 243 | resolution=512, 244 | ) -> Tuple[Image.Image, int]: 245 | 246 | name = metadata.get("name", "") 247 | 248 | if self.name2placeholder is not None: 249 | name = self.name2placeholder[name] 250 | if metadata.get("super_class", None) is not None: 251 | name = name + " " + metadata.get("super_class", "") 252 | prompt = self.prompt.format(name=name) 253 | 254 | print(prompt) 255 | 256 | kwargs = dict( 257 | prompt=[prompt], 258 | guidance_scale=self.guidance_scale, 259 | num_inference_steps=25, 260 | num_images_per_prompt=len(image), 261 | height=resolution, 262 | width=resolution, 263 | ) 264 | 265 | has_nsfw_concept = True 266 | while has_nsfw_concept: 267 | with autocast("cuda"): 268 | outputs = self.pipe(**kwargs) 269 | 270 | has_nsfw_concept = ( 271 | self.pipe.safety_checker is not None 272 | and outputs.nsfw_content_detected[0] 273 | ) 274 | 275 | canvas = [] 276 | for out in outputs.images: 277 | canvas.append(out) 278 | return canvas, label 279 | -------------------------------------------------------------------------------- /augmentation/real_mix.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Any, Callable, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image, ImageOps 8 | from scipy.ndimage import maximum_filter 9 | from torch import autocast 10 | 11 | from augmentation.base_augmentation import GenerativeAugmentation, GenerativeMixup 12 | from diffusers import ( 13 | DPMSolverMultistepScheduler, 14 | StableDiffusionImg2ImgPipeline, 15 | StableDiffusionInpaintPipeline, 16 | StableDiffusionPipeline, 17 | ) 18 | from diffusers.utils import logging 19 | 20 | 21 | def format_name(name): 22 | return f"<{name.replace(' ', '_')}>" 23 | 24 | 25 | class RealGeneration(GenerativeMixup): 26 | 27 | pipe = None # global sharing is a hack to avoid OOM 28 | 29 | def __init__( 30 | self, 31 | model_path: str = "runwayml/stable-diffusion-v1-5", 32 | prompt: str = "a photo of a {name}", 33 | format_name: Callable = format_name, 34 | guidance_scale: float = 7.5, 35 | mask: bool = False, 36 | inverted: bool = False, 37 | mask_grow_radius: int = 16, 38 | disable_safety_checker: bool = True, 39 | revision: str = None, 40 | device="cuda", 41 | **kwargs, 42 | ): 43 | 44 | super(RealGeneration, self).__init__() 45 | 46 | if RealGeneration.pipe is None: 47 | 48 | PipelineClass = StableDiffusionPipeline 49 | 50 | RealGeneration.pipe = PipelineClass.from_pretrained( 51 | model_path, 52 | use_auth_token=True, 53 | revision=revision, 54 | local_files_only=True, 55 | torch_dtype=torch.float16, 56 | ).to(device) 57 | scheduler = DPMSolverMultistepScheduler.from_config( 58 | RealGeneration.pipe.scheduler.config 59 | ) 60 | RealGeneration.pipe.scheduler = scheduler 61 | logging.disable_progress_bar() 62 | self.pipe.set_progress_bar_config(disable=True) 63 | 64 | if disable_safety_checker: 65 | self.pipe.safety_checker = None 66 | 67 | self.prompt = prompt 68 | self.guidance_scale = guidance_scale 69 | self.format_name = format_name 70 | 71 | self.mask = mask 72 | self.inverted = inverted 73 | self.mask_grow_radius = mask_grow_radius 74 | 75 | def forward( 76 | self, 77 | image: Image.Image, 78 | label: int, 79 | metadata: dict, 80 | strength: float = 0.5, 81 | resolution=512, 82 | ) -> Tuple[Image.Image, int]: 83 | 84 | name = self.format_name(metadata.get("name", "")) 85 | prompt = self.prompt.format(name=name) 86 | 87 | if self.mask: 88 | assert "mask" in metadata, "mask=True but no mask present in metadata" 89 | 90 | # word_name = metadata.get("name", "").replace(" ", "") 91 | 92 | kwargs = dict( 93 | prompt=[prompt], 94 | guidance_scale=self.guidance_scale, 95 | num_inference_steps=25, 96 | num_images_per_prompt=len(image), 97 | height=resolution, 98 | width=resolution, 99 | ) 100 | 101 | if self.mask: # use focal object mask 102 | # TODO 103 | mask_image = metadata["mask"].resize((512, 512), Image.NEAREST) 104 | mask_image = Image.fromarray( 105 | maximum_filter(np.array(mask_image), size=self.mask_grow_radius) 106 | ) 107 | 108 | if self.inverted: 109 | 110 | mask_image = ImageOps.invert(mask_image.convert("L")).convert("1") 111 | 112 | kwargs["mask_image"] = mask_image 113 | 114 | has_nsfw_concept = True 115 | while has_nsfw_concept: 116 | with autocast("cuda"): 117 | outputs = self.pipe(**kwargs) 118 | 119 | has_nsfw_concept = ( 120 | self.pipe.safety_checker is not None 121 | and outputs.nsfw_content_detected[0] 122 | ) 123 | 124 | canvas = [] 125 | for orig, out in zip(image, outputs.images): 126 | canvas.append(out.resize(orig.size, Image.BILINEAR)) 127 | 128 | return canvas, label 129 | 130 | 131 | class RealGuidance(GenerativeAugmentation): 132 | 133 | pipe = None # global sharing is a hack to avoid OOM 134 | 135 | def __init__( 136 | self, 137 | model_path: str = "runwayml/stable-diffusion-v1-5", 138 | prompt: str = "a photo of a {name}", 139 | strength: float = 0.5, 140 | guidance_scale: float = 7.5, 141 | mask: bool = False, 142 | inverted: bool = False, 143 | mask_grow_radius: int = 16, 144 | erasure_ckpt_path: str = None, 145 | disable_safety_checker: bool = True, 146 | **kwargs, 147 | ): 148 | 149 | super(RealGuidance, self).__init__() 150 | 151 | if RealGuidance.pipe is None: 152 | 153 | PipelineClass = ( 154 | StableDiffusionInpaintPipeline 155 | if mask 156 | else StableDiffusionImg2ImgPipeline 157 | ) 158 | 159 | self.pipe = PipelineClass.from_pretrained( 160 | model_path, 161 | use_auth_token=True, 162 | revision="fp16", 163 | torch_dtype=torch.float16, 164 | ).to("cuda") 165 | 166 | logging.disable_progress_bar() 167 | self.pipe.set_progress_bar_config(disable=True) 168 | 169 | if disable_safety_checker: 170 | self.pipe.safety_checker = None 171 | 172 | self.prompt = prompt 173 | self.strength = strength 174 | self.guidance_scale = guidance_scale 175 | 176 | self.mask = mask 177 | self.inverted = inverted 178 | self.mask_grow_radius = mask_grow_radius 179 | 180 | self.erasure_ckpt_path = erasure_ckpt_path 181 | self.erasure_word_name = None 182 | 183 | def forward( 184 | self, image: Image.Image, label: int, metadata: dict 185 | ) -> Tuple[Image.Image, int]: 186 | 187 | canvas = image.resize((512, 512), Image.BILINEAR) 188 | prompt = self.prompt.format(name=metadata.get("name", "")) 189 | 190 | if self.mask: 191 | assert "mask" in metadata, "mask=True but no mask present in metadata" 192 | 193 | word_name = metadata.get("name", "").replace(" ", "") 194 | 195 | if self.erasure_ckpt_path is not None and ( 196 | self.erasure_word_name is None or self.erasure_word_name != word_name 197 | ): 198 | 199 | self.erasure_word_name = word_name 200 | ckpt_name = "method_full-sg_3-ng_1-iter_1000-lr_1e-05" 201 | 202 | ckpt_path = os.path.join( 203 | self.erasure_ckpt_path, 204 | f"compvis-word_{word_name}-{ckpt_name}", 205 | f"diffusers-word_{word_name}-{ckpt_name}.pt", 206 | ) 207 | 208 | self.pipe.unet.load_state_dict(torch.load(ckpt_path, map_location="cuda")) 209 | 210 | kwargs = dict( 211 | image=canvas, 212 | prompt=[prompt], 213 | strength=self.strength, 214 | guidance_scale=self.guidance_scale, 215 | ) 216 | 217 | if self.mask: # use focal object mask 218 | 219 | mask_image = Image.fromarray( 220 | (np.where(metadata["mask"], 255, 0)).astype(np.uint8) 221 | ).resize((512, 512), Image.NEAREST) 222 | 223 | mask_image = Image.fromarray( 224 | maximum_filter(np.array(mask_image), size=self.mask_grow_radius) 225 | ) 226 | 227 | if self.inverted: 228 | 229 | mask_image = ImageOps.invert(mask_image.convert("L")).convert("1") 230 | 231 | kwargs["mask_image"] = mask_image 232 | 233 | has_nsfw_concept = True 234 | while has_nsfw_concept: 235 | with autocast("cuda"): 236 | outputs = self.pipe(**kwargs) 237 | 238 | has_nsfw_concept = ( 239 | self.pipe.safety_checker is not None 240 | and outputs.nsfw_content_detected[0] 241 | ) 242 | 243 | canvas = outputs.images[0].resize(image.size, Image.BILINEAR) 244 | 245 | return canvas, label 246 | -------------------------------------------------------------------------------- /augmentation/ti_mix.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image, ImageOps 6 | from scipy.ndimage import maximum_filter 7 | from torch import autocast 8 | from transformers import CLIPTextModel, CLIPTokenizer 9 | 10 | from augmentation.base_augmentation import GenerativeMixup 11 | from diffusers import ( 12 | DPMSolverMultistepScheduler, 13 | StableDiffusionImg2ImgPipeline, 14 | StableDiffusionInpaintPipeline, 15 | ) 16 | from diffusers.utils import logging 17 | 18 | ERROR_MESSAGE = "Tokenizer already contains the token {token}. \ 19 | Please pass a different `token` that is not already in the tokenizer." 20 | 21 | 22 | def load_embeddings( 23 | embed_path: str, 24 | model_path: str = "runwayml/stable-diffusion-v1-5", 25 | revision: str = "39593d5650112b4cc580433f6b0435385882d819", 26 | ): 27 | 28 | tokenizer = CLIPTokenizer.from_pretrained( 29 | model_path, use_auth_token=True, revision=revision, subfolder="tokenizer" 30 | ) 31 | 32 | text_encoder = CLIPTextModel.from_pretrained( 33 | model_path, use_auth_token=True, revision=revision, subfolder="text_encoder" 34 | ) 35 | 36 | for token, token_embedding in torch.load(embed_path, map_location="cpu").items(): 37 | 38 | # add the token in tokenizer 39 | num_added_tokens = tokenizer.add_tokens(token) 40 | assert num_added_tokens > 0, ERROR_MESSAGE.format(token=token) 41 | 42 | # resize the token embeddings 43 | text_encoder.resize_token_embeddings(len(tokenizer)) 44 | added_token_id = tokenizer.convert_tokens_to_ids(token) 45 | 46 | # get the old word embeddings 47 | embeddings = text_encoder.get_input_embeddings() 48 | 49 | # get the id for the token and assign new embeds 50 | embeddings.weight.data[added_token_id] = token_embedding.to( 51 | embeddings.weight.dtype 52 | ) 53 | 54 | return tokenizer, text_encoder 55 | 56 | 57 | def format_name(name): 58 | return f"<{name.replace(' ', '_')}>" 59 | 60 | 61 | class TextualInversionMixup(GenerativeMixup): 62 | 63 | pipe = None # global sharing is a hack to avoid OOM 64 | 65 | def __init__( 66 | self, 67 | embed_path: str, 68 | model_path: str = "runwayml/stable-diffusion-v1-5", 69 | prompt: str = "a photo of a {name}", 70 | format_name: Callable = format_name, 71 | guidance_scale: float = 7.5, 72 | mask: bool = False, 73 | inverted: bool = False, 74 | mask_grow_radius: int = 16, 75 | disable_safety_checker: bool = True, 76 | revision: str = "39593d5650112b4cc580433f6b0435385882d819", 77 | device="cuda", 78 | **kwargs, 79 | ): 80 | 81 | super().__init__() 82 | 83 | if TextualInversionMixup.pipe is None: 84 | 85 | PipelineClass = ( 86 | StableDiffusionInpaintPipeline 87 | if mask 88 | else StableDiffusionImg2ImgPipeline 89 | ) 90 | 91 | tokenizer, text_encoder = load_embeddings( 92 | embed_path, model_path=model_path, revision=revision 93 | ) 94 | 95 | TextualInversionMixup.pipe = PipelineClass.from_pretrained( 96 | model_path, 97 | use_auth_token=True, 98 | revision=revision, 99 | torch_dtype=torch.float16, 100 | ).to(device) 101 | scheduler = DPMSolverMultistepScheduler.from_config( 102 | TextualInversionMixup.pipe.scheduler.config 103 | ) 104 | TextualInversionMixup.pipe.scheduler = scheduler 105 | self.pipe.tokenizer = tokenizer 106 | self.pipe.text_encoder = text_encoder.to(device) 107 | 108 | logging.disable_progress_bar() 109 | self.pipe.set_progress_bar_config(disable=True) 110 | 111 | if disable_safety_checker: 112 | self.pipe.safety_checker = None 113 | 114 | self.prompt = prompt 115 | self.guidance_scale = guidance_scale 116 | self.format_name = format_name 117 | 118 | self.mask = mask 119 | self.inverted = inverted 120 | self.mask_grow_radius = mask_grow_radius 121 | 122 | self.erasure_word_name = None 123 | 124 | def forward( 125 | self, image: Image.Image, label: int, metadata: dict, strength: float = 0.5 126 | ) -> Tuple[Image.Image, int]: 127 | 128 | canvas = image.resize((512, 512), Image.BILINEAR) 129 | name = self.format_name(metadata.get("name", "")) 130 | prompt = self.prompt.format(name=name) 131 | 132 | if self.mask: 133 | assert "mask" in metadata, "mask=True but no mask present in metadata" 134 | 135 | word_name = metadata.get("name", "").replace(" ", "") 136 | 137 | kwargs = dict( 138 | image=canvas, 139 | prompt=[prompt], 140 | strength=strength, 141 | guidance_scale=self.guidance_scale, 142 | ) 143 | 144 | if self.mask: # use focal object mask 145 | # TODO 146 | mask_image = Image.fromarray( 147 | (np.where(metadata["mask"], 255, 0)).astype(np.uint8) 148 | ).resize((512, 512), Image.NEAREST) 149 | 150 | mask_image = Image.fromarray( 151 | maximum_filter(np.array(mask_image), size=self.mask_grow_radius) 152 | ) 153 | 154 | if self.inverted: 155 | 156 | mask_image = ImageOps.invert(mask_image.convert("L")).convert("1") 157 | 158 | kwargs["mask_image"] = mask_image 159 | 160 | has_nsfw_concept = True 161 | while has_nsfw_concept: 162 | with autocast("cuda"): 163 | outputs = self.pipe(**kwargs) 164 | 165 | has_nsfw_concept = ( 166 | self.pipe.safety_checker is not None 167 | and outputs.nsfw_content_detected[0] 168 | ) 169 | 170 | canvas = outputs.images[0].resize(image.size, Image.BILINEAR) 171 | 172 | return canvas, label 173 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .instance.aircraft import * 2 | from .instance.car import * 3 | from .instance.cub import * 4 | from .instance.dog import * 5 | from .instance.flower import * 6 | from .instance.food import * 7 | from .instance.pascal import * 8 | from .instance.pet import * 9 | from .instance.waterbird import * 10 | 11 | DATASET_NAME_MAPPING = { 12 | "cub": CUBBirdHugDataset, 13 | "flower": Flowers102Dataset, 14 | "car": CarHugDataset, 15 | "pet": PetHugDataset, 16 | "aircraft": AircraftHugDataset, 17 | "food": FoodHugDataset, 18 | "pascal": PascalDataset, 19 | "dog": StanfordDogDataset, 20 | } 21 | IMBALANCE_DATASET_NAME_MAPPING = { 22 | "cub": CUBBirdHugImbalanceDataset, 23 | "flower": FlowersImbalanceDataset, 24 | } 25 | T2I_DATASET_NAME_MAPPING = { 26 | "cub": CUBBirdHugDatasetForT2I, 27 | "flower": FlowersDatasetForT2I, 28 | } 29 | T2I_IMBALANCE_DATASET_NAME_MAPPING = { 30 | "cub": CUBBirdHugImbalanceDatasetForT2I, 31 | "flower": FlowersImbalanceDatasetForT2I, 32 | } 33 | -------------------------------------------------------------------------------- /dataset/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import math 3 | import os 4 | import random 5 | from typing import Any, List, Tuple, Union 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | import torchvision.transforms as transforms 11 | from PIL import Image 12 | from torch.utils.data import Dataset 13 | 14 | 15 | def onehot(size, target): 16 | vec = torch.zeros(size, dtype=torch.float32) 17 | vec[target] = 1.0 18 | return vec 19 | 20 | 21 | class SyntheticDataset(Dataset): 22 | def __init__( 23 | self, 24 | synthetic_dir: Union[str, List[str]] = None, 25 | gamma: int = 1, 26 | soft_scaler: float = 1, 27 | num_syn_seeds: int = 999, 28 | image_size: int = 512, 29 | crop_size: int = 448, 30 | class2label: dict = None, 31 | ) -> None: 32 | super().__init__() 33 | 34 | self.synthetic_dir = synthetic_dir 35 | self.num_syn_seeds = num_syn_seeds # number of seeds to generate synthetic data 36 | self.gamma = gamma 37 | self.soft_scaler = soft_scaler 38 | self.class_names = None 39 | 40 | self.parse_syn_data_pd(synthetic_dir) 41 | 42 | test_transform = transforms.Compose( 43 | [ 44 | transforms.Resize((image_size, image_size)), 45 | transforms.CenterCrop((crop_size, crop_size)), 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 48 | ] 49 | ) 50 | 51 | self.transform = test_transform 52 | self.class2label = ( 53 | {name: i for i, name in enumerate(self.class_names)} 54 | if class2label is None 55 | else class2label 56 | ) 57 | self.num_classes = len(self.class2label.keys()) 58 | 59 | def set_transform(self, transform) -> None: 60 | self.transform = transform 61 | 62 | def parse_syn_data_pd(self, synthetic_dir) -> None: 63 | if isinstance(synthetic_dir, list): 64 | pass 65 | elif isinstance(synthetic_dir, str): 66 | synthetic_dir = [synthetic_dir] 67 | else: 68 | raise NotImplementedError("Not supported type") 69 | meta_df_list = [] 70 | 71 | for _dir in synthetic_dir: 72 | meta_dir = os.path.join(_dir, self.csv_file) 73 | meta_df = pd.read_csv(meta_dir) 74 | meta_df.loc[:, "Path"] = meta_df["Path"].apply( 75 | lambda x: os.path.join(_dir, "data", x) 76 | ) 77 | meta_df_list.append(meta_df) 78 | self.meta_df = pd.concat(meta_df_list).reset_index(drop=True) 79 | 80 | self.syn_nums = len(self.meta_df) 81 | self.class_names = list(set(self.meta_df["First Directory"].values)) 82 | print(f"Syn numbers: {self.syn_nums}\n") 83 | 84 | def get_syn_item_raw(self, idx: int): 85 | df_data = self.meta_df.iloc[idx] 86 | src_label = self.class2label[df_data["First Directory"]] 87 | tar_label = self.class2label[df_data["Second Directory"]] 88 | path = df_data["Path"] 89 | return path, src_label, tar_label 90 | 91 | def __len__(self) -> int: 92 | return self.syn_nums 93 | 94 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: 95 | path, src_label, target_label = self.get_syn_item_raw(idx) 96 | image = Image.open(path).convert("RGB") 97 | return { 98 | "pixel_values": self.transform(image), 99 | "src_label": src_label, 100 | "tar_label": target_label, 101 | } 102 | 103 | 104 | class HugFewShotDataset(Dataset): 105 | 106 | num_classes: int = None 107 | class_names: int = None 108 | class2label: dict = None 109 | label2class: dict = None 110 | 111 | def __init__( 112 | self, 113 | split: str = "train", 114 | examples_per_class: int = None, 115 | synthetic_probability: float = 0.5, 116 | return_onehot: bool = False, 117 | soft_scaler: float = 1, 118 | synthetic_dir: Union[str, List[str]] = None, 119 | image_size: int = 512, 120 | crop_size: int = 448, 121 | gamma: int = 1, 122 | num_syn_seeds: int = 99999, 123 | clip_filtered_syn: bool = False, 124 | target_class_num: int = None, 125 | **kwargs, 126 | ): 127 | 128 | self.examples_per_class = examples_per_class 129 | self.num_syn_seeds = num_syn_seeds # number of seeds to generate synthetic data 130 | 131 | self.synthetic_dir = synthetic_dir 132 | self.clip_filtered_syn = clip_filtered_syn 133 | self.return_onehot = return_onehot 134 | 135 | if self.synthetic_dir is not None: 136 | assert self.return_onehot == True 137 | self.synthetic_probability = synthetic_probability 138 | self.soft_scaler = soft_scaler 139 | self.gamma = gamma 140 | self.target_class_num = target_class_num 141 | self.parse_syn_data_pd(synthetic_dir) 142 | 143 | train_transform = transforms.Compose( 144 | [ 145 | transforms.Resize((image_size, image_size)), 146 | transforms.RandomCrop(crop_size, padding=8), 147 | transforms.RandomHorizontalFlip(), 148 | transforms.ToTensor(), 149 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 150 | ] 151 | ) 152 | test_transform = transforms.Compose( 153 | [ 154 | transforms.Resize((image_size, image_size)), 155 | transforms.CenterCrop((crop_size, crop_size)), 156 | transforms.ToTensor(), 157 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 158 | ] 159 | ) 160 | self.transform = {"train": train_transform, "val": test_transform}[split] 161 | 162 | def set_transform(self, transform) -> None: 163 | self.transform = transform 164 | 165 | @abc.abstractmethod 166 | def get_image_by_idx(self, idx: int) -> Image.Image: 167 | 168 | return NotImplemented 169 | 170 | @abc.abstractmethod 171 | def get_label_by_idx(self, idx: int) -> int: 172 | 173 | return NotImplemented 174 | 175 | @abc.abstractmethod 176 | def get_metadata_by_idx(self, idx: int) -> dict: 177 | 178 | return NotImplemented 179 | 180 | def parse_syn_data_pd(self, synthetic_dir, filter=True) -> None: 181 | if isinstance(synthetic_dir, list): 182 | pass 183 | elif isinstance(synthetic_dir, str): 184 | synthetic_dir = [synthetic_dir] 185 | else: 186 | raise NotImplementedError("Not supported type") 187 | meta_df_list = [] 188 | for _dir in synthetic_dir: 189 | df_basename = ( 190 | "meta.csv" if not self.clip_filtered_syn else "remained_meta.csv" 191 | ) 192 | meta_dir = os.path.join(_dir, df_basename) 193 | meta_df = self.filter_df(pd.read_csv(meta_dir)) 194 | meta_df.loc[:, "Path"] = meta_df["Path"].apply( 195 | lambda x: os.path.join(_dir, "data", x) 196 | ) 197 | meta_df_list.append(meta_df) 198 | self.meta_df = pd.concat(meta_df_list).reset_index(drop=True) 199 | self.syn_nums = len(self.meta_df) 200 | 201 | print(f"Syn numbers: {self.syn_nums}\n") 202 | 203 | def filter_df(self, df: pd.DataFrame) -> pd.DataFrame: 204 | 205 | if self.target_class_num is not None: 206 | selected_indexs = [] 207 | for source_name in self.class_names: 208 | target_classes = random.sample(self.class_names, self.target_class_num) 209 | indexs = df[ 210 | (df["First Directory"] == source_name) 211 | & (df["Second Directory"].isin(target_classes)) 212 | ] 213 | selected_indexs.append(indexs) 214 | 215 | meta2 = pd.concat(selected_indexs, axis=0) 216 | total_num = min(len(meta2), 18000) 217 | idxs = random.sample(range(len(meta2)), total_num) 218 | meta2 = meta2.iloc[idxs] 219 | meta2.reset_index(drop=True, inplace=True) 220 | df = meta2 221 | print("filter_df", self.target_class_num, len(df)) 222 | return df 223 | 224 | def get_syn_item(self, idx: int): 225 | 226 | df_data = self.meta_df.iloc[idx] 227 | src_label = self.class2label[df_data["First Directory"]] 228 | tar_label = self.class2label[df_data["Second Directory"]] 229 | path = df_data["Path"] 230 | strength = df_data["Strength"] 231 | onehot_label = torch.zeros(self.num_classes) 232 | onehot_label[src_label] += self.soft_scaler * ( 233 | 1 - math.pow(strength, self.gamma) 234 | ) 235 | onehot_label[tar_label] += self.soft_scaler * math.pow(strength, self.gamma) 236 | 237 | return path, onehot_label 238 | 239 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: 240 | 241 | if ( 242 | self.synthetic_dir is not None 243 | and np.random.uniform() < self.synthetic_probability 244 | ): 245 | syn_idx = np.random.choice(self.syn_nums) 246 | path, label = self.get_syn_item(syn_idx) 247 | image = Image.open(path).convert("RGB") 248 | else: 249 | image = self.get_image_by_idx(idx) 250 | label = self.get_label_by_idx(idx) 251 | 252 | if self.return_onehot: 253 | if isinstance(label, (int, np.int64)): 254 | label = onehot(self.num_classes, label) 255 | return dict(pixel_values=self.transform(image), label=label) 256 | -------------------------------------------------------------------------------- /dataset/instance/aircraft.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | from typing import Any, Dict, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision.transforms as transforms 8 | from datasets import load_dataset 9 | from PIL import Image 10 | 11 | from dataset.base import HugFewShotDataset 12 | from dataset.template import IMAGENET_TEMPLATES_TINY 13 | 14 | SUPER_CLASS_NAME = "aircraft" 15 | HUG_IMAGE_TRAIN_DIR = r"Multimodal-Fatima/FGVC_Aircraft_train" 16 | HUG_IMAGE_TEST_DIR = r"Multimodal-Fatima/FGVC_Aircraft_test" 17 | HUG_LOCAL_IMAGE_TRAIN_DIR = "/home/zhicai/.cache/huggingface/datasets/Multimodal-Fatima___parquet/Multimodal-Fatima--FGVC_Aircraft_train-d13c51225819e71f" 18 | HUG_LOCAL_IMAGE_TEST_DIR = "/home/zhicai/.cache/huggingface/datasets/Multimodal-Fatima___parquet/Multimodal-Fatima--FGVC_Aircraft_test-2d1cae4ba1777be1" 19 | 20 | 21 | class AircraftHugDataset(HugFewShotDataset): 22 | 23 | super_class_name = SUPER_CLASS_NAME 24 | 25 | def __init__( 26 | self, 27 | *args, 28 | split: str = "train", 29 | seed: int = 0, 30 | image_train_dir: str = HUG_LOCAL_IMAGE_TRAIN_DIR, 31 | image_test_dir: str = HUG_LOCAL_IMAGE_TEST_DIR, 32 | examples_per_class: int = -1, 33 | synthetic_probability: float = 0.5, 34 | return_onehot: bool = False, 35 | soft_scaler: float = 0.9, 36 | synthetic_dir: str = None, 37 | image_size: int = 512, 38 | crop_size: int = 448, 39 | **kwargs, 40 | ): 41 | 42 | super().__init__( 43 | *args, 44 | split=split, 45 | examples_per_class=examples_per_class, 46 | synthetic_probability=synthetic_probability, 47 | return_onehot=return_onehot, 48 | soft_scaler=soft_scaler, 49 | synthetic_dir=synthetic_dir, 50 | image_size=image_size, 51 | crop_size=crop_size, 52 | **kwargs, 53 | ) 54 | 55 | if split == "train": 56 | dataset = load_dataset(HUG_LOCAL_IMAGE_TRAIN_DIR, split="train") 57 | else: 58 | dataset = load_dataset(HUG_LOCAL_IMAGE_TEST_DIR, split="test") 59 | 60 | # self.class_names = [name.replace('/',' ') for name in dataset.features['label'].names] 61 | 62 | random.seed(seed) 63 | np.random.seed(seed) 64 | if examples_per_class is not None and examples_per_class > 0: 65 | all_labels = dataset["label"] 66 | label_to_indices = defaultdict(list) 67 | for i, label in enumerate(all_labels): 68 | label_to_indices[label].append(i) 69 | 70 | _all_indices = [] 71 | for key, items in label_to_indices.items(): 72 | try: 73 | sampled_indices = random.sample(items, examples_per_class) 74 | except ValueError: 75 | print( 76 | f"{key}: Sample larger than population or is negative, use random.choices instead" 77 | ) 78 | sampled_indices = random.choices(items, k=examples_per_class) 79 | 80 | label_to_indices[key] = sampled_indices 81 | _all_indices.extend(sampled_indices) 82 | dataset = dataset.select(_all_indices) 83 | 84 | self.dataset = dataset 85 | class2label = self.dataset.features["label"]._str2int 86 | self.class2label = {k.replace("/", " "): v for k, v in class2label.items()} 87 | self.label2class = {v: k.replace("/", " ") for k, v in class2label.items()} 88 | self.class_names = [ 89 | name.replace("/", " ") for name in dataset.features["label"].names 90 | ] 91 | self.num_classes = len(self.class_names) 92 | 93 | self.label_to_indices = defaultdict(list) 94 | for i, label in enumerate(self.dataset["label"]): 95 | self.label_to_indices[label].append(i) 96 | 97 | def __len__(self): 98 | 99 | return len(self.dataset) 100 | 101 | def get_image_by_idx(self, idx: int) -> Image.Image: 102 | 103 | return self.dataset[idx]["image"].convert("RGB") 104 | 105 | def get_label_by_idx(self, idx: int) -> int: 106 | 107 | return self.dataset[idx]["label"] 108 | 109 | def get_metadata_by_idx(self, idx: int) -> dict: 110 | 111 | return dict( 112 | name=self.label2class[self.get_label_by_idx(idx)], 113 | super_class=self.super_class_name, 114 | ) 115 | 116 | 117 | class AircraftHugDatasetForT2I(torch.utils.data.Dataset): 118 | 119 | super_class_name = SUPER_CLASS_NAME 120 | 121 | def __init__( 122 | self, 123 | *args, 124 | split: str = "train", 125 | seed: int = 0, 126 | image_train_dir: str = HUG_LOCAL_IMAGE_TRAIN_DIR, 127 | max_train_samples: int = -1, 128 | class_prompts_ratio: float = 0.5, 129 | resolution: int = 512, 130 | center_crop: bool = False, 131 | random_flip: bool = False, 132 | use_placeholder: bool = False, 133 | examples_per_class: int = -1, 134 | **kwargs, 135 | ): 136 | 137 | super().__init__() 138 | 139 | dataset = load_dataset( 140 | "/home/zhicai/.cache/huggingface/datasets/Multimodal-Fatima___parquet/Multimodal-Fatima--FGVC_Aircraft_train-d13c51225819e71f", 141 | split="train", 142 | ) 143 | 144 | random.seed(seed) 145 | np.random.seed(seed) 146 | if max_train_samples is not None and max_train_samples > 0: 147 | dataset = dataset.shuffle(seed=seed).select(range(max_train_samples)) 148 | if examples_per_class is not None and examples_per_class > 0: 149 | all_labels = dataset["label"] 150 | label_to_indices = defaultdict(list) 151 | for i, label in enumerate(all_labels): 152 | label_to_indices[label].append(i) 153 | 154 | _all_indices = [] 155 | for key, items in label_to_indices.items(): 156 | try: 157 | sampled_indices = random.sample(items, examples_per_class) 158 | except ValueError: 159 | print( 160 | f"{key}: Sample larger than population or is negative, use random.choices instead" 161 | ) 162 | sampled_indices = random.choices(items, k=examples_per_class) 163 | 164 | label_to_indices[key] = sampled_indices 165 | _all_indices.extend(sampled_indices) 166 | dataset = dataset.select(_all_indices) 167 | self.dataset = dataset 168 | class2label = self.dataset.features["label"]._str2int 169 | self.class2label = {k.replace("/", " "): v for k, v in class2label.items()} 170 | self.label2class = {v: k.replace("/", " ") for k, v in class2label.items()} 171 | self.class_names = [ 172 | name.replace("/", " ") for name in dataset.features["label"].names 173 | ] 174 | self.num_classes = len(self.class_names) 175 | self.use_placeholder = use_placeholder 176 | self.name2placeholder = {} 177 | self.placeholder2name = {} 178 | self.label_to_indices = defaultdict(list) 179 | for i, label in enumerate(self.dataset["label"]): 180 | self.label_to_indices[label].append(i) 181 | 182 | self.transform = transforms.Compose( 183 | [ 184 | transforms.Resize( 185 | resolution, interpolation=transforms.InterpolationMode.BILINEAR 186 | ), 187 | ( 188 | transforms.CenterCrop(resolution) 189 | if center_crop 190 | else transforms.RandomCrop(resolution) 191 | ), 192 | ( 193 | transforms.RandomHorizontalFlip() 194 | if random_flip 195 | else transforms.Lambda(lambda x: x) 196 | ), 197 | transforms.ToTensor(), 198 | transforms.Normalize([0.5], [0.5]), 199 | ] 200 | ) 201 | 202 | def __len__(self): 203 | 204 | return len(self.dataset) 205 | 206 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: 207 | 208 | image = self.get_image_by_idx(idx) 209 | prompt = self.get_prompt_by_idx(idx) 210 | 211 | return dict(pixel_values=self.transform(image), caption=prompt) 212 | 213 | def get_image_by_idx(self, idx: int) -> Image.Image: 214 | 215 | return self.dataset[idx]["image"].convert("RGB") 216 | 217 | def get_label_by_idx(self, idx: int) -> int: 218 | 219 | return self.dataset[idx]["label"] 220 | 221 | def get_prompt_by_idx(self, idx: int) -> int: 222 | # randomly choose from class name or description 223 | if self.use_placeholder: 224 | content = ( 225 | self.name2placeholder[self.label2class[self.dataset[idx]["label"]]] 226 | + f" {self.super_class_name}" 227 | ) 228 | else: 229 | content = self.label2class[self.dataset[idx]["label"]] 230 | prompt = random.choice(IMAGENET_TEMPLATES_TINY).format(content) 231 | 232 | return prompt 233 | 234 | def get_metadata_by_idx(self, idx: int) -> dict: 235 | 236 | return dict( 237 | name=self.label2class[self.get_label_by_idx(idx)], 238 | super_class=self.super_class_name, 239 | ) 240 | -------------------------------------------------------------------------------- /dataset/instance/car.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | from typing import Any, Dict, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision.transforms as transforms 8 | from datasets import load_dataset 9 | from PIL import Image 10 | 11 | from dataset.base import HugFewShotDataset 12 | from dataset.template import IMAGENET_TEMPLATES_TINY 13 | 14 | SUPER_CLASS_NAME = "car" 15 | HUG_LOCAL_IMAGE_TRAIN_DIR = r"Multimodal-Fatima/StanfordCars_train" 16 | HUG_LOCAL_IMAGE_TEST_DIR = r"Multimodal-Fatima/StanfordCars_test" 17 | 18 | 19 | class CarHugDataset(HugFewShotDataset): 20 | 21 | super_class_name = SUPER_CLASS_NAME 22 | 23 | def __init__( 24 | self, 25 | *args, 26 | split: str = "train", 27 | seed: int = 0, 28 | image_train_dir: str = HUG_LOCAL_IMAGE_TRAIN_DIR, 29 | image_test_dir: str = HUG_LOCAL_IMAGE_TEST_DIR, 30 | examples_per_class: int = -1, 31 | synthetic_probability: float = 0.5, 32 | return_onehot: bool = False, 33 | soft_scaler: float = 0.9, 34 | synthetic_dir: str = None, 35 | image_size: int = 512, 36 | crop_size: int = 448, 37 | **kwargs, 38 | ): 39 | 40 | super(CarHugDataset, self).__init__( 41 | *args, 42 | split=split, 43 | examples_per_class=examples_per_class, 44 | synthetic_probability=synthetic_probability, 45 | return_onehot=return_onehot, 46 | soft_scaler=soft_scaler, 47 | synthetic_dir=synthetic_dir, 48 | image_size=image_size, 49 | crop_size=crop_size, 50 | **kwargs, 51 | ) 52 | 53 | if split == "train": 54 | dataset = load_dataset( 55 | "/data/zhicai/cache/huggingface/datasets/Multimodal-Fatima___stanford_cars_train", 56 | split="train", 57 | ) 58 | else: 59 | dataset = load_dataset( 60 | "/data/zhicai/cache/huggingface/datasets/Multimodal-Fatima___stanford_cars_test", 61 | split="test", 62 | ) 63 | 64 | self.class_names = [ 65 | name.replace("/", " ") for name in dataset.features["label"].names 66 | ] 67 | 68 | random.seed(seed) 69 | np.random.seed(seed) 70 | if examples_per_class is not None and examples_per_class > 0: 71 | all_labels = dataset["label"] 72 | label_to_indices = defaultdict(list) 73 | for i, label in enumerate(all_labels): 74 | label_to_indices[label].append(i) 75 | 76 | _all_indices = [] 77 | for key, items in label_to_indices.items(): 78 | try: 79 | sampled_indices = random.sample(items, examples_per_class) 80 | except ValueError: 81 | print( 82 | f"{key}: Sample larger than population or is negative, use random.choices instead" 83 | ) 84 | sampled_indices = random.choices(items, k=examples_per_class) 85 | 86 | label_to_indices[key] = sampled_indices 87 | _all_indices.extend(sampled_indices) 88 | dataset = dataset.select(_all_indices) 89 | 90 | self.dataset = dataset 91 | class2label = self.dataset.features["label"]._str2int 92 | self.class2label = {k.replace("/", " "): v for k, v in class2label.items()} 93 | self.label2class = {v: k.replace("/", " ") for k, v in class2label.items()} 94 | self.class_names = [ 95 | name.replace("/", " ") for name in dataset.features["label"].names 96 | ] 97 | self.num_classes = len(self.class_names) 98 | 99 | self.label_to_indices = defaultdict(list) 100 | for i, label in enumerate(self.dataset["label"]): 101 | self.label_to_indices[label].append(i) 102 | 103 | def __len__(self): 104 | 105 | return len(self.dataset) 106 | 107 | def get_image_by_idx(self, idx: int) -> Image.Image: 108 | 109 | return self.dataset[idx]["image"].convert("RGB") 110 | 111 | def get_label_by_idx(self, idx: int) -> int: 112 | 113 | return self.dataset[idx]["label"] 114 | 115 | def get_metadata_by_idx(self, idx: int) -> dict: 116 | 117 | return dict( 118 | name=self.label2class[self.get_label_by_idx(idx)], 119 | super_class=self.super_class_name, 120 | ) 121 | 122 | 123 | class CarHugDatasetForT2I(torch.utils.data.Dataset): 124 | 125 | super_class_name = SUPER_CLASS_NAME 126 | 127 | def __init__( 128 | self, 129 | *args, 130 | split: str = "train", 131 | seed: int = 0, 132 | image_train_dir: str = HUG_LOCAL_IMAGE_TRAIN_DIR, 133 | max_train_samples: int = -1, 134 | class_prompts_ratio: float = 0.5, 135 | resolution: int = 512, 136 | center_crop: bool = False, 137 | random_flip: bool = False, 138 | use_placeholder: bool = False, 139 | examples_per_class: int = -1, 140 | **kwargs, 141 | ): 142 | 143 | super().__init__() 144 | 145 | dataset = load_dataset( 146 | "/data/zhicai/cache/huggingface/datasets/Multimodal-Fatima___stanford_cars_train", 147 | split="train", 148 | ) 149 | 150 | random.seed(seed) 151 | np.random.seed(seed) 152 | if max_train_samples is not None and max_train_samples > 0: 153 | dataset = dataset.shuffle(seed=seed).select(range(max_train_samples)) 154 | if examples_per_class is not None and examples_per_class > 0: 155 | all_labels = dataset["label"] 156 | label_to_indices = defaultdict(list) 157 | for i, label in enumerate(all_labels): 158 | label_to_indices[label].append(i) 159 | 160 | _all_indices = [] 161 | for key, items in label_to_indices.items(): 162 | try: 163 | sampled_indices = random.sample(items, examples_per_class) 164 | except ValueError: 165 | print( 166 | f"{key}: Sample larger than population or is negative, use random.choices instead" 167 | ) 168 | sampled_indices = random.choices(items, k=examples_per_class) 169 | 170 | label_to_indices[key] = sampled_indices 171 | _all_indices.extend(sampled_indices) 172 | dataset = dataset.select(_all_indices) 173 | self.dataset = dataset 174 | class2label = self.dataset.features["label"]._str2int 175 | self.class2label = {k.replace("/", " "): v for k, v in class2label.items()} 176 | self.label2class = {v: k.replace("/", " ") for k, v in class2label.items()} 177 | self.class_names = [ 178 | name.replace("/", " ") for name in dataset.features["label"].names 179 | ] 180 | self.num_classes = len(self.class_names) 181 | self.use_placeholder = use_placeholder 182 | self.name2placeholder = {} 183 | self.placeholder2name = {} 184 | self.label_to_indices = defaultdict(list) 185 | for i, label in enumerate(self.dataset["label"]): 186 | self.label_to_indices[label].append(i) 187 | 188 | self.transform = transforms.Compose( 189 | [ 190 | transforms.Resize( 191 | resolution, interpolation=transforms.InterpolationMode.BILINEAR 192 | ), 193 | ( 194 | transforms.CenterCrop(resolution) 195 | if center_crop 196 | else transforms.RandomCrop(resolution) 197 | ), 198 | ( 199 | transforms.RandomHorizontalFlip() 200 | if random_flip 201 | else transforms.Lambda(lambda x: x) 202 | ), 203 | transforms.ToTensor(), 204 | transforms.Normalize([0.5], [0.5]), 205 | ] 206 | ) 207 | 208 | def __len__(self): 209 | 210 | return len(self.dataset) 211 | 212 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: 213 | 214 | image = self.get_image_by_idx(idx) 215 | prompt = self.get_prompt_by_idx(idx) 216 | # label = self.get_label_by_idx(idx) 217 | 218 | return dict(pixel_values=self.transform(image), caption=prompt) 219 | 220 | def get_image_by_idx(self, idx: int) -> Image.Image: 221 | 222 | return self.dataset[idx]["image"].convert("RGB") 223 | 224 | def get_label_by_idx(self, idx: int) -> int: 225 | 226 | return self.dataset[idx]["label"] 227 | 228 | def get_prompt_by_idx(self, idx: int) -> int: 229 | # randomly choose from class name or description 230 | if self.use_placeholder: 231 | content = ( 232 | self.name2placeholder[self.label2class[self.dataset[idx]["label"]]] 233 | + f"{self.super_class_name}" 234 | ) 235 | else: 236 | content = self.label2class[self.dataset[idx]["label"]] 237 | prompt = random.choice(IMAGENET_TEMPLATES_TINY).format(content) 238 | 239 | return prompt 240 | 241 | def get_metadata_by_idx(self, idx: int) -> dict: 242 | 243 | return dict(name=self.label2class[self.get_label_by_idx(idx)]) 244 | -------------------------------------------------------------------------------- /dataset/instance/dog.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from collections import defaultdict 4 | from typing import Tuple 5 | 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as transforms 9 | from PIL import Image 10 | from scipy.io import loadmat 11 | 12 | from dataset.base import HugFewShotDataset 13 | from dataset.template import IMAGENET_TEMPLATES_TINY 14 | 15 | SUPER_CLASS_NAME = "dog" 16 | DEFAULT_IMAGE_DIR = "/data/zhicai/datasets/fgvc_datasets/stanford_dogs/" 17 | 18 | CLASS_NAME = [ 19 | "Chihuahua", 20 | "Japanese spaniel", 21 | "Maltese dog", 22 | "Pekinese", 23 | "Shih-Tzu", 24 | "Blenheim spaniel", 25 | "papillon", 26 | "toy terrier", 27 | "Rhodesian ridgeback", 28 | "Afghan hound", 29 | "basset", 30 | "beagle", 31 | "bloodhound", 32 | "bluetick", 33 | "black-and-tan coonhound", 34 | "Walker hound", 35 | "English foxhound", 36 | "redbone", 37 | "borzoi", 38 | "Irish wolfhound", 39 | "Italian greyhound", 40 | "whippet", 41 | "Ibizan hound", 42 | "Norwegian elkhound", 43 | "otterhound", 44 | "Saluki", 45 | "Scottish deerhound", 46 | "Weimaraner", 47 | "Staffordshire bullterrier", 48 | "American Staffordshire terrier", 49 | "Bedlington terrier", 50 | "Border terrier", 51 | "Kerry blue terrier", 52 | "Irish terrier", 53 | "Norfolk terrier", 54 | "Norwich terrier", 55 | "Yorkshire terrier", 56 | "wire-haired fox terrier", 57 | "Lakeland terrier", 58 | "Sealyham terrier", 59 | "Airedale", 60 | "cairn", 61 | "Australian terrier", 62 | "Dandie Dinmont", 63 | "Boston bull", 64 | "miniature schnauzer", 65 | "giant schnauzer", 66 | "standard schnauzer", 67 | "Scotch terrier", 68 | "Tibetan terrier", 69 | "silky terrier", 70 | "soft-coated wheaten terrier", 71 | "West Highland white terrier", 72 | "Lhasa", 73 | "flat-coated retriever", 74 | "curly-coated retriever", 75 | "golden retriever", 76 | "Labrador retriever", 77 | "Chesapeake Bay retriever", 78 | "German short-haired pointer", 79 | "vizsla", 80 | "English setter", 81 | "Irish setter", 82 | "Gordon setter", 83 | "Brittany spaniel", 84 | "clumber", 85 | "English springer", 86 | "Welsh springer spaniel", 87 | "cocker spaniel", 88 | "Sussex spaniel", 89 | "Irish water spaniel", 90 | "kuvasz", 91 | "schipperke", 92 | "groenendael", 93 | "malinois", 94 | "briard", 95 | "kelpie", 96 | "komondor", 97 | "Old English sheepdog", 98 | "Shetland sheepdog", 99 | "collie", 100 | "Border collie", 101 | "Bouvier des Flandres", 102 | "Rottweiler", 103 | "German shepherd", 104 | "Doberman", 105 | "miniature pinscher", 106 | "Greater Swiss Mountain dog", 107 | "Bernese mountain dog", 108 | "Appenzeller", 109 | "EntleBucher", 110 | "boxer", 111 | "bull mastiff", 112 | "Tibetan mastiff", 113 | "French bulldog", 114 | "Great Dane", 115 | "Saint Bernard", 116 | "Eskimo dog", 117 | "malamute", 118 | "Siberian husky", 119 | "affenpinscher", 120 | "basenji", 121 | "pug", 122 | "Leonberg", 123 | "Newfoundland", 124 | "Great Pyrenees", 125 | "Samoyed", 126 | "Pomeranian", 127 | "chow", 128 | "keeshond", 129 | "Brabancon griffon", 130 | "Pembroke", 131 | "Cardigan", 132 | "toy poodle", 133 | "miniature poodle", 134 | "standard poodle", 135 | "Mexican hairless", 136 | "dingo", 137 | "dhole", 138 | "African hunting dog", 139 | ] 140 | 141 | 142 | class StanfordDogDataset(HugFewShotDataset): 143 | 144 | class_names = CLASS_NAME 145 | super_class_name = SUPER_CLASS_NAME 146 | num_classes: int = len(class_names) 147 | 148 | def __init__( 149 | self, 150 | *args, 151 | split: str = "train", 152 | seed: int = 0, 153 | image_dir: str = DEFAULT_IMAGE_DIR, 154 | examples_per_class: int = None, 155 | synthetic_probability: float = 0.5, 156 | return_onehot: bool = False, 157 | soft_scaler: float = 0.9, 158 | synthetic_dir: str = None, 159 | image_size: int = 512, 160 | crop_size: int = 448, 161 | **kwargs, 162 | ): 163 | 164 | super().__init__( 165 | *args, 166 | split=split, 167 | examples_per_class=examples_per_class, 168 | synthetic_probability=synthetic_probability, 169 | return_onehot=return_onehot, 170 | soft_scaler=soft_scaler, 171 | synthetic_dir=synthetic_dir, 172 | image_size=image_size, 173 | crop_size=crop_size, 174 | **kwargs, 175 | ) 176 | 177 | if split == "train": 178 | data_mat = loadmat(os.path.join(image_dir, "train_list.mat")) 179 | else: 180 | data_mat = loadmat(os.path.join(image_dir, "test_list.mat")) 181 | 182 | image_files = [ 183 | os.path.join(image_dir, "Images", i[0][0]) for i in data_mat["file_list"] 184 | ] 185 | imagelabels = data_mat["labels"].squeeze() 186 | class_to_images = defaultdict(list) 187 | 188 | for image_idx, image_path in enumerate(image_files): 189 | class_name = self.class_names[imagelabels[image_idx] - 1] 190 | class_to_images[class_name].append(image_path) 191 | 192 | rng = np.random.default_rng(seed) 193 | class_to_ids = { 194 | key: rng.permutation(len(class_to_images[key])) for key in self.class_names 195 | } 196 | 197 | if examples_per_class is not None and examples_per_class > 0: 198 | class_to_ids = { 199 | key: ids[:examples_per_class] for key, ids in class_to_ids.items() 200 | } 201 | 202 | self.class_to_images = { 203 | key: [class_to_images[key][i] for i in ids] 204 | for key, ids in class_to_ids.items() 205 | } 206 | self.class2label = {key: i for i, key in enumerate(self.class_names)} 207 | self.label2class = {v: k for k, v in self.class2label.items()} 208 | self.all_images = sum( 209 | [self.class_to_images[key] for key in self.class_names], [] 210 | ) 211 | self.all_labels = [ 212 | i 213 | for i, key in enumerate(self.class_names) 214 | for _ in self.class_to_images[key] 215 | ] 216 | 217 | self.label_to_indices = defaultdict(list) 218 | for i, label in enumerate(self.all_labels): 219 | self.label_to_indices[label].append(i) 220 | 221 | def __len__(self): 222 | 223 | return len(self.all_images) 224 | 225 | def get_image_by_idx(self, idx: int) -> Image.Image: 226 | 227 | return Image.open(self.all_images[idx]).convert("RGB") 228 | 229 | def get_label_by_idx(self, idx: int) -> int: 230 | 231 | return self.all_labels[idx] 232 | 233 | def get_metadata_by_idx(self, idx: int) -> dict: 234 | 235 | return dict( 236 | name=self.class_names[self.all_labels[idx]], 237 | super_class=self.super_class_name, 238 | ) 239 | 240 | 241 | class StanfordDogDatasetForT2I(torch.utils.data.Dataset): 242 | 243 | class_names = CLASS_NAME 244 | super_class_name = SUPER_CLASS_NAME 245 | 246 | def __init__( 247 | self, 248 | *args, 249 | split: str = "train", 250 | seed: int = 0, 251 | image_dir: str = DEFAULT_IMAGE_DIR, 252 | max_train_samples: int = -1, 253 | class_prompts_ratio: float = 0.5, 254 | resolution: int = 512, 255 | center_crop: bool = False, 256 | random_flip: bool = False, 257 | use_placeholder: bool = False, 258 | examples_per_class: int = -1, 259 | **kwargs, 260 | ): 261 | 262 | super().__init__() 263 | if split == "train": 264 | data_mat = loadmat(os.path.join(image_dir, "train_list.mat")) 265 | else: 266 | data_mat = loadmat(os.path.join(image_dir, "test_list.mat")) 267 | 268 | image_files = [ 269 | os.path.join(image_dir, "Images", i[0][0]) for i in data_mat["file_list"] 270 | ] 271 | imagelabels = data_mat["labels"].squeeze() 272 | class_to_images = defaultdict(list) 273 | random.seed(seed) 274 | np.random.seed(seed) 275 | if max_train_samples is not None and max_train_samples > 0: 276 | dataset = dataset.shuffle(seed=seed).select(range(max_train_samples)) 277 | class_to_images = defaultdict(list) 278 | 279 | for image_idx, image_path in enumerate(image_files): 280 | class_name = self.class_names[imagelabels[image_idx] - 1] 281 | class_to_images[class_name].append(image_path) 282 | 283 | rng = np.random.default_rng(seed) 284 | class_to_ids = { 285 | key: rng.permutation(len(class_to_images[key])) for key in self.class_names 286 | } 287 | 288 | # Split 0.9/0.1 as Train/Test subset 289 | class_to_ids = { 290 | key: np.array_split(class_to_ids[key], [int(0.5 * len(class_to_ids[key]))])[ 291 | 0 if split == "train" else 1 292 | ] 293 | for key in self.class_names 294 | } 295 | if examples_per_class is not None and examples_per_class > 0: 296 | class_to_ids = { 297 | key: ids[:examples_per_class] for key, ids in class_to_ids.items() 298 | } 299 | self.class_to_images = { 300 | key: [class_to_images[key][i] for i in ids] 301 | for key, ids in class_to_ids.items() 302 | } 303 | self.all_images = sum( 304 | [self.class_to_images[key] for key in self.class_names], [] 305 | ) 306 | self.class2label = {key: i for i, key in enumerate(self.class_names)} 307 | self.label2class = {v: k for k, v in self.class2label.items()} 308 | self.all_labels = [ 309 | i 310 | for i, key in enumerate(self.class_names) 311 | for _ in self.class_to_images[key] 312 | ] 313 | 314 | self.label_to_indices = defaultdict(list) 315 | for i, label in enumerate(self.all_labels): 316 | self.label_to_indices[label].append(i) 317 | 318 | self.num_classes = len(self.class_names) 319 | self.class_prompts_ratio = class_prompts_ratio 320 | self.use_placeholder = use_placeholder 321 | self.name2placeholder = {} 322 | self.placeholder2name = {} 323 | 324 | self.transform = transforms.Compose( 325 | [ 326 | transforms.Resize( 327 | resolution, interpolation=transforms.InterpolationMode.BILINEAR 328 | ), 329 | ( 330 | transforms.CenterCrop(resolution) 331 | if center_crop 332 | else transforms.RandomCrop(resolution) 333 | ), 334 | ( 335 | transforms.RandomHorizontalFlip() 336 | if random_flip 337 | else transforms.Lambda(lambda x: x) 338 | ), 339 | transforms.ToTensor(), 340 | transforms.Normalize([0.5], [0.5]), 341 | ] 342 | ) 343 | 344 | def __len__(self): 345 | 346 | return len(self.all_images) 347 | 348 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: 349 | 350 | image = self.get_image_by_idx(idx) 351 | prompt = self.get_prompt_by_idx(idx) 352 | # label = self.get_label_by_idx(idx) 353 | 354 | return dict(pixel_values=self.transform(image), caption=prompt) 355 | 356 | def get_image_by_idx(self, idx: int) -> Image.Image: 357 | 358 | return Image.open(self.all_images[idx]).convert("RGB") 359 | 360 | def get_label_by_idx(self, idx: int) -> int: 361 | 362 | return self.all_labels[idx] 363 | 364 | def get_prompt_by_idx(self, idx: int) -> int: 365 | # randomly choose from class name or description 366 | 367 | if self.use_placeholder: 368 | content = ( 369 | self.name2placeholder[self.label2class[self.get_label_by_idx(idx)]] 370 | + f"{self.super_class_name}" 371 | ) 372 | else: 373 | content = self.label2class[self.get_label_by_idx(idx)] 374 | prompt = random.choice(IMAGENET_TEMPLATES_TINY).format(content) 375 | 376 | return prompt 377 | 378 | def get_metadata_by_idx(self, idx: int) -> dict: 379 | return dict(name=self.class_names[self.all_labels[idx]]) 380 | 381 | 382 | if __name__ == "__main__": 383 | 384 | ds2 = StanfordDogDataset() 385 | # ds = DogDatasetForT2I(return_onehot=True) 386 | print(ds2.class_to_images == ds2.class_to_images) 387 | print("l") 388 | -------------------------------------------------------------------------------- /dataset/instance/food.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | from typing import Any, Dict, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision.transforms as transforms 8 | from datasets import load_dataset 9 | from PIL import Image 10 | 11 | from dataset.base import HugFewShotDataset 12 | from dataset.template import IMAGENET_TEMPLATES_TINY 13 | 14 | SUPER_CLASS_NAME = "food" 15 | HUG_LOCAL_IMAGE_TRAIN_DIR = r"Multimodal-Fatima/Food101_train" 16 | HUG_LOCAL_IMAGE_TEST_DIR = r"Multimodal-Fatima/Food101_test" 17 | 18 | 19 | class FoodHugDataset(HugFewShotDataset): 20 | super_class_name = SUPER_CLASS_NAME 21 | 22 | def __init__( 23 | self, 24 | *args, 25 | split: str = "train", 26 | seed: int = 0, 27 | image_train_dir: str = HUG_LOCAL_IMAGE_TRAIN_DIR, 28 | image_test_dir: str = HUG_LOCAL_IMAGE_TEST_DIR, 29 | examples_per_class: int = -1, 30 | synthetic_probability: float = 0.5, 31 | return_onehot: bool = False, 32 | soft_scaler: float = 0.9, 33 | synthetic_dir: str = None, 34 | image_size: int = 512, 35 | crop_size: int = 448, 36 | **kwargs, 37 | ): 38 | 39 | super(FoodHugDataset, self).__init__( 40 | *args, 41 | split=split, 42 | examples_per_class=examples_per_class, 43 | synthetic_probability=synthetic_probability, 44 | return_onehot=return_onehot, 45 | soft_scaler=soft_scaler, 46 | synthetic_dir=synthetic_dir, 47 | image_size=image_size, 48 | crop_size=crop_size, 49 | **kwargs, 50 | ) 51 | 52 | if split == "train": 53 | dataset = load_dataset(image_train_dir, split="train") 54 | else: 55 | dataset = load_dataset(image_test_dir, split="test") 56 | 57 | self.class_names = [ 58 | name.replace("/", " ") for name in dataset.features["label"].names 59 | ] 60 | 61 | random.seed(seed) 62 | np.random.seed(seed) 63 | if examples_per_class is not None and examples_per_class > 0: 64 | all_labels = dataset["label"] 65 | label_to_indices = defaultdict(list) 66 | for i, label in enumerate(all_labels): 67 | label_to_indices[label].append(i) 68 | 69 | _all_indices = [] 70 | for key, items in label_to_indices.items(): 71 | try: 72 | sampled_indices = random.sample(items, examples_per_class) 73 | except ValueError: 74 | print( 75 | f"{key}: Sample larger than population or is negative, use random.choices instead" 76 | ) 77 | sampled_indices = random.choices(items, k=examples_per_class) 78 | 79 | label_to_indices[key] = sampled_indices 80 | _all_indices.extend(sampled_indices) 81 | dataset = dataset.select(_all_indices) 82 | 83 | self.dataset = dataset 84 | class2label = self.dataset.features["label"]._str2int 85 | self.class2label = {k.replace("/", " "): v for k, v in class2label.items()} 86 | self.label2class = {v: k.replace("/", " ") for k, v in class2label.items()} 87 | self.class_names = [ 88 | name.replace("/", " ") for name in dataset.features["label"].names 89 | ] 90 | self.num_classes = len(self.class_names) 91 | 92 | self.label_to_indices = defaultdict(list) 93 | for i, label in enumerate(self.dataset["label"]): 94 | self.label_to_indices[label].append(i) 95 | 96 | def __len__(self): 97 | 98 | return len(self.dataset) 99 | 100 | def get_image_by_idx(self, idx: int) -> Image.Image: 101 | 102 | return self.dataset[idx]["image"].convert("RGB") 103 | 104 | def get_label_by_idx(self, idx: int) -> int: 105 | 106 | return self.dataset[idx]["label"] 107 | 108 | def get_metadata_by_idx(self, idx: int) -> dict: 109 | 110 | return dict( 111 | name=self.label2class[self.get_label_by_idx(idx)], 112 | super_class=self.super_class_name, 113 | ) 114 | 115 | 116 | class FoodHugDatasetForT2I(torch.utils.data.Dataset): 117 | super_class_name = SUPER_CLASS_NAME 118 | 119 | def __init__( 120 | self, 121 | *args, 122 | split: str = "train", 123 | seed: int = 0, 124 | image_train_dir: str = HUG_LOCAL_IMAGE_TRAIN_DIR, 125 | max_train_samples: int = -1, 126 | class_prompts_ratio: float = 0.5, 127 | resolution: int = 512, 128 | center_crop: bool = False, 129 | random_flip: bool = False, 130 | use_placeholder: bool = False, 131 | **kwargs, 132 | ): 133 | 134 | super().__init__() 135 | 136 | dataset = load_dataset(image_train_dir, split="train") 137 | 138 | random.seed(seed) 139 | np.random.seed(seed) 140 | if max_train_samples is not None and max_train_samples > 0: 141 | dataset = dataset.shuffle(seed=seed).select(range(max_train_samples)) 142 | 143 | self.dataset = dataset 144 | class2label = self.dataset.features["label"]._str2int 145 | self.class2label = {k.replace("/", " "): v for k, v in class2label.items()} 146 | self.label2class = {v: k.replace("/", " ") for k, v in class2label.items()} 147 | self.class_names = [ 148 | name.replace("/", " ") for name in dataset.features["label"].names 149 | ] 150 | self.num_classes = len(self.class_names) 151 | self.use_placeholder = use_placeholder 152 | self.name2placeholder = {} 153 | self.placeholder2name = {} 154 | self.label_to_indices = defaultdict(list) 155 | for i, label in enumerate(self.dataset["label"]): 156 | self.label_to_indices[label].append(i) 157 | 158 | self.transform = transforms.Compose( 159 | [ 160 | transforms.Resize( 161 | resolution, interpolation=transforms.InterpolationMode.BILINEAR 162 | ), 163 | ( 164 | transforms.CenterCrop(resolution) 165 | if center_crop 166 | else transforms.RandomCrop(resolution) 167 | ), 168 | ( 169 | transforms.RandomHorizontalFlip() 170 | if random_flip 171 | else transforms.Lambda(lambda x: x) 172 | ), 173 | transforms.ToTensor(), 174 | transforms.Normalize([0.5], [0.5]), 175 | ] 176 | ) 177 | 178 | def __len__(self): 179 | 180 | return len(self.dataset) 181 | 182 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: 183 | 184 | image = self.get_image_by_idx(idx) 185 | prompt = self.get_prompt_by_idx(idx) 186 | # label = self.get_label_by_idx(idx) 187 | 188 | return dict(pixel_values=self.transform(image), caption=prompt) 189 | 190 | def get_image_by_idx(self, idx: int) -> Image.Image: 191 | 192 | return self.dataset[idx]["image"].convert("RGB") 193 | 194 | def get_label_by_idx(self, idx: int) -> int: 195 | 196 | return self.dataset[idx]["label"] 197 | 198 | def get_prompt_by_idx(self, idx: int) -> int: 199 | # randomly choose from class name or description 200 | if self.use_placeholder: 201 | content = ( 202 | self.name2placeholder[self.label2class[self.dataset[idx]["label"]]] 203 | + f" {self.super_class_name}" 204 | ) 205 | else: 206 | content = self.label2class[self.dataset[idx]["label"]] 207 | prompt = random.choice(IMAGENET_TEMPLATES_TINY).format(content) 208 | 209 | return prompt 210 | 211 | def get_metadata_by_idx(self, idx: int) -> dict: 212 | 213 | return dict(name=self.label2class[self.get_label_by_idx(idx)]) 214 | 215 | 216 | if __name__ == "__main__": 217 | ds_train = FoodHugDataset() 218 | print(ds_train[0]) 219 | -------------------------------------------------------------------------------- /dataset/instance/pascal.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from collections import defaultdict 4 | from typing import Any, Dict, Tuple 5 | 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as transforms 9 | from datasets import load_dataset 10 | from PIL import Image 11 | 12 | from dataset.base import HugFewShotDataset 13 | from dataset.template import IMAGENET_TEMPLATES_TINY 14 | 15 | PASCAL_DIR = "/data/zhicai/datasets/VOCdevkit/VOC2012/" 16 | 17 | TRAIN_IMAGE_SET = os.path.join(PASCAL_DIR, "ImageSets/Segmentation/train.txt") 18 | VAL_IMAGE_SET = os.path.join(PASCAL_DIR, "ImageSets/Segmentation/val.txt") 19 | DEFAULT_IMAGE_DIR = os.path.join(PASCAL_DIR, "JPEGImages") 20 | DEFAULT_LABEL_DIR = os.path.join(PASCAL_DIR, "SegmentationClass") 21 | DEFAULT_INSTANCE_DIR = os.path.join(PASCAL_DIR, "SegmentationObject") 22 | 23 | SUPER_CLASS_NAME = "" 24 | CLASS_NAME = [ 25 | "airplane", 26 | "bicycle", 27 | "bird", 28 | "boat", 29 | "bottle", 30 | "bus", 31 | "car", 32 | "cat", 33 | "chair", 34 | "cow", 35 | "dining table", 36 | "dog", 37 | "horse", 38 | "motorcycle", 39 | "person", 40 | "potted plant", 41 | "sheep", 42 | "sofa", 43 | "train", 44 | "television", 45 | ] 46 | 47 | 48 | class PascalDataset(HugFewShotDataset): 49 | 50 | class_names = CLASS_NAME 51 | num_classes: int = len(class_names) 52 | super_class_name = SUPER_CLASS_NAME 53 | 54 | def __init__( 55 | self, 56 | *args, 57 | split: str = "train", 58 | seed: int = 0, 59 | image_dir: str = DEFAULT_IMAGE_DIR, 60 | examples_per_class: int = None, 61 | synthetic_probability: float = 0.5, 62 | return_onehot: bool = False, 63 | soft_scaler: float = 0.9, 64 | synthetic_dir: str = None, 65 | image_size: int = 512, 66 | crop_size: int = 448, 67 | **kwargs, 68 | ): 69 | 70 | super().__init__( 71 | *args, 72 | split=split, 73 | examples_per_class=examples_per_class, 74 | synthetic_probability=synthetic_probability, 75 | return_onehot=return_onehot, 76 | soft_scaler=soft_scaler, 77 | synthetic_dir=synthetic_dir, 78 | image_size=image_size, 79 | crop_size=crop_size, 80 | **kwargs, 81 | ) 82 | 83 | image_set = {"train": TRAIN_IMAGE_SET, "val": VAL_IMAGE_SET}[split] 84 | 85 | with open(image_set, "r") as f: 86 | image_set_lines = [x.strip() for x in f.readlines()] 87 | 88 | class_to_images = defaultdict(list) 89 | class_to_annotations = defaultdict(list) 90 | 91 | for image_id in image_set_lines: 92 | 93 | labels = os.path.join(DEFAULT_LABEL_DIR, image_id + ".png") 94 | instances = os.path.join(DEFAULT_INSTANCE_DIR, image_id + ".png") 95 | 96 | labels = np.asarray(Image.open(labels)) 97 | instances = np.asarray(Image.open(instances)) 98 | 99 | instance_ids, pixel_loc, counts = np.unique( 100 | instances, return_index=True, return_counts=True 101 | ) 102 | 103 | counts[0] = counts[-1] = 0 # remove background 104 | 105 | argmax_index = counts.argmax() 106 | 107 | mask = np.equal(instances, instance_ids[argmax_index]) 108 | class_name = self.class_names[labels.flat[pixel_loc[argmax_index]] - 1] 109 | 110 | class_to_images[class_name].append( 111 | os.path.join(image_dir, image_id + ".jpg") 112 | ) 113 | class_to_annotations[class_name].append(dict(mask=mask)) 114 | 115 | rng = np.random.default_rng(seed) 116 | class_to_ids = { 117 | key: rng.permutation(len(class_to_images[key])) for key in self.class_names 118 | } 119 | 120 | if examples_per_class is not None and examples_per_class > 0: 121 | class_to_ids = { 122 | key: ids[:examples_per_class] for key, ids in class_to_ids.items() 123 | } 124 | 125 | self.class_to_images = { 126 | key: [class_to_images[key][i] for i in ids] 127 | for key, ids in class_to_ids.items() 128 | } 129 | 130 | self.class_to_annotations = { 131 | key: [class_to_annotations[key][i] for i in ids] 132 | for key, ids in class_to_ids.items() 133 | } 134 | 135 | self.class2label = {key: i for i, key in enumerate(self.class_names)} 136 | self.label2class = {v: k for k, v in self.class2label.items()} 137 | self.all_images = sum( 138 | [self.class_to_images[key] for key in self.class_names], [] 139 | ) 140 | self.all_labels = [ 141 | i 142 | for i, key in enumerate(self.class_names) 143 | for _ in self.class_to_images[key] 144 | ] 145 | 146 | self.label_to_indices = defaultdict(list) 147 | for i, label in enumerate(self.all_labels): 148 | self.label_to_indices[label].append(i) 149 | 150 | def __len__(self): 151 | 152 | return len(self.all_images) 153 | 154 | def get_image_by_idx(self, idx: int) -> Image.Image: 155 | 156 | return Image.open(self.all_images[idx]).convert("RGB") 157 | 158 | def get_label_by_idx(self, idx: int) -> int: 159 | 160 | return self.all_labels[idx] 161 | 162 | def get_metadata_by_idx(self, idx: int) -> dict: 163 | 164 | return dict( 165 | name=self.class_names[self.all_labels[idx]], 166 | super_class=self.super_class_name, 167 | ) 168 | 169 | 170 | class PascalDatasetForT2I(torch.utils.data.Dataset): 171 | super_class_name = SUPER_CLASS_NAME 172 | class_names = CLASS_NAME 173 | 174 | def __init__( 175 | self, 176 | *args, 177 | split: str = "train", 178 | seed: int = 0, 179 | image_dir: str = DEFAULT_IMAGE_DIR, 180 | max_train_samples: int = -1, 181 | class_prompts_ratio: float = 0.5, 182 | resolution: int = 512, 183 | center_crop: bool = False, 184 | random_flip: bool = False, 185 | use_placeholder: bool = False, 186 | examples_per_class: int = -1, 187 | **kwargs, 188 | ): 189 | 190 | super().__init__() 191 | image_set = {"train": TRAIN_IMAGE_SET, "val": VAL_IMAGE_SET}[split] 192 | 193 | with open(image_set, "r") as f: 194 | image_set_lines = [x.strip() for x in f.readlines()] 195 | 196 | class_to_images = defaultdict(list) 197 | class_to_annotations = defaultdict(list) 198 | 199 | for image_id in image_set_lines: 200 | 201 | labels = os.path.join(DEFAULT_LABEL_DIR, image_id + ".png") 202 | instances = os.path.join(DEFAULT_INSTANCE_DIR, image_id + ".png") 203 | 204 | labels = np.asarray(Image.open(labels)) 205 | instances = np.asarray(Image.open(instances)) 206 | 207 | instance_ids, pixel_loc, counts = np.unique( 208 | instances, return_index=True, return_counts=True 209 | ) 210 | 211 | counts[0] = counts[-1] = 0 # remove background 212 | 213 | argmax_index = counts.argmax() 214 | 215 | mask = np.equal(instances, instance_ids[argmax_index]) 216 | class_name = self.class_names[labels.flat[pixel_loc[argmax_index]] - 1] 217 | 218 | class_to_images[class_name].append( 219 | os.path.join(image_dir, image_id + ".jpg") 220 | ) 221 | class_to_annotations[class_name].append(dict(mask=mask)) 222 | 223 | rng = np.random.default_rng(seed) 224 | class_to_ids = { 225 | key: rng.permutation(len(class_to_images[key])) for key in self.class_names 226 | } 227 | 228 | if examples_per_class is not None and examples_per_class > 0: 229 | class_to_ids = { 230 | key: ids[:examples_per_class] for key, ids in class_to_ids.items() 231 | } 232 | 233 | self.class_to_images = { 234 | key: [class_to_images[key][i] for i in ids] 235 | for key, ids in class_to_ids.items() 236 | } 237 | 238 | self.class_to_annotations = { 239 | key: [class_to_annotations[key][i] for i in ids] 240 | for key, ids in class_to_ids.items() 241 | } 242 | 243 | self.class2label = {key: i for i, key in enumerate(self.class_names)} 244 | self.label2class = {v: k for k, v in self.class2label.items()} 245 | self.all_images = sum( 246 | [self.class_to_images[key] for key in self.class_names], [] 247 | ) 248 | self.all_labels = [ 249 | i 250 | for i, key in enumerate(self.class_names) 251 | for _ in self.class_to_images[key] 252 | ] 253 | 254 | self.label_to_indices = defaultdict(list) 255 | for i, label in enumerate(self.all_labels): 256 | self.label_to_indices[label].append(i) 257 | self.num_classes = len(self.class_names) 258 | self.class_prompts_ratio = class_prompts_ratio 259 | self.use_placeholder = use_placeholder 260 | self.name2placeholder = {} 261 | self.placeholder2name = {} 262 | 263 | self.transform = transforms.Compose( 264 | [ 265 | transforms.Resize( 266 | resolution, interpolation=transforms.InterpolationMode.BILINEAR 267 | ), 268 | ( 269 | transforms.CenterCrop(resolution) 270 | if center_crop 271 | else transforms.RandomCrop(resolution) 272 | ), 273 | ( 274 | transforms.RandomHorizontalFlip() 275 | if random_flip 276 | else transforms.Lambda(lambda x: x) 277 | ), 278 | transforms.ToTensor(), 279 | transforms.Normalize([0.5], [0.5]), 280 | ] 281 | ) 282 | 283 | def __len__(self): 284 | 285 | return len(self.all_images) 286 | 287 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: 288 | 289 | image = self.get_image_by_idx(idx) 290 | prompt = self.get_prompt_by_idx(idx) 291 | # label = self.get_label_by_idx(idx) 292 | 293 | return dict(pixel_values=self.transform(image), caption=prompt) 294 | 295 | def get_image_by_idx(self, idx: int) -> Image.Image: 296 | 297 | return Image.open(self.all_images[idx]).convert("RGB") 298 | 299 | def get_label_by_idx(self, idx: int) -> int: 300 | 301 | return self.all_labels[idx] 302 | 303 | def get_prompt_by_idx(self, idx: int) -> int: 304 | # randomly choose from class name or description 305 | 306 | if self.use_placeholder: 307 | content = ( 308 | self.name2placeholder[self.label2class[self.get_label_by_idx(idx)]] 309 | + f"{self.super_class_name}" 310 | ) 311 | else: 312 | content = self.label2class[self.get_label_by_idx(idx)] 313 | prompt = random.choice(IMAGENET_TEMPLATES_TINY).format(content) 314 | 315 | return prompt 316 | 317 | def get_metadata_by_idx(self, idx: int) -> dict: 318 | return dict(name=self.class_names[self.all_labels[idx]]) 319 | 320 | 321 | if __name__ == "__main__": 322 | ds = PascalDataset(return_onehot=True) 323 | ds2 = PascalDatasetForT2I() 324 | print(ds.class_to_images == ds2.class_to_images) 325 | print("l") 326 | -------------------------------------------------------------------------------- /dataset/instance/pet.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | from typing import Tuple 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision.transforms as transforms 8 | from datasets import load_dataset 9 | from PIL import Image 10 | 11 | from dataset.base import HugFewShotDataset 12 | from dataset.template import IMAGENET_TEMPLATES_TINY 13 | 14 | SUPER_CLASS_NAME = "animal" 15 | HUG_LOCAL_IMAGE_TRAIN_DIR = r"jonathancui/oxford-pets" 16 | HUG_LOCAL_IMAGE_TEST_DIR = r"jonathancui/oxford-pets" 17 | DEFAULT_IMAGE_LOCAL_DIR = r"/home/zhicai/.cache/huggingface/local/pet" 18 | HUB_LOCAL_DIR = "/home/zhicai/.cache/huggingface/datasets/pcuenq___oxford-pets/default-7fcafd63f4da1c6c" 19 | 20 | 21 | class PetHugDataset(HugFewShotDataset): 22 | 23 | super_class_name = SUPER_CLASS_NAME 24 | 25 | def __init__( 26 | self, 27 | *args, 28 | split: str = "train", 29 | seed: int = 0, 30 | image_train_dir: str = HUG_LOCAL_IMAGE_TRAIN_DIR, 31 | image_test_dir: str = HUG_LOCAL_IMAGE_TEST_DIR, 32 | examples_per_class: int = -1, 33 | synthetic_probability: float = 0.5, 34 | return_onehot: bool = False, 35 | soft_scaler: float = 0.9, 36 | synthetic_dir: str = None, 37 | image_size: int = 512, 38 | crop_size: int = 448, 39 | **kwargs, 40 | ): 41 | 42 | super(PetHugDataset, self).__init__( 43 | *args, 44 | split=split, 45 | examples_per_class=examples_per_class, 46 | synthetic_probability=synthetic_probability, 47 | return_onehot=return_onehot, 48 | soft_scaler=soft_scaler, 49 | synthetic_dir=synthetic_dir, 50 | image_size=image_size, 51 | crop_size=crop_size, 52 | **kwargs, 53 | ) 54 | 55 | if split == "train": 56 | dataset = load_dataset(HUB_LOCAL_DIR, split="train") 57 | else: 58 | dataset = load_dataset(HUB_LOCAL_DIR, split="test") 59 | 60 | random.seed(seed) 61 | np.random.seed(seed) 62 | if examples_per_class is not None and examples_per_class > 0: 63 | all_labels = dataset["label"] 64 | label_to_indices = defaultdict(list) 65 | for i, label in enumerate(all_labels): 66 | label_to_indices[label].append(i) 67 | 68 | _all_indices = [] 69 | for key, items in label_to_indices.items(): 70 | try: 71 | sampled_indices = random.sample(items, examples_per_class) 72 | except ValueError: 73 | print( 74 | f"{key}: Sample larger than population or is negative, use random.choices instead" 75 | ) 76 | sampled_indices = random.choices(items, k=examples_per_class) 77 | 78 | label_to_indices[key] = sampled_indices 79 | _all_indices.extend(sampled_indices) 80 | dataset = dataset.select(_all_indices) 81 | 82 | self.dataset = dataset 83 | class2label = self.dataset.features["label"]._str2int 84 | self.class2label = { 85 | k.replace("/", " ").replace("_", " "): v for k, v in class2label.items() 86 | } 87 | self.label2class = { 88 | v: k.replace("/", " ").replace("_", " ") for k, v in class2label.items() 89 | } 90 | self.class_names = [ 91 | name.replace("/", " ").replace("_", " ") 92 | for name in dataset.features["label"].names 93 | ] 94 | self.num_classes = len(self.class_names) 95 | 96 | self.label_to_indices = defaultdict(list) 97 | for i, label in enumerate(self.dataset["label"]): 98 | self.label_to_indices[label].append(i) 99 | 100 | def __len__(self): 101 | 102 | return len(self.dataset) 103 | 104 | def get_image_by_idx(self, idx: int) -> Image.Image: 105 | 106 | return self.dataset[idx]["image"].convert("RGB") 107 | 108 | def get_label_by_idx(self, idx: int) -> int: 109 | 110 | return self.dataset[idx]["label"] 111 | 112 | def get_metadata_by_idx(self, idx: int) -> dict: 113 | 114 | return dict( 115 | name=self.label2class[self.get_label_by_idx(idx)], 116 | super_class=self.super_class_name, 117 | ) 118 | 119 | 120 | class PetHugDatasetForT2I(torch.utils.data.Dataset): 121 | 122 | super_class_name = SUPER_CLASS_NAME 123 | 124 | def __init__( 125 | self, 126 | *args, 127 | split: str = "train", 128 | seed: int = 0, 129 | image_train_dir: str = HUG_LOCAL_IMAGE_TRAIN_DIR, 130 | max_train_samples: int = -1, 131 | class_prompts_ratio: float = 0.5, 132 | resolution: int = 512, 133 | center_crop: bool = False, 134 | random_flip: bool = False, 135 | use_placeholder: bool = False, 136 | **kwargs, 137 | ): 138 | 139 | super().__init__() 140 | 141 | dataset = load_dataset(HUB_LOCAL_DIR, split="train") 142 | dataset = dataset.class_encode_column("label") 143 | 144 | random.seed(seed) 145 | np.random.seed(seed) 146 | if max_train_samples is not None and max_train_samples > 0: 147 | dataset = dataset.shuffle(seed=seed).select(range(max_train_samples)) 148 | 149 | self.dataset = dataset 150 | class2label = self.dataset.features["label"]._str2int 151 | self.class2label = {k.replace("/", " "): v for k, v in class2label.items()} 152 | self.label2class = {v: k.replace("/", " ") for k, v in class2label.items()} 153 | self.class_names = [ 154 | name.replace("/", " ") for name in dataset.features["label"].names 155 | ] 156 | self.num_classes = len(self.class_names) 157 | self.use_placeholder = use_placeholder 158 | self.name2placeholder = {} 159 | self.placeholder2name = {} 160 | self.label_to_indices = defaultdict(list) 161 | for i, label in enumerate(self.dataset["label"]): 162 | self.label_to_indices[label].append(i) 163 | 164 | self.transform = transforms.Compose( 165 | [ 166 | transforms.Resize( 167 | resolution, interpolation=transforms.InterpolationMode.BILINEAR 168 | ), 169 | ( 170 | transforms.CenterCrop(resolution) 171 | if center_crop 172 | else transforms.RandomCrop(resolution) 173 | ), 174 | ( 175 | transforms.RandomHorizontalFlip() 176 | if random_flip 177 | else transforms.Lambda(lambda x: x) 178 | ), 179 | transforms.ToTensor(), 180 | transforms.Normalize([0.5], [0.5]), 181 | ] 182 | ) 183 | 184 | def __len__(self): 185 | 186 | return len(self.dataset) 187 | 188 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: 189 | 190 | image = self.get_image_by_idx(idx) 191 | prompt = self.get_prompt_by_idx(idx) 192 | # label = self.get_label_by_idx(idx) 193 | 194 | return dict(pixel_values=self.transform(image), caption=prompt) 195 | 196 | def get_image_by_idx(self, idx: int) -> Image.Image: 197 | 198 | return self.dataset[idx]["image"].convert("RGB") 199 | 200 | def get_label_by_idx(self, idx: int) -> int: 201 | 202 | return self.dataset[idx]["label"] 203 | 204 | def get_prompt_by_idx(self, idx: int) -> int: 205 | # randomly choose from class name or description 206 | if self.use_placeholder: 207 | content = ( 208 | self.name2placeholder[self.label2class[self.dataset[idx]["label"]]] 209 | + f" {self.super_class_name}" 210 | ) 211 | else: 212 | content = self.label2class[self.dataset[idx]["label"]] 213 | prompt = random.choice(IMAGENET_TEMPLATES_TINY).format(content) 214 | 215 | return prompt 216 | 217 | def get_metadata_by_idx(self, idx: int) -> dict: 218 | 219 | return dict(name=self.label2class[self.get_label_by_idx(idx)]) 220 | -------------------------------------------------------------------------------- /dataset/instance/waterbird.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | import torchvision.transforms as transforms 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | 10 | DATA_DIR = "/data/zhicai/datasets/waterbird_complete95_forest2water2/" 11 | 12 | 13 | def onehot(size: int, target: int): 14 | vec = torch.zeros(size, dtype=torch.float32) 15 | vec[target] = 1.0 16 | return vec 17 | 18 | 19 | class WaterBird(Dataset): 20 | def __init__(self, split=2, image_size=256, crop_size=224, return_onehot=False): 21 | self.root_dir = DATA_DIR 22 | dataframe = pd.read_csv(os.path.join(self.root_dir, "metadata.csv")) 23 | dataframe = dataframe[dataframe["split"] == split].reset_index() 24 | self.labels = list( 25 | map(lambda x: int(x.split(".")[0]) - 1, dataframe["img_filename"]) 26 | ) 27 | self.dataframe = dataframe 28 | self.image_paths = dataframe["img_filename"] 29 | self.groups = dataframe.apply( 30 | lambda row: f"{row['y']}{row['place']}", axis=1 31 | ).tolist() 32 | self.return_onehot = return_onehot 33 | self.num_classes = len(set(self.labels)) 34 | self.transform = transforms.Compose( 35 | [ 36 | transforms.Resize((image_size, image_size)), 37 | transforms.CenterCrop((crop_size, crop_size)), 38 | transforms.ToTensor(), 39 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 40 | ] 41 | ) 42 | 43 | def __len__(self): 44 | return len(self.dataframe) 45 | 46 | def __getitem__(self, idx): 47 | 48 | img_path = self.image_paths[idx] 49 | # Load image 50 | img = Image.open(os.path.join(self.root_dir, img_path)).convert("RGB") 51 | 52 | label = self.labels[idx] 53 | group = self.groups[idx] 54 | if self.transform: 55 | img = self.transform(img) 56 | if self.return_onehot: 57 | if isinstance(label, (int, np.int64)): 58 | label = onehot(self.num_classes, label) 59 | return img, label, group 60 | -------------------------------------------------------------------------------- /dataset/template.py: -------------------------------------------------------------------------------- 1 | IMAGENET_TEMPLATES_TINY = [ 2 | "a photo of a {}", 3 | ] 4 | 5 | IMAGENET_TEMPLATES_SMALL = [ 6 | "a photo of a {}", 7 | "a rendering of a {}", 8 | "a cropped photo of the {}", 9 | "the photo of a {}", 10 | "a photo of a clean {}", 11 | "a photo of a dirty {}", 12 | "a dark photo of the {}", 13 | "a photo of my {}", 14 | "a photo of the cool {}", 15 | "a close-up photo of a {}", 16 | "a bright photo of the {}", 17 | "a cropped photo of a {}", 18 | "a photo of the {}", 19 | "a good photo of the {}", 20 | "a photo of one {}", 21 | "a close-up photo of the {}", 22 | "a rendition of the {}", 23 | "a photo of the clean {}", 24 | "a rendition of a {}", 25 | "a photo of a nice {}", 26 | "a good photo of a {}", 27 | "a photo of the nice {}", 28 | "a photo of the small {}", 29 | "a photo of the weird {}", 30 | "a photo of the large {}", 31 | "a photo of a cool {}", 32 | "a photo of a small {}", 33 | ] 34 | 35 | IMAGENET_STYLE_TEMPLATES_SMALL = [ 36 | "a painting in the style of {}", 37 | "a rendering in the style of {}", 38 | "a cropped painting in the style of {}", 39 | "the painting in the style of {}", 40 | "a clean painting in the style of {}", 41 | "a dirty painting in the style of {}", 42 | "a dark painting in the style of {}", 43 | "a picture in the style of {}", 44 | "a cool painting in the style of {}", 45 | "a close-up painting in the style of {}", 46 | "a bright painting in the style of {}", 47 | "a cropped painting in the style of {}", 48 | "a good painting in the style of {}", 49 | "a close-up painting in the style of {}", 50 | "a rendition in the style of {}", 51 | "a nice painting in the style of {}", 52 | "a small painting in the style of {}", 53 | "a weird painting in the style of {}", 54 | "a large painting in the style of {}", 55 | ] 56 | -------------------------------------------------------------------------------- /downstream_tasks/imb_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhicaiwww/Diff-Mix/a81337b1492bcc7f7a8a61921836e94191a3a0ef/downstream_tasks/imb_utils/__init__.py -------------------------------------------------------------------------------- /downstream_tasks/imb_utils/autoaug.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code from: https://github.com/dvlab-research/Parametric-Contrastive-Learning/blob/main/LT/autoaug.py 3 | """ 4 | 5 | from PIL import Image, ImageEnhance, ImageOps 6 | import numpy as np 7 | import random 8 | import torch 9 | 10 | 11 | class Cutout(object): 12 | def __init__(self, n_holes, length): 13 | self.n_holes = n_holes 14 | self.length = length 15 | 16 | def __call__(self, img): 17 | h = img.size(1) 18 | w = img.size(2) 19 | 20 | mask = np.ones((h, w), np.float32) 21 | 22 | for n in range(self.n_holes): 23 | y = np.random.randint(h) 24 | x = np.random.randint(w) 25 | 26 | y1 = np.clip(y - self.length // 2, 0, h) 27 | y2 = np.clip(y + self.length // 2, 0, h) 28 | x1 = np.clip(x - self.length // 2, 0, w) 29 | x2 = np.clip(x + self.length // 2, 0, w) 30 | 31 | mask[y1: y2, x1: x2] = 0. 32 | 33 | mask = torch.from_numpy(mask) 34 | mask = mask.expand_as(img) 35 | img = img * mask 36 | 37 | return img 38 | 39 | class ImageNetPolicy(object): 40 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 41 | Example: 42 | >>> policy = ImageNetPolicy() 43 | >>> transformed = policy(image) 44 | Example as a PyTorch Transform: 45 | >>> transform=transforms.Compose([ 46 | >>> transforms.Resize(256), 47 | >>> ImageNetPolicy(), 48 | >>> transforms.ToTensor()]) 49 | """ 50 | def __init__(self, fillcolor=(128, 128, 128)): 51 | self.policies = [ 52 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 53 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 54 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 55 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 56 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 57 | 58 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 59 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 60 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 61 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 62 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 63 | 64 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 65 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 66 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 67 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 68 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 69 | 70 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 71 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 72 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 73 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 74 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 75 | 76 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 77 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 78 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 79 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 80 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 81 | ] 82 | 83 | 84 | def __call__(self, img): 85 | policy_idx = random.randint(0, len(self.policies) - 1) 86 | return self.policies[policy_idx](img) 87 | 88 | def __repr__(self): 89 | return "AutoAugment ImageNet Policy" 90 | 91 | 92 | class CIFAR10Policy(object): 93 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 94 | Example: 95 | >>> policy = CIFAR10Policy() 96 | >>> transformed = policy(image) 97 | Example as a PyTorch Transform: 98 | >>> transform=transforms.Compose([ 99 | >>> transforms.Resize(256), 100 | >>> CIFAR10Policy(), 101 | >>> transforms.ToTensor()]) 102 | """ 103 | def __init__(self, fillcolor=(128, 128, 128)): 104 | self.policies = [ 105 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 106 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 107 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 108 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 109 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 110 | 111 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 112 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 113 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 114 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 115 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 116 | 117 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 118 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 119 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 120 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 121 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 122 | 123 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 124 | SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor), 125 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 126 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 127 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 128 | 129 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 130 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 131 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 132 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 133 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 134 | ] 135 | 136 | 137 | def __call__(self, img): 138 | policy_idx = random.randint(0, len(self.policies) - 1) 139 | return self.policies[policy_idx](img) 140 | 141 | def __repr__(self): 142 | return "AutoAugment CIFAR10 Policy" 143 | 144 | 145 | class SVHNPolicy(object): 146 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 147 | Example: 148 | >>> policy = SVHNPolicy() 149 | >>> transformed = policy(image) 150 | Example as a PyTorch Transform: 151 | >>> transform=transforms.Compose([ 152 | >>> transforms.Resize(256), 153 | >>> SVHNPolicy(), 154 | >>> transforms.ToTensor()]) 155 | """ 156 | def __init__(self, fillcolor=(128, 128, 128)): 157 | self.policies = [ 158 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 159 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 160 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 161 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 162 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 163 | 164 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 165 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 166 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 167 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 168 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 169 | 170 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 171 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 172 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 173 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 174 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 175 | 176 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 177 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 178 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 179 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 180 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 181 | 182 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 183 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 184 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 185 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 186 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 187 | ] 188 | 189 | 190 | def __call__(self, img): 191 | policy_idx = random.randint(0, len(self.policies) - 1) 192 | return self.policies[policy_idx](img) 193 | 194 | def __repr__(self): 195 | return "AutoAugment SVHN Policy" 196 | 197 | 198 | class SubPolicy(object): 199 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 200 | ranges = { 201 | "shearX": np.linspace(0, 0.3, 10), 202 | "shearY": np.linspace(0, 0.3, 10), 203 | "translateX": np.linspace(0, 150 / 331, 10), 204 | "translateY": np.linspace(0, 150 / 331, 10), 205 | "rotate": np.linspace(0, 30, 10), 206 | "color": np.linspace(0.0, 0.9, 10), 207 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 208 | "solarize": np.linspace(256, 0, 10), 209 | "contrast": np.linspace(0.0, 0.9, 10), 210 | "sharpness": np.linspace(0.0, 0.9, 10), 211 | "brightness": np.linspace(0.0, 0.9, 10), 212 | "autocontrast": [0] * 10, 213 | "equalize": [0] * 10, 214 | "invert": [0] * 10 215 | } 216 | 217 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 218 | def rotate_with_fill(img, magnitude): 219 | rot = img.convert("RGBA").rotate(magnitude) 220 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 221 | 222 | func = { 223 | "shearX": lambda img, magnitude: img.transform( 224 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 225 | Image.BICUBIC, fillcolor=fillcolor), 226 | "shearY": lambda img, magnitude: img.transform( 227 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 228 | Image.BICUBIC, fillcolor=fillcolor), 229 | "translateX": lambda img, magnitude: img.transform( 230 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 231 | fillcolor=fillcolor), 232 | "translateY": lambda img, magnitude: img.transform( 233 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 234 | fillcolor=fillcolor), 235 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 236 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 237 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 238 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 239 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 240 | 1 + magnitude * random.choice([-1, 1])), 241 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 242 | 1 + magnitude * random.choice([-1, 1])), 243 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 244 | 1 + magnitude * random.choice([-1, 1])), 245 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 246 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 247 | "invert": lambda img, magnitude: ImageOps.invert(img) 248 | } 249 | 250 | self.p1 = p1 251 | self.operation1 = func[operation1] 252 | self.magnitude1 = ranges[operation1][magnitude_idx1] 253 | self.p2 = p2 254 | self.operation2 = func[operation2] 255 | self.magnitude2 = ranges[operation2][magnitude_idx2] 256 | 257 | 258 | def __call__(self, img): 259 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1) 260 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2) 261 | return img -------------------------------------------------------------------------------- /downstream_tasks/imb_utils/moco_loader.py: -------------------------------------------------------------------------------- 1 | """https://github.com/facebookresearch/moco""" 2 | 3 | import random 4 | 5 | from PIL import ImageFilter 6 | 7 | 8 | class TwoCropsTransform: 9 | """Take two random crops of one image as the query and key.""" 10 | 11 | def __init__(self, base_transform): 12 | self.base_transform = base_transform 13 | 14 | def __call__(self, x): 15 | q = self.base_transform(x) 16 | k = self.base_transform(x) 17 | return [q, k] 18 | 19 | 20 | class GaussianBlur(object): 21 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 22 | 23 | def __init__(self, sigma=[0.1, 2.0]): 24 | self.sigma = sigma 25 | 26 | def __call__(self, x): 27 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 28 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 29 | return x 30 | -------------------------------------------------------------------------------- /downstream_tasks/imb_utils/randaugment.py: -------------------------------------------------------------------------------- 1 | # original code: https://github.com/dvlab-research/Parametric-Contrastive-Learning/blob/main/LT/randaugment.py 2 | """ AutoAugment and RandAugment 3 | Implementation adapted from: 4 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 5 | Papers: https://arxiv.org/abs/1805.09501, https://arxiv.org/abs/1906.11172, and https://arxiv.org/abs/1909.13719 6 | Hacked together by Ross Wightman 7 | """ 8 | import random 9 | import math 10 | import re 11 | from PIL import Image, ImageOps, ImageEnhance 12 | import PIL 13 | import numpy as np 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torchvision.transforms as transforms 18 | 19 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) 20 | 21 | _FILL = (128, 128, 128) 22 | 23 | # This signifies the max integer that the controller RNN could predict for the 24 | # augmentation scheme. 25 | _MAX_LEVEL = 10. 26 | 27 | _HPARAMS_DEFAULT = dict( 28 | translate_const=250, 29 | img_mean=_FILL, 30 | ) 31 | 32 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 33 | 34 | 35 | def _interpolation(kwargs): 36 | interpolation = kwargs.pop('resample', Image.BILINEAR) 37 | if isinstance(interpolation, (list, tuple)): 38 | return random.choice(interpolation) 39 | else: 40 | return interpolation 41 | 42 | 43 | def _check_args_tf(kwargs): 44 | if 'fillcolor' in kwargs and _PIL_VER < (5, 0): 45 | kwargs.pop('fillcolor') 46 | kwargs['resample'] = _interpolation(kwargs) 47 | 48 | 49 | def shear_x(img, factor, **kwargs): 50 | _check_args_tf(kwargs) 51 | return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) 52 | 53 | 54 | def shear_y(img, factor, **kwargs): 55 | _check_args_tf(kwargs) 56 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) 57 | 58 | 59 | def translate_x_rel(img, pct, **kwargs): 60 | pixels = pct * img.size[0] 61 | _check_args_tf(kwargs) 62 | return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) 63 | 64 | 65 | def translate_y_rel(img, pct, **kwargs): 66 | pixels = pct * img.size[1] 67 | _check_args_tf(kwargs) 68 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) 69 | 70 | 71 | def translate_x_abs(img, pixels, **kwargs): 72 | _check_args_tf(kwargs) 73 | return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) 74 | 75 | 76 | def translate_y_abs(img, pixels, **kwargs): 77 | _check_args_tf(kwargs) 78 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) 79 | 80 | 81 | def rotate(img, degrees, **kwargs): 82 | _check_args_tf(kwargs) 83 | if _PIL_VER >= (5, 2): 84 | return img.rotate(degrees, **kwargs) 85 | elif _PIL_VER >= (5, 0): 86 | w, h = img.size 87 | post_trans = (0, 0) 88 | rotn_center = (w / 2.0, h / 2.0) 89 | angle = -math.radians(degrees) 90 | matrix = [ 91 | round(math.cos(angle), 15), 92 | round(math.sin(angle), 15), 93 | 0.0, 94 | round(-math.sin(angle), 15), 95 | round(math.cos(angle), 15), 96 | 0.0, 97 | ] 98 | 99 | def transform(x, y, matrix): 100 | (a, b, c, d, e, f) = matrix 101 | return a * x + b * y + c, d * x + e * y + f 102 | 103 | matrix[2], matrix[5] = transform( 104 | -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix 105 | ) 106 | matrix[2] += rotn_center[0] 107 | matrix[5] += rotn_center[1] 108 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs) 109 | else: 110 | return img.rotate(degrees, resample=kwargs['resample']) 111 | 112 | 113 | def auto_contrast(img, **__): 114 | return ImageOps.autocontrast(img) 115 | 116 | 117 | def invert(img, **__): 118 | return ImageOps.invert(img) 119 | 120 | 121 | def identity(img, **__): 122 | return img 123 | 124 | 125 | def equalize(img, **__): 126 | return ImageOps.equalize(img) 127 | 128 | 129 | def solarize(img, thresh, **__): 130 | return ImageOps.solarize(img, thresh) 131 | 132 | 133 | def solarize_add(img, add, thresh=128, **__): 134 | lut = [] 135 | for i in range(256): 136 | if i < thresh: 137 | lut.append(min(255, i + add)) 138 | else: 139 | lut.append(i) 140 | if img.mode in ("L", "RGB"): 141 | if img.mode == "RGB" and len(lut) == 256: 142 | lut = lut + lut + lut 143 | return img.point(lut) 144 | else: 145 | return img 146 | 147 | 148 | def posterize(img, bits_to_keep, **__): 149 | if bits_to_keep >= 8: 150 | return img 151 | return ImageOps.posterize(img, bits_to_keep) 152 | 153 | 154 | def contrast(img, factor, **__): 155 | return ImageEnhance.Contrast(img).enhance(factor) 156 | 157 | 158 | def color(img, factor, **__): 159 | return ImageEnhance.Color(img).enhance(factor) 160 | 161 | 162 | def brightness(img, factor, **__): 163 | return ImageEnhance.Brightness(img).enhance(factor) 164 | 165 | 166 | def sharpness(img, factor, **__): 167 | return ImageEnhance.Sharpness(img).enhance(factor) 168 | 169 | 170 | def _randomly_negate(v): 171 | """With 50% prob, negate the value""" 172 | return -v if random.random() > 0.5 else v 173 | 174 | 175 | def _rotate_level_to_arg(level, _hparams): 176 | # range [-30, 30] 177 | level = (level / _MAX_LEVEL) * 30. 178 | level = _randomly_negate(level) 179 | return level, 180 | 181 | 182 | def _enhance_level_to_arg(level, _hparams): 183 | # range [0.1, 1.9] 184 | return (level / _MAX_LEVEL) * 1.8 + 0.1, 185 | 186 | 187 | def _shear_level_to_arg(level, _hparams): 188 | # range [-0.3, 0.3] 189 | level = (level / _MAX_LEVEL) * 0.3 190 | level = _randomly_negate(level) 191 | return level, 192 | 193 | 194 | def _translate_abs_level_to_arg(level, hparams): 195 | translate_const = hparams['translate_const'] 196 | level = (level / _MAX_LEVEL) * float(translate_const) 197 | level = _randomly_negate(level) 198 | return level, 199 | 200 | 201 | def _translate_rel_level_to_arg(level, _hparams): 202 | # range [-0.45, 0.45] 203 | level = (level / _MAX_LEVEL) * 0.45 204 | level = _randomly_negate(level) 205 | return level, 206 | 207 | 208 | def _posterize_original_level_to_arg(level, _hparams): 209 | # As per original AutoAugment paper description 210 | # range [4, 8], 'keep 4 up to 8 MSB of image' 211 | return int((level / _MAX_LEVEL) * 4) + 4, 212 | 213 | 214 | def _posterize_research_level_to_arg(level, _hparams): 215 | # As per Tensorflow models research and UDA impl 216 | # range [4, 0], 'keep 4 down to 0 MSB of original image' 217 | return 4 - int((level / _MAX_LEVEL) * 4), 218 | 219 | 220 | def _posterize_tpu_level_to_arg(level, _hparams): 221 | # As per Tensorflow TPU EfficientNet impl 222 | # range [0, 4], 'keep 0 up to 4 MSB of original image' 223 | return int((level / _MAX_LEVEL) * 4), 224 | 225 | 226 | def _solarize_level_to_arg(level, _hparams): 227 | # range [0, 256] 228 | return int((level / _MAX_LEVEL) * 256), 229 | 230 | 231 | def _solarize_add_level_to_arg(level, _hparams): 232 | # range [0, 110] 233 | return int((level / _MAX_LEVEL) * 110), 234 | 235 | 236 | LEVEL_TO_ARG = { 237 | 'AutoContrast': None, 238 | 'Equalize': None, 239 | 'Invert': None, 240 | 'Identity': None, 241 | 'Rotate': _rotate_level_to_arg, 242 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers 243 | 'PosterizeOriginal': _posterize_original_level_to_arg, 244 | 'PosterizeResearch': _posterize_research_level_to_arg, 245 | 'PosterizeTpu': _posterize_tpu_level_to_arg, 246 | 'Solarize': _solarize_level_to_arg, 247 | 'SolarizeAdd': _solarize_add_level_to_arg, 248 | 'Color': _enhance_level_to_arg, 249 | 'Contrast': _enhance_level_to_arg, 250 | 'Brightness': _enhance_level_to_arg, 251 | 'Sharpness': _enhance_level_to_arg, 252 | 'ShearX': _shear_level_to_arg, 253 | 'ShearY': _shear_level_to_arg, 254 | 'TranslateX': _translate_abs_level_to_arg, 255 | 'TranslateY': _translate_abs_level_to_arg, 256 | 'TranslateXRel': _translate_rel_level_to_arg, 257 | 'TranslateYRel': _translate_rel_level_to_arg, 258 | } 259 | 260 | 261 | NAME_TO_OP = { 262 | 'AutoContrast': auto_contrast, 263 | 'Equalize': equalize, 264 | 'Invert': invert, 265 | 'Identity': identity, 266 | 'Rotate': rotate, 267 | 'PosterizeOriginal': posterize, 268 | 'PosterizeResearch': posterize, 269 | 'PosterizeTpu': posterize, 270 | 'Solarize': solarize, 271 | 'SolarizeAdd': solarize_add, 272 | 'Color': color, 273 | 'Contrast': contrast, 274 | 'Brightness': brightness, 275 | 'Sharpness': sharpness, 276 | 'ShearX': shear_x, 277 | 'ShearY': shear_y, 278 | 'TranslateX': translate_x_abs, 279 | 'TranslateY': translate_y_abs, 280 | 'TranslateXRel': translate_x_rel, 281 | 'TranslateYRel': translate_y_rel, 282 | } 283 | 284 | 285 | class AutoAugmentOp: 286 | 287 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None): 288 | hparams = hparams or _HPARAMS_DEFAULT 289 | self.aug_fn = NAME_TO_OP[name] 290 | self.level_fn = LEVEL_TO_ARG[name] 291 | self.prob = prob 292 | self.magnitude = magnitude 293 | self.hparams = hparams.copy() 294 | self.kwargs = dict( 295 | fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, 296 | resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, 297 | ) 298 | 299 | # If magnitude_std is > 0, we introduce some randomness 300 | # in the usually fixed policy and sample magnitude from a normal distribution 301 | # with mean `magnitude` and std-dev of `magnitude_std`. 302 | # NOTE This is my own hack, being tested, not in papers or reference impls. 303 | self.magnitude_std = self.hparams.get('magnitude_std', 0) 304 | 305 | def __call__(self, img): 306 | if random.random() > self.prob: 307 | return img 308 | magnitude = self.magnitude 309 | if self.magnitude_std and self.magnitude_std > 0: 310 | magnitude = random.gauss(magnitude, self.magnitude_std) 311 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range 312 | level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() 313 | return self.aug_fn(img, *level_args, **self.kwargs) 314 | 315 | 316 | _RAND_TRANSFORMS = [ 317 | 'AutoContrast', 318 | 'Equalize', 319 | 'Invert', 320 | 'Rotate', 321 | 'PosterizeTpu', 322 | 'Solarize', 323 | 'SolarizeAdd', 324 | 'Color', 325 | 'Contrast', 326 | 'Brightness', 327 | 'Sharpness', 328 | 'ShearX', 329 | 'ShearY', 330 | 'TranslateXRel', 331 | 'TranslateYRel', 332 | #'Cutout' # FIXME I implement this as random erasing separately 333 | ] 334 | 335 | _RAND_TRANSFORMS_CMC = [ 336 | 'AutoContrast', 337 | 'Identity', 338 | 'Rotate', 339 | 'Sharpness', 340 | 'ShearX', 341 | 'ShearY', 342 | 'TranslateXRel', 343 | 'TranslateYRel', 344 | #'Cutout' # FIXME I implement this as random erasing separately 345 | ] 346 | 347 | 348 | # These experimental weights are based loosely on the relative improvements mentioned in paper. 349 | # They may not result in increased performance, but could likely be tuned to so. 350 | _RAND_CHOICE_WEIGHTS_0 = { 351 | 'Rotate': 0.3, 352 | 'ShearX': 0.2, 353 | 'ShearY': 0.2, 354 | 'TranslateXRel': 0.1, 355 | 'TranslateYRel': 0.1, 356 | 'Color': .025, 357 | 'Sharpness': 0.025, 358 | 'AutoContrast': 0.025, 359 | 'Solarize': .005, 360 | 'SolarizeAdd': .005, 361 | 'Contrast': .005, 362 | 'Brightness': .005, 363 | 'Equalize': .005, 364 | 'PosterizeTpu': 0, 365 | 'Invert': 0, 366 | } 367 | 368 | 369 | def _select_rand_weights(weight_idx=0, transforms=None): 370 | transforms = transforms or _RAND_TRANSFORMS 371 | assert weight_idx == 0 # only one set of weights currently 372 | rand_weights = _RAND_CHOICE_WEIGHTS_0 373 | probs = [rand_weights[k] for k in transforms] 374 | probs /= np.sum(probs) 375 | return probs 376 | 377 | 378 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None): 379 | """rand augment ops for RGB images""" 380 | hparams = hparams or _HPARAMS_DEFAULT 381 | transforms = transforms or _RAND_TRANSFORMS 382 | return [AutoAugmentOp( 383 | name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] 384 | 385 | 386 | def rand_augment_ops_cmc(magnitude=10, hparams=None, transforms=None): 387 | """rand augment ops for CMC images (removing color ops)""" 388 | hparams = hparams or _HPARAMS_DEFAULT 389 | transforms = transforms or _RAND_TRANSFORMS_CMC 390 | return [AutoAugmentOp( 391 | name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] 392 | 393 | 394 | class RandAugment: 395 | def __init__(self, ops, num_layers=2, choice_weights=None): 396 | self.ops = ops 397 | self.num_layers = num_layers 398 | self.choice_weights = choice_weights 399 | 400 | def __call__(self, img): 401 | # no replacement when using weighted choice 402 | ops = np.random.choice( 403 | self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) 404 | for op in ops: 405 | img = op(img) 406 | return img 407 | 408 | 409 | def rand_augment_transform(config_str, hparams, use_cmc=False): 410 | """ 411 | Create a RandAugment transform 412 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by 413 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining 414 | sections, not order sepecific determine 415 | 'm' - integer magnitude of rand augment 416 | 'n' - integer num layers (number of transform ops selected per image) 417 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 418 | 'mstd' - float std deviation of magnitude noise applied 419 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 420 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 421 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme 422 | :param use_cmc: Flag indicates removing augmentation for coloring ops. 423 | :return: A PyTorch compatible Transform 424 | """ 425 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) 426 | num_layers = 2 # default to 2 ops per image 427 | weight_idx = None # default to no probability weights for op choice 428 | config = config_str.split('-') 429 | assert config[0] == 'rand' 430 | config = config[1:] 431 | for c in config: 432 | cs = re.split(r'(\d.*)', c) 433 | if len(cs) < 2: 434 | continue 435 | key, val = cs[:2] 436 | if key == 'mstd': 437 | # noise param injected via hparams for now 438 | hparams.setdefault('magnitude_std', float(val)) 439 | elif key == 'm': 440 | magnitude = int(val) 441 | elif key == 'n': 442 | num_layers = int(val) 443 | elif key == 'w': 444 | weight_idx = int(val) 445 | else: 446 | assert False, 'Unknown RandAugment config section' 447 | if use_cmc: 448 | ra_ops = rand_augment_ops_cmc(magnitude=magnitude, hparams=hparams) 449 | else: 450 | ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams) 451 | choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) 452 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) 453 | 454 | 455 | class GaussianBlur(object): 456 | """blur a single image on CPU""" 457 | def __init__(self, kernel_size): 458 | radias = kernel_size // 2 459 | kernel_size = radias * 2 + 1 460 | self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), 461 | stride=1, padding=0, bias=False, groups=3) 462 | self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size), 463 | stride=1, padding=0, bias=False, groups=3) 464 | self.k = kernel_size 465 | self.r = radias 466 | 467 | self.blur = nn.Sequential( 468 | nn.ReflectionPad2d(radias), 469 | self.blur_h, 470 | self.blur_v 471 | ) 472 | 473 | self.pil_to_tensor = transforms.ToTensor() 474 | self.tensor_to_pil = transforms.ToPILImage() 475 | 476 | def __call__(self, img): 477 | img = self.pil_to_tensor(img).unsqueeze(0) 478 | 479 | sigma = np.random.uniform(0.1, 2.0) 480 | x = np.arange(-self.r, self.r + 1) 481 | x = np.exp(-np.power(x, 2) / (2 * sigma * sigma)) 482 | x = x / x.sum() 483 | x = torch.from_numpy(x).view(1, -1).repeat(3, 1) 484 | 485 | self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1)) 486 | self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k)) 487 | 488 | with torch.no_grad(): 489 | img = self.blur(img) 490 | img = img.squeeze() 491 | 492 | img = self.tensor_to_pil(img) 493 | 494 | return img 495 | -------------------------------------------------------------------------------- /downstream_tasks/imb_utils/util.py: -------------------------------------------------------------------------------- 1 | # original code: https://github.com/kaidic/LDAM-DRW/blob/master/utils.py 2 | import torch 3 | import torch.distributed as dist 4 | import shutil 5 | import os 6 | import numpy as np 7 | import matplotlib 8 | 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt 11 | from sklearn.metrics import confusion_matrix 12 | from sklearn.utils.multiclass import unique_labels 13 | 14 | 15 | class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): 16 | 17 | def __init__(self, dataset, indices=None, num_samples=None): 18 | # if indices is not provided, 19 | # all elements in the dataset will be considered 20 | self.indices = list(range(len(dataset))) \ 21 | if indices is None else indices 22 | 23 | # if num_samples is not provided, 24 | # draw `len(indices)` samples in each iteration 25 | self.num_samples = len(self.indices) \ 26 | if num_samples is None else num_samples 27 | 28 | # distribution of classes in the dataset 29 | label_to_count = [0] * len(np.unique(dataset.targets)) 30 | for idx in self.indices: 31 | label = self._get_label(dataset, idx) 32 | label_to_count[label] += 1 33 | 34 | beta = 0.9999 35 | effective_num = 1.0 - np.power(beta, label_to_count) 36 | per_cls_weights = (1.0 - beta) / np.array(effective_num) 37 | 38 | # weight for each sample 39 | weights = [per_cls_weights[self._get_label(dataset, idx)] 40 | for idx in self.indices] 41 | self.weights = torch.DoubleTensor(weights) 42 | 43 | def _get_label(self, dataset, idx): 44 | return dataset.targets[idx] 45 | 46 | def __iter__(self): 47 | return iter(torch.multinomial(self.weights, self.num_samples, replacement=True).tolist()) 48 | 49 | def __len__(self): 50 | return self.num_samples 51 | 52 | 53 | def calc_confusion_mat(val_loader, model, args): 54 | model.eval() 55 | all_preds = [] 56 | all_targets = [] 57 | with torch.no_grad(): 58 | for i, (input, target) in enumerate(val_loader): 59 | if args.gpu is not None: 60 | input = input.cuda(args.gpu, non_blocking=True) 61 | target = target.cuda(args.gpu, non_blocking=True) 62 | 63 | # compute output 64 | output = model(input) 65 | _, pred = torch.max(output, 1) 66 | all_preds.extend(pred.cpu().numpy()) 67 | all_targets.extend(target.cpu().numpy()) 68 | cf = confusion_matrix(all_targets, all_preds).astype(float) 69 | 70 | cls_cnt = cf.sum(axis=1) 71 | cls_hit = np.diag(cf) 72 | 73 | cls_acc = cls_hit / cls_cnt 74 | 75 | print('Class Accuracy : ') 76 | print(cls_acc) 77 | classes = [str(x) for x in args.cls_num_list] 78 | plot_confusion_matrix(all_targets, all_preds, classes) 79 | plt.savefig(os.path.join(args.root_log, args.store_name, 'confusion_matrix.png')) 80 | 81 | 82 | def plot_confusion_matrix(y_true, y_pred, classes, 83 | normalize=False, 84 | title=None, 85 | cmap=plt.cm.Blues): 86 | if not title: 87 | if normalize: 88 | title = 'Normalized confusion matrix' 89 | else: 90 | title = 'Confusion matrix, without normalization' 91 | 92 | # Compute confusion matrix 93 | cm = confusion_matrix(y_true, y_pred) 94 | 95 | fig, ax = plt.subplots() 96 | im = ax.imshow(cm, interpolation='nearest', cmap=cmap) 97 | ax.figure.colorbar(im, ax=ax) 98 | # We want to show all ticks... 99 | ax.set(xticks=np.arange(cm.shape[1]), 100 | yticks=np.arange(cm.shape[0]), 101 | # ... and label them with the respective list entries 102 | xticklabels=classes, yticklabels=classes, 103 | title=title, 104 | ylabel='True label', 105 | xlabel='Predicted label') 106 | 107 | # Rotate the tick labels and set their alignment. 108 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", 109 | rotation_mode="anchor") 110 | 111 | # Loop over data dimensions and create text annotations. 112 | fmt = '.2f' if normalize else 'd' 113 | thresh = cm.max() / 2. 114 | for i in range(cm.shape[0]): 115 | for j in range(cm.shape[1]): 116 | ax.text(j, i, format(cm[i, j], fmt), 117 | ha="center", va="center", 118 | color="white" if cm[i, j] > thresh else "black") 119 | fig.tight_layout() 120 | return ax 121 | 122 | 123 | def prepare_folders(args): 124 | folders_util = [args.root_log, args.root_model, 125 | os.path.join(args.root_log, args.store_name), 126 | os.path.join(args.root_model, args.store_name)] 127 | for folder in folders_util: 128 | if not os.path.exists(folder): 129 | print('creating folder ' + folder) 130 | os.mkdir(folder) 131 | 132 | 133 | def save_checkpoint(args, state, is_best, epoch): 134 | filename = '%s/%s/ckpt.pth.tar' % (args.root_model, args.store_name) 135 | torch.save(state, filename) 136 | if is_best: 137 | shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar')) 138 | if epoch % 20 == 0: 139 | filename = '%s/%s/%s_ckpt.pth.tar' % (args.root_model, args.store_name, str(epoch)) 140 | torch.save(state, filename) 141 | 142 | 143 | class AverageMeter(object): 144 | 145 | def __init__(self, name, fmt=':f'): 146 | self.name = name 147 | self.fmt = fmt 148 | self.reset() 149 | 150 | def reset(self): 151 | self.val = 0 152 | self.avg = 0 153 | self.sum = 0 154 | self.count = 0 155 | 156 | def update(self, val, n=1): 157 | self.val = val 158 | self.sum += val * n 159 | self.count += n 160 | self.avg = self.sum / self.count 161 | 162 | def __str__(self): 163 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 164 | return fmtstr.format(**self.__dict__) 165 | 166 | 167 | def accuracy(output, target, topk=(1,)): 168 | # print("output", output.shape) 169 | # print("target", target.shape) 170 | with torch.no_grad(): 171 | maxk = max(topk) 172 | batch_size = target.size(0) 173 | 174 | _, pred = output.topk(maxk, 1, True, True) 175 | pred = pred.t() 176 | if target.dim() == 1: 177 | target = target.reshape(1, -1) 178 | elif target.dim() == 2: 179 | target = target.argmax(dim=1).reshape(1, -1) 180 | correct = pred.eq(target.expand_as(pred)) 181 | 182 | res = [] 183 | for k in topk: 184 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 185 | res.append(correct_k.mul_(100.0 / batch_size)) 186 | return res 187 | 188 | -------------------------------------------------------------------------------- /downstream_tasks/losses.py: -------------------------------------------------------------------------------- 1 | # reference code: https://github.com/kaidic/LDAM-DRW/blob/master/cifar_train.py 2 | 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | 10 | 11 | def focal_loss(input_values, gamma): 12 | """Computes the focal loss""" 13 | p = torch.exp(-input_values) 14 | # loss = (1 - p) ** gamma * input_values 15 | loss = (1 - p) ** gamma * input_values * 10 16 | return loss.mean() 17 | 18 | 19 | class FocalLoss(nn.Module): 20 | def __init__(self, weight=None, gamma=0.0): 21 | super(FocalLoss, self).__init__() 22 | assert gamma >= 0 23 | self.gamma = gamma 24 | self.weight = weight 25 | 26 | def forward(self, input, target): 27 | return focal_loss( 28 | F.cross_entropy(input, target, reduction="none", weight=self.weight), 29 | self.gamma, 30 | ) 31 | 32 | 33 | class LDAMLoss(nn.Module): 34 | def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30): 35 | super(LDAMLoss, self).__init__() 36 | m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list)) 37 | m_list = m_list * (max_m / np.max(m_list)) 38 | m_list = torch.cuda.FloatTensor(m_list) 39 | self.m_list = m_list 40 | assert s > 0 41 | self.s = s 42 | self.weight = weight 43 | 44 | def forward(self, x, target): 45 | index = torch.zeros_like(x, dtype=torch.uint8) 46 | index.scatter_(1, target.data.view(-1, 1), 1) 47 | 48 | index_float = index.type(torch.cuda.FloatTensor) 49 | batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0, 1)) 50 | batch_m = batch_m.view((-1, 1)) 51 | x_m = x - batch_m 52 | 53 | output = torch.where(index, x_m, x) 54 | return F.cross_entropy(self.s * output, target, weight=self.weight) 55 | 56 | 57 | class BalancedSoftmaxLoss(nn.Module): 58 | def __init__(self, cls_num_list): 59 | super().__init__() 60 | cls_prior = cls_num_list / sum(cls_num_list) 61 | self.log_prior = torch.log(cls_prior).unsqueeze(0) 62 | # self.min_prob = 1e-9 63 | # print(f'Use BalancedSoftmaxLoss, class_prior: {cls_prior}') 64 | 65 | def forward(self, logits, labels): 66 | adjusted_logits = logits + self.log_prior 67 | label_loss = F.cross_entropy(adjusted_logits, labels) 68 | 69 | return label_loss 70 | 71 | 72 | class LabelSmoothing(nn.Module): 73 | # "Implement label smoothing." 74 | 75 | def __init__(self, size, smoothing=0.0): 76 | super(LabelSmoothing, self).__init__() 77 | self.criterion = nn.KLDivLoss(size_average=False) 78 | # self.padding_idx = padding_idx 79 | self.confidence = 1.0 - smoothing 80 | self.smoothing = smoothing 81 | self.size = size 82 | self.true_dist = None 83 | 84 | def forward(self, x, target): 85 | """ 86 | x表示输入 (M,N)N个样本,M表示总类数,每一个类的概率log P 87 | target表示label(M,) 88 | """ 89 | assert x.size(1) == self.size 90 | x = x.log() 91 | true_dist = x.data.clone() # 先深复制过来 92 | # print true_dist 93 | true_dist.fill_(self.smoothing / (self.size - 1)) # otherwise的公式 94 | # print true_dist 95 | # 变成one-hot编码,1表示按列填充, 96 | # target.data.unsqueeze(1)表示索引,confidence表示填充的数字 97 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 98 | 99 | self.true_dist = true_dist 100 | print(x.shape, true_dist.shape) 101 | 102 | return self.criterion(x, Variable(true_dist, requires_grad=False)) 103 | 104 | 105 | class LabelSmoothingLoss(nn.Module): 106 | def __init__(self, classes, smoothing=0.0, dim=-1): 107 | super(LabelSmoothingLoss, self).__init__() 108 | self.confidence = 1.0 - smoothing 109 | self.smoothing = smoothing 110 | self.cls = classes 111 | self.dim = dim 112 | 113 | def forward(self, pred, target): 114 | pred = pred.log_softmax(dim=self.dim) 115 | with torch.no_grad(): 116 | # true_dist = pred.data.clone() 117 | true_dist = torch.zeros_like(pred) 118 | true_dist.fill_(self.smoothing / (self.cls - 1)) 119 | if target.dim() == 1: 120 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 121 | else: 122 | true_dist = true_dist + target * self.confidence 123 | return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) 124 | -------------------------------------------------------------------------------- /downstream_tasks/mixup.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from cutmix.cutmix import CutMix 6 | from cutmix.utils import onehot, rand_bbox 7 | from torch.utils.data.dataset import Dataset 8 | 9 | 10 | def calculate_confusion_matrix(pred, target): 11 | """Calculate confusion matrix according to the prediction and target. 12 | 13 | Args: 14 | pred (torch.Tensor | np.array): The model prediction with shape (N, C). 15 | target (torch.Tensor | np.array): The target of each prediction with 16 | shape (N, 1) or (N,). 17 | 18 | Returns: 19 | torch.Tensor: Confusion matrix 20 | The shape is (C, C), where C is the number of classes. 21 | """ 22 | 23 | if isinstance(pred, np.ndarray): 24 | pred = torch.from_numpy(pred) 25 | if isinstance(target, np.ndarray): 26 | target = torch.from_numpy(target) 27 | assert isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor), ( 28 | f"pred and target should be torch.Tensor or np.ndarray, " 29 | f"but got {type(pred)} and {type(target)}." 30 | ) 31 | 32 | # Modified from PyTorch-Ignite 33 | num_classes = pred.size(1) 34 | pred_label = torch.argmax(pred, dim=1).flatten() 35 | target_label = target.flatten() 36 | assert len(pred_label) == len(target_label) 37 | 38 | with torch.no_grad(): 39 | indices = num_classes * target_label + pred_label 40 | matrix = torch.bincount(indices, minlength=num_classes**2) 41 | matrix = matrix.reshape(num_classes, num_classes) 42 | return matrix.detach().cpu().numpy() 43 | 44 | 45 | def calculate_accuracy(pred, target): 46 | _, predicted_labels = torch.max(pred, dim=1) 47 | correct_predictions = torch.sum(predicted_labels == target) 48 | total_samples = target.size(0) 49 | 50 | accuracy = correct_predictions.item() / total_samples 51 | return accuracy 52 | 53 | 54 | def is_vector_label(x): 55 | if isinstance(x, np.ndarray): 56 | return x.size > 1 57 | elif isinstance(x, torch.Tensor): 58 | return x.size().numel() > 1 59 | elif isinstance(x, int): 60 | return False 61 | else: 62 | raise TypeError(f"Unknown type {type(x)}") 63 | 64 | 65 | class CutMix(Dataset): 66 | def __init__(self, dataset, num_class, num_mix=1, beta=1.0, prob=1.0): 67 | self.dataset = dataset 68 | self.num_class = num_class 69 | self.num_mix = num_mix 70 | self.beta = beta 71 | self.prob = prob 72 | 73 | def __getitem__(self, index): 74 | example = self.dataset[index] 75 | img, lb = example["pixel_values"], example["label"] 76 | lb_onehot = lb if is_vector_label(lb) else onehot(self.num_class, lb) 77 | 78 | for _ in range(self.num_mix): 79 | r = np.random.rand(1) 80 | if self.beta <= 0 or r > self.prob: 81 | continue 82 | 83 | # generate mixed sample 84 | lam = np.random.beta(self.beta, self.beta) 85 | rand_index = random.choice(range(len(self))) 86 | 87 | rand_example = self.dataset[rand_index] 88 | img2, lb2 = rand_example["pixel_values"], rand_example["label"] 89 | lb2_onehot = lb2 if is_vector_label(lb2) else onehot(self.num_class, lb2) 90 | bbx1, bby1, bbx2, bby2 = rand_bbox(img.size(), lam) 91 | img[:, bbx1:bbx2, bby1:bby2] = img2[:, bbx1:bbx2, bby1:bby2] 92 | lam = 1 - ( 93 | (bbx2 - bbx1) * (bby2 - bby1) / (img.size()[-1] * img.size()[-2]) 94 | ) 95 | lb_onehot = lb_onehot * lam + lb2_onehot * (1.0 - lam) 96 | 97 | return {"pixel_values": img, "label": lb_onehot} 98 | 99 | def __len__(self): 100 | return len(self.dataset) 101 | 102 | 103 | def mixup_data(x, y, alpha=1, num_classes=200): 104 | """Compute the mixup data. Return mixed inputs, pairs of targets, and lambda""" 105 | if alpha > 0.0: 106 | lam = np.random.beta(alpha, alpha) 107 | else: 108 | lam = 1.0 109 | batch_size = x.size()[0] 110 | index = torch.randperm(batch_size).to(x.device) 111 | 112 | mixed_x = lam * x + (1 - lam) * x[index, :] 113 | if is_vector_label(y): 114 | mixed_y = lam * y + (1 - lam) * y[index] 115 | else: 116 | mixed_y = onehot(y, num_classes) * lam + onehot(y[index], num_classes) * ( 117 | 1 - lam 118 | ) 119 | return mixed_x, mixed_y 120 | -------------------------------------------------------------------------------- /outputs: -------------------------------------------------------------------------------- 1 | ../../outputs/da-fusion/outputs/ -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhicaiwww/Diff-Mix/a81337b1492bcc7f7a8a61921836e94191a3a0ef/requirement.txt -------------------------------------------------------------------------------- /scripts/classification.sh: -------------------------------------------------------------------------------- 1 | #dataset (dog cub car aircraft flower pet food chest caltech pascal ) 2 | #lr(resnet) (0.001 0.05 0.10 0.10 0.05 0.01 0.01 0.01 0.01 0.01 ) 3 | #lr(vit) (0.00001 0.001 0.001 0.001 0.0005 0.0001 0.00005 0.00005 0.00005 0.00005 ) 4 | 5 | GPU=1 6 | DATASET="cub" 7 | SHOT=-1 8 | # "shot{args.examples_per_class}_{args.sample_strategy}_{args.strength_strategy}_{args.aug_strength}" 9 | SYNDATA_DIR="aug_samples/cub/shot${SHOT}_diff-mix_fixed_0.7" # shot-1 denotes full shot 10 | SYNDATA_P=0.1 11 | GAMMA=0.8 12 | 13 | python downstream_tasks/train_hub.py \ 14 | --dataset $DATASET \ 15 | --syndata_dir $SYNDATA_DIR \ 16 | --syndata_p $SYNDATA_P \ 17 | --model "resnet50" \ 18 | --gamma $GAMMA \ 19 | --examples_per_class $SHOT \ 20 | --gpu $GPU \ 21 | --amp 2 \ 22 | --note $(date +%m%d%H%M) \ 23 | --group_note "fullshot" \ 24 | --nepoch 120 \ 25 | --res_mode 224 \ 26 | --lr 0.05 \ 27 | --seed 0 \ 28 | --weight_decay 0.0005 -------------------------------------------------------------------------------- /scripts/classification_imb.sh: -------------------------------------------------------------------------------- 1 | # CMO 2 | gpu=1 3 | datast='cub' 4 | imb_factor=0.01 5 | GAMMA=0.8 6 | # "imb{args.imbalance_factor}_{args.sample_strategy}_{args.strength_strategy}_{args.aug_strength}" 7 | SYNDATA_DIR="aug_samples/cub/imb${imb_factor}_diff-mix_fixed_0.7" 8 | SYNDATA_P=0.1 9 | 10 | python downstream_tasks/train_hub_imb.py \ 11 | --dataset $datast \ 12 | --loss_type CE \ 13 | --lr 0.005 \ 14 | --epochs 200 \ 15 | --imb_factor $imb_factor \ 16 | -b 128 \ 17 | --gpu $gpu \ 18 | --root_log outputs/results_cmo \ 19 | --data_aug CMO 20 | 21 | # DRW 22 | python downstream_tasks/train_hub_imb.py \ 23 | --dataset $datast \ 24 | --loss_type CE \ 25 | --lr 0.005 \ 26 | --epochs 200 \ 27 | --imb_factor $imb_factor \ 28 | -b 128 \ 29 | --gpu $gpu \ 30 | --data_aug vanilla \ 31 | --root_log outputs/results_cmo \ 32 | --train_rule DRW 33 | 34 | # baseline 35 | python downstream_tasks/train_hub_imb.py \ 36 | --dataset $datast \ 37 | --loss_type CE \ 38 | --lr 0.005 \ 39 | --epochs 200 \ 40 | --imb_factor $imb_factor \ 41 | -b 128 \ 42 | --gpu $gpu \ 43 | --data_aug vanilla \ 44 | --root_log outputs/results_cmo 45 | 46 | # weightedSyn 47 | python downstream_tasks/train_hub_imb.py \ 48 | --dataset $datast \ 49 | --loss_type CE \ 50 | --lr 0.005 \ 51 | --epochs 200 \ 52 | --imb_factor $imb_factor \ 53 | -b 128 \ 54 | --gpu $gpu \ 55 | --data_aug vanilla \ 56 | --root_log outputs/results_cmo \ 57 | --syndata_dir $SYNDATA_DIR \ 58 | --syndata_p $SYNDATA_P \ 59 | --gamma $GAMMA \ 60 | --use_weighted_syn 61 | 62 | -------------------------------------------------------------------------------- /scripts/classification_waterbird.sh: -------------------------------------------------------------------------------- 1 | GPU=1 2 | DATASET="cub" 3 | SHOT=-1 4 | SYNDATA_DIR="aug_samples/cub/shot${SHOT}_diff-mix_fixed_0.7" # shot-1 denotes full shot 5 | SYNDATA_P=0.1 6 | GAMMA=0.8 7 | 8 | python downstream_tasks/train_hub_waterbird.py \ 9 | --dataset $DATASET \ 10 | --syndata_dir $SYNDATA_DIR \ 11 | --syndata_p $SYNDATA_P \ 12 | --model "resnet50" \ 13 | --gamma $GAMMA \ 14 | --examples_per_class $SHOT \ 15 | --gpu $GPU \ 16 | --amp 2 \ 17 | --note $(date +%m%d%H%M) \ 18 | --group_note "robustness" \ 19 | --nepoch 120 \ 20 | --res_mode 224 \ 21 | --lr 0.05 \ 22 | --seed 0 \ 23 | --weight_decay 0.0005 -------------------------------------------------------------------------------- /scripts/compose_syn_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import re 4 | import shutil 5 | from collections import defaultdict 6 | 7 | import pandas as pd 8 | from PIL import Image 9 | from tqdm import tqdm 10 | 11 | from utils.misc import check_synthetic_dir_valid 12 | 13 | 14 | def generate_meta_csv(output_path): 15 | rootdir = os.path.join(output_path, "data") 16 | if os.path.exists(os.path.join(output_path, "meta.csv")): 17 | return 18 | pattern_level_1 = r"(.+)" 19 | pattern_level_2 = r"(.+)-(\d+)-(.+).png" 20 | 21 | # Generate meta.csv for indexing images 22 | data_dict = defaultdict(list) 23 | for dir in os.listdir(rootdir): 24 | if not os.path.isdir(os.path.join(rootdir, dir)): 25 | continue 26 | match_1 = re.match(pattern_level_1, dir) 27 | first_dir = match_1.group(1).replace("_", " ") 28 | for file in os.listdir(os.path.join(rootdir, dir)): 29 | match_2 = re.match(pattern_level_2, file) 30 | second_dir = match_2.group(1).replace("_", " ") 31 | num = int(match_2.group(2)) 32 | floating_num = float(match_2.group(3)) 33 | data_dict["First Directory"].append(first_dir) 34 | data_dict["Second Directory"].append(second_dir) 35 | data_dict["Number"].append(num) 36 | data_dict["Strength"].append(floating_num) 37 | data_dict["Path"].append(os.path.join(dir, file)) 38 | 39 | df = pd.DataFrame(data_dict) 40 | 41 | # Validate generated images 42 | valid_rows = [] 43 | for index, row in tqdm(df.iterrows(), total=len(df)): 44 | image_path = os.path.join(output_path, "data", row["Path"]) 45 | try: 46 | img = Image.open(image_path) 47 | img.close() 48 | valid_rows.append(row) 49 | except Exception as e: 50 | os.remove(image_path) 51 | print(f"Deleted {image_path} due to error: {str(e)}") 52 | 53 | valid_df = pd.DataFrame(valid_rows) 54 | csv_path = os.path.join(output_path, "meta.csv") 55 | valid_df.to_csv(csv_path, index=False) 56 | 57 | print("DataFrame:") 58 | print(df) 59 | 60 | 61 | def main(source_directory_list, target_directory, num_samples): 62 | 63 | target_directory = os.path.join(target_directory, "data") 64 | 65 | os.makedirs(target_directory, exist_ok=True) 66 | 67 | image_files = [] 68 | image_class_names = [] 69 | for source_directory in source_directory_list: 70 | source_directory = os.path.join(source_directory, "data") 71 | for class_name in os.listdir(source_directory): 72 | class_directory = os.path.join(source_directory, class_name) 73 | target_class_directory = os.path.join(target_directory, class_name) 74 | os.makedirs(target_class_directory, exist_ok=True) 75 | for filename in os.listdir(class_directory): 76 | if filename.endswith(".png"): 77 | image_class_names.append(class_name) 78 | image_files.append(os.path.join(class_directory, filename)) 79 | # random sample idx 80 | random_indices = random.sample( 81 | range(len(image_files)), min(int(num_samples), len(image_files)) 82 | ) 83 | 84 | selected_image_files = [image_files[i] for i in random_indices] 85 | selected_image_class_names = [image_class_names[i] for i in random_indices] 86 | 87 | for class_name, image_file in tqdm( 88 | zip(selected_image_class_names, selected_image_files), 89 | desc="Copying data", 90 | total=num_samples, 91 | ): 92 | shutil.copy(image_file, os.path.join(target_directory, class_name)) 93 | 94 | 95 | # python sample_synthetic_subset.py --num_samples 1000 --dataset cub --source_syn realgen --target_directory outputs/aug_samples_1shot/cub/real-gen-Multi5 96 | # python sample_synthetic_subset.py --num_samples 16670 --dataset aircraft --source_syn mixup0.5 mixup0.7 mixup0.9 --target_directory outputs/aug_samples/aircraft/diff-mix-Uniform 97 | if __name__ == "__main__": 98 | import argparse 99 | 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument( 102 | "--num_samples", type=int, default=40000, help="Number of samples" 103 | ) 104 | parser.add_argument("--dataset", type=str, default="cub", help="Dataset name") 105 | parser.add_argument("--source_aug_dir", type=str, nargs="+", default="mixup") 106 | parser.add_argument( 107 | "--target_directory", 108 | type=str, 109 | default="aug_samples/cub/diff-mix-Uniform", 110 | help="Target directory", 111 | ) 112 | args = parser.parse_args() 113 | source_directory_list = [] 114 | for synthetic_dir in args.source_aug_dir: 115 | check_synthetic_dir_valid(synthetic_dir) 116 | generate_meta_csv(synthetic_dir) 117 | source_directory_list.append(synthetic_dir) 118 | target_directory = args.target_directory 119 | main(source_directory_list, target_directory, args.num_samples) 120 | generate_meta_csv(target_directory) 121 | -------------------------------------------------------------------------------- /scripts/filter_syn_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | import tqdm 8 | from transformers import CLIPModel, CLIPProcessor 9 | 10 | from dataset.base import SyntheticDataset 11 | 12 | FILTER_CAPTIONS_MAPPING = { 13 | "cub": ["a photo with a bird on it", "a photo without a bird on it"], 14 | "aircraft": [ 15 | "a photo with an aricraft on it. ", 16 | "a photo without an aricraft on it. ", 17 | ], 18 | "dog": [ 19 | "a photo with a dog on it. ", 20 | "a photo without a dog on it. ", 21 | ], 22 | "flower": [ 23 | "a photo with a flower on it. ", 24 | "a photo without a flower on it. ", 25 | ], 26 | "car": [ 27 | "a photo with a car on it. ", 28 | "a photo without an car on it. ", 29 | ], 30 | } 31 | 32 | 33 | def to_tensor(x): 34 | if isinstance(x, int): 35 | return torch.tensor(x) 36 | elif isinstance(x, torch.Tensor): 37 | return x 38 | else: 39 | raise NotImplementedError 40 | 41 | 42 | def syn_collate_fn(examples): 43 | pixel_values = [example["pixel_values"] for example in examples] 44 | src_labels = torch.stack([to_tensor(example["src_label"]) for example in examples]) 45 | tar_labels = torch.stack([to_tensor(example["tar_label"]) for example in examples]) 46 | dtype = torch.float32 if len(src_labels.size()) == 2 else torch.long 47 | src_labels.to(dtype=dtype) 48 | tar_labels.to(dtype=dtype) 49 | return { 50 | "pixel_values": pixel_values, 51 | "src_labels": src_labels, 52 | "tar_labels": tar_labels, 53 | } 54 | 55 | 56 | def main(args): 57 | device = f"cuda:{args.gpu}" 58 | bs = args.batch_size 59 | model = CLIPModel.from_pretrained( 60 | "openai/clip-vit-base-patch32", local_files_only=True 61 | ).to(device) 62 | processor = CLIPProcessor.from_pretrained( 63 | "openai/clip-vit-base-patch32", local_files_only=True 64 | ) 65 | ds_syn = SyntheticDataset(synthetic_dir=args.syndata_dir) 66 | ds_syn.transform = torch.nn.Identity() 67 | dataloader_syn = torch.utils.data.DataLoader( 68 | ds_syn, batch_size=bs, collate_fn=syn_collate_fn, shuffle=False, num_workers=4 69 | ) 70 | positive_confidence = [] 71 | 72 | for batch in tqdm.tqdm(dataloader_syn, total=len(dataloader_syn)): 73 | 74 | images = batch["pixel_values"] 75 | inputs = processor( 76 | text=FILTER_CAPTIONS_MAPPING[args.dataset], 77 | images=images, 78 | return_tensors="pt", 79 | padding=True, 80 | ).to(device) 81 | 82 | outputs = model(**inputs) 83 | logits_per_image = ( 84 | outputs.logits_per_image 85 | ) # this is the image-text similarity score 86 | probs = logits_per_image.softmax( 87 | dim=1 88 | ) # we can take the softmax to get the label probabilities 89 | probs = probs.cpu().detach().numpy() 90 | positive_confidence = np.concatenate((positive_confidence, probs[:, 0])) 91 | # filter the least 10% confident samples 92 | positive_confidence = np.array(positive_confidence) 93 | bottom_threshold = np.percentile(positive_confidence, 10) 94 | up_threshold = np.percentile(positive_confidence, 90) 95 | meta_df = pd.read_csv(os.path.join(args.synthetic_dir, "meta.csv")) 96 | meta_df1 = meta_df[positive_confidence >= bottom_threshold] 97 | meta_df2 = meta_df[positive_confidence < bottom_threshold] 98 | meta_df3 = meta_df[positive_confidence >= up_threshold] 99 | 100 | meta_df1.to_csv(os.path.join(args.synthetic_dir, "meta_10-100per.csv"), index=False) 101 | meta_df2.to_csv(os.path.join(args.synthetic_dir, "meta_0-10per.csv"), index=False) 102 | meta_df3.to_csv(os.path.join(args.synthetic_dir, "meta_90-100per.csv"), index=False) 103 | 104 | 105 | if __name__ == "__main__": 106 | parsers = argparse.ArgumentParser() 107 | parsers.add_argument("--dataset", type=str, default="cub") 108 | parsers.add_argument("--syndata_dir", type=str, required=True) 109 | parsers.add_argument("-g", "--gpu", type=str, default="0") 110 | parsers.add_argument("-b", "--batch_size", type=int, default=200) 111 | args = parsers.parse_args() 112 | main(args) 113 | -------------------------------------------------------------------------------- /scripts/finetune.sh: -------------------------------------------------------------------------------- 1 | MODEL_NAME="runwayml/stable-diffusion-v1-5" 2 | DATASET='cub' 3 | SHOT=-1 # set -1 for full shot 4 | OUTPUT_DIR="ckpts/${DATASET}/shot${SHOT}_lora_rank10" 5 | 6 | accelerate launch --mixed_precision='fp16' --main_process_port 29507 \ 7 | train_lora.py \ 8 | --pretrained_model_name_or_path=$MODEL_NAME \ 9 | --dataset_name=$DATASET \ 10 | --resolution=224 \ 11 | --random_flip \ 12 | --max_train_steps=35000 \ 13 | --num_train_epochs=10 \ 14 | --checkpointing_steps=5000 \ 15 | --learning_rate=5e-05 \ 16 | --lr_scheduler='constant' \ 17 | --lr_warmup_steps=0 \ 18 | --seed=42 \ 19 | --rank=10 \ 20 | --local_files_only \ 21 | --examples_per_class $SHOT \ 22 | --train_batch_size 2 \ 23 | --output_dir=$OUTPUT_DIR \ 24 | --report_to='tensorboard'" 25 | -------------------------------------------------------------------------------- /scripts/finetune_imb.sh: -------------------------------------------------------------------------------- 1 | MODEL_NAME="runwayml/stable-diffusion-v1-5" 2 | DATASET='cub' 3 | SHOT=-1 # set -1 for full shot 4 | IMB_FACTOR=0.01 5 | OUTPUT_DIR="ckpts/${DATASET}/imb{$IMB_FACTOR}_lora_rank10" 6 | 7 | accelerate launch --mixed_precision='fp16' --main_process_port 29507 \ 8 | train_lora.py \ 9 | --pretrained_model_name_or_path=$MODEL_NAME \ 10 | --dataset_name=$DATASET \ 11 | --resolution=224 \ 12 | --random_flip \ 13 | --max_train_steps=35000 \ 14 | --num_train_epochs=10 \ 15 | --checkpointing_steps=5000 \ 16 | --learning_rate=5e-05 \ 17 | --lr_scheduler='constant' \ 18 | --lr_warmup_steps=0 \ 19 | --seed=42 \ 20 | --rank=10 \ 21 | --local_files_only \ 22 | --examples_per_class $SHOT \ 23 | --train_batch_size 2 \ 24 | --output_dir=$OUTPUT_DIR \ 25 | --report_to='tensorboard' \ 26 | --task='imbalanced' \ 27 | --imbalance_factor $IMB_FACTOR 28 | 29 | -------------------------------------------------------------------------------- /scripts/sample.sh: -------------------------------------------------------------------------------- 1 | 2 | DATASET='cub' 3 | # set -1 for full shot 4 | SHOT=-1 5 | FINETUNED_CKPT="ckpts/cub/shot${SHOT}-lora-rank10" 6 | # ['diff-mix', 'diff-aug', 'diff-gen', 'real-mix', 'real-aug', 'real-gen', 'ti_mix', 'ti_aug'] 7 | SAMPLE_STRATEGY='diff-mix' 8 | STRENGTH=0.8 9 | # ['fixed', 'uniform']. 'fixed': use fixed $STRENGTH, 'uniform': sample from [0.3, 0.5, 0.7, 0.9] 10 | STRENGTH_STRATEGY='fixed' 11 | # expand the dataset by 5 times 12 | MULTIPLIER=5 13 | # spwan 4 processes 14 | GPU_IDS=(0 1 2 3) 15 | 16 | python scripts/sample_mp.py \ 17 | --model-path='runwayml/stable-diffusion-v1-5' \ 18 | --output_root='outputs/aug_samples' \ 19 | --dataset=$DATASET \ 20 | --finetuned_ckpt=$FINETUNED_CKPT \ 21 | --syn_dataset_mulitiplier=$MULTIPLIER \ 22 | --strength_strategy=$STRENGTH_STRATEGY \ 23 | --sample_strategy=$SAMPLE_STRATEGY \ 24 | --examples_per_class=$SHOT \ 25 | --resolution=512 \ 26 | --batch_size=1 \ 27 | --aug_strength=0.8 \ 28 | --gpu-ids=${GPU_IDS[@]} 29 | 30 | -------------------------------------------------------------------------------- /scripts/sample_imb.sh: -------------------------------------------------------------------------------- 1 | 2 | DATASET='cub' 3 | # set -1 for full shot 4 | SHOT=-1 5 | # ['diff-mix', 'diff-aug', 'diff-gen', 'real-mix', 'real-aug', 'real-gen', 'ti_mix', 'ti_aug'] 6 | SAMPLE_STRATEGY='diff-mix' 7 | STRENGTH=0.8 8 | # ['fixed', 'uniform']. 'fixed': use fixed $STRENGTH, 'uniform': sample from [0.3, 0.5, 0.7, 0.9] 9 | STRENGTH_STRATEGY='fixed' 10 | # expand the dataset by 5 times 11 | MULTIPLIER=5 12 | # spwan 4 processes 13 | IMB_FACTOR=0.01 14 | FINETUNED_CKPT="ckpts/cub/imb${IMB_FACTOR}-lora-rank10" 15 | GPU_IDS=(0 1 2 3) 16 | 17 | python scripts/sample_mp.py \ 18 | --model-path='runwayml/stable-diffusion-v1-5' \ 19 | --output_root='outputs/aug_samples' \ 20 | --dataset=$DATASET \ 21 | --finetuned_ckpt=$FINETUNED_CKPT \ 22 | --syn_dataset_mulitiplier=$MULTIPLIER \ 23 | --strength_strategy=$STRENGTH_STRATEGY \ 24 | --sample_strategy=$SAMPLE_STRATEGY \ 25 | --examples_per_class=$SHOT \ 26 | --resolution=512 \ 27 | --batch_size=1 \ 28 | --aug_strength=0.8 \ 29 | --gpu-ids=${GPU_IDS[@]} \ 30 | --task='imbalanced' \ 31 | --imbalance_factor=$IMB_FACTOR 32 | 33 | -------------------------------------------------------------------------------- /scripts/sample_mp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import re 5 | import sys 6 | import time 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | import yaml 12 | 13 | os.environ["CURL_CA_BUNDLE"] = "" 14 | 15 | from collections import defaultdict 16 | from multiprocessing import Process, Queue 17 | from queue import Empty 18 | 19 | from PIL import Image 20 | from tqdm import tqdm 21 | 22 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 23 | from augmentation import AUGMENT_METHODS 24 | from dataset import DATASET_NAME_MAPPING, IMBALANCE_DATASET_NAME_MAPPING 25 | from utils.misc import parse_finetuned_ckpt 26 | 27 | 28 | def check_args_valid(args): 29 | 30 | if args.sample_strategy == "real-gen": 31 | args.lora_path = None 32 | args.embed_path = None 33 | args.aug_strength = 1 34 | elif args.sample_strategy == "diff-gen": 35 | lora_path, embed_path = parse_finetuned_ckpt(args.finetuned_ckpt) 36 | args.lora_path = lora_path 37 | args.embed_path = embed_path 38 | args.aug_strength = 1 39 | elif args.sample_strategy in ["real-aug", "real-mix"]: 40 | args.lora_path = None 41 | args.embed_path = None 42 | elif args.sample_strategy in ["diff-aug", "diff-mix"]: 43 | lora_path, embed_path = parse_finetuned_ckpt(args.finetuned_ckpt) 44 | args.lora_path = lora_path 45 | args.embed_path = embed_path 46 | 47 | 48 | def sample_func(args, in_queue, out_queue, gpu_id, process_id): 49 | 50 | os.environ["CURL_CA_BUNDLE"] = "" 51 | 52 | random.seed(args.seed + process_id) 53 | np.random.seed(args.seed + process_id) 54 | torch.manual_seed(args.seed + process_id) 55 | 56 | if args.task == "imbalanced": 57 | train_dataset = IMBALANCE_DATASET_NAME_MAPPING[args.dataset]( 58 | split="train", 59 | seed=args.seed, 60 | resolution=args.resolution, 61 | imbalance_factor=args.imbalance_factor, 62 | ) 63 | else: 64 | train_dataset = DATASET_NAME_MAPPING[args.dataset]( 65 | split="train", 66 | seed=args.seed, # dataset seed is fixed for all processes 67 | examples_per_class=args.examples_per_class, 68 | resolution=args.resolution, 69 | ) 70 | 71 | model = AUGMENT_METHODS[args.sample_strategy]( 72 | model_path=args.model_path, 73 | embed_path=args.embed_path, 74 | lora_path=args.lora_path, 75 | prompt=args.prompt, 76 | guidance_scale=args.guidance_scale, 77 | device=f"cuda:{gpu_id}", 78 | ) 79 | batch_size = args.batch_size 80 | 81 | while True: 82 | index_list = [] 83 | source_label_list = [] 84 | target_label_list = [] 85 | strength_list = [] 86 | for _ in range(batch_size): 87 | try: 88 | index, source_label, target_label, strength = in_queue.get(timeout=1) 89 | index_list.append(index) 90 | source_label_list.append(source_label) 91 | target_label_list.append(target_label) 92 | strength_list.append(strength) 93 | except Empty: 94 | print("queue empty, exit") 95 | break 96 | target_label = target_label_list[0] 97 | target_indice = random.sample(train_dataset.label_to_indices[target_label], 1)[ 98 | 0 99 | ] 100 | target_metadata = train_dataset.get_metadata_by_idx(target_indice) 101 | target_name = target_metadata["name"].replace(" ", "_").replace("/", "_") 102 | 103 | source_images = [] 104 | save_paths = [] 105 | if args.task == "vanilla": 106 | source_indices = [ 107 | random.sample(train_dataset.label_to_indices[source_label], 1)[0] 108 | for source_label in source_label_list 109 | ] 110 | elif args.task == "imbalanced": 111 | source_indices = random.sample(range(len(train_dataset)), batch_size) 112 | for index, source_indice in zip(index_list, source_indices): 113 | source_images.append(train_dataset.get_image_by_idx(source_indice)) 114 | source_metadata = train_dataset.get_metadata_by_idx(source_indice) 115 | source_name = source_metadata["name"].replace(" ", "_").replace("/", "_") 116 | save_name = os.path.join( 117 | source_name, f"{target_name}-{index:06d}-{strength}.png" 118 | ) 119 | save_paths.append(os.path.join(args.output_path, "data", save_name)) 120 | 121 | if os.path.exists(save_paths[0]): 122 | print(f"skip {save_paths[0]}") 123 | else: 124 | image, _ = model( 125 | image=source_images, 126 | label=target_label, 127 | strength=strength, 128 | metadata=target_metadata, 129 | resolution=args.resolution, 130 | ) 131 | for image, save_path in zip(image, save_paths): 132 | image.save(save_path) 133 | print(f"save {save_path}") 134 | 135 | 136 | def main(args): 137 | 138 | torch.multiprocessing.set_start_method("spawn") 139 | 140 | os.makedirs(os.path.join(args.output_root, args.dataset), exist_ok=True) 141 | 142 | check_args_valid(args) 143 | if args.task == "vanilla": 144 | output_name = f"shot{args.examples_per_class}_{args.sample_strategy}_{args.strength_strategy}_{args.aug_strength}" 145 | else: # imbalanced 146 | output_name = f"imb{args.imbalance_factor}_{args.sample_strategy}_{args.strength_strategy}_{args.aug_strength}" 147 | args.output_path = os.path.join(args.output_root, args.dataset, output_name) 148 | 149 | os.makedirs(args.output_path, exist_ok=True) 150 | torch.manual_seed(args.seed) 151 | np.random.seed(args.seed) 152 | random.seed(args.seed) 153 | 154 | gpu_ids = args.gpu_ids 155 | in_queue = Queue() 156 | out_queue = Queue() 157 | 158 | if args.task == "imbalanced": 159 | train_dataset = IMBALANCE_DATASET_NAME_MAPPING[args.dataset]( 160 | split="train", 161 | seed=args.seed, 162 | resolution=args.resolution, 163 | imbalance_factor=args.imbalance_factor, 164 | ) 165 | else: 166 | train_dataset = DATASET_NAME_MAPPING[args.dataset]( 167 | split="train", 168 | seed=args.seed, 169 | examples_per_class=args.examples_per_class, 170 | resolution=args.resolution, 171 | ) 172 | 173 | num_classes = len(train_dataset.class_names) 174 | 175 | for name in train_dataset.class_names: 176 | name = name.replace(" ", "_").replace("/", "_") 177 | os.makedirs(os.path.join(args.output_path, "data", name), exist_ok=True) 178 | 179 | num_tasks = args.syn_dataset_mulitiplier * len(train_dataset) 180 | 181 | if args.sample_strategy in [ 182 | "real-gen", 183 | "real-aug", 184 | "diff-aug", 185 | "diff-gen", 186 | "ti-aug", 187 | ]: 188 | source_classes = random.choices( 189 | range(len(train_dataset.class_names)), k=num_tasks 190 | ) 191 | target_classes = source_classes 192 | elif args.sample_strategy in ["real-mix", "diff-mix", "ti-mix"]: 193 | source_classes = random.choices( 194 | range(len(train_dataset.class_names)), k=num_tasks 195 | ) 196 | target_classes = random.choices( 197 | range(len(train_dataset.class_names)), k=num_tasks 198 | ) 199 | else: 200 | raise ValueError(f"Augmentation strategy {args.sample_strategy} not supported") 201 | 202 | if args.strength_strategy == "fixed": 203 | strength_list = [args.aug_strength] * num_tasks 204 | elif args.strength_strategy == "uniform": 205 | strength_list = random.choices([0.3, 0.5, 0.7, 0.9], k=num_tasks) 206 | 207 | options = zip(range(num_tasks), source_classes, target_classes, strength_list) 208 | 209 | for option in options: 210 | in_queue.put(option) 211 | 212 | sample_config = vars(args) 213 | sample_config["num_classes"] = num_classes 214 | sample_config["total_tasks"] = num_tasks 215 | sample_config["sample_strategy"] = args.sample_strategy 216 | 217 | with open( 218 | os.path.join(args.output_path, "config.yaml"), "w", encoding="utf-8" 219 | ) as f: 220 | yaml.dump(sample_config, f) 221 | 222 | processes = [] 223 | total_tasks = in_queue.qsize() 224 | print("Number of total tasks", total_tasks) 225 | 226 | with tqdm(total=total_tasks, desc="Processing") as pbar: 227 | for process_id, gpu_id in enumerate(gpu_ids): 228 | process = Process( 229 | target=sample_func, 230 | args=(args, in_queue, out_queue, gpu_id, process_id), 231 | ) 232 | process.start() 233 | processes.append(process) 234 | 235 | while any(process.is_alive() for process in processes): 236 | current_queue_size = in_queue.qsize() 237 | pbar.n = total_tasks - current_queue_size 238 | pbar.refresh() 239 | time.sleep(1) 240 | 241 | for process in processes: 242 | process.join() 243 | 244 | # Generate meta.csv for indexing images 245 | rootdir = os.path.join(args.output_path, "data") 246 | pattern_level_1 = r"(.+)" 247 | pattern_level_2 = r"(.+)-(\d+)-(.+).png" 248 | data_dict = defaultdict(list) 249 | for dir in os.listdir(rootdir): 250 | if not os.path.isdir(os.path.join(rootdir, dir)): 251 | continue 252 | match_1 = re.match(pattern_level_1, dir) 253 | first_dir = match_1.group(1).replace("_", " ") 254 | for file in os.listdir(os.path.join(rootdir, dir)): 255 | match_2 = re.match(pattern_level_2, file) 256 | second_dir = match_2.group(1).replace("_", " ") 257 | num = int(match_2.group(2)) 258 | floating_num = float(match_2.group(3)) 259 | data_dict["First Directory"].append(first_dir) 260 | data_dict["Second Directory"].append(second_dir) 261 | data_dict["Number"].append(num) 262 | data_dict["Strength"].append(floating_num) 263 | data_dict["Path"].append(os.path.join(dir, file)) 264 | 265 | df = pd.DataFrame(data_dict) 266 | 267 | # Validate generated images 268 | valid_rows = [] 269 | for index, row in tqdm(df.iterrows(), total=len(df)): 270 | image_path = os.path.join(args.output_path, "data", row["Path"]) 271 | try: 272 | img = Image.open(image_path) 273 | img.close() 274 | valid_rows.append(row) 275 | except Exception as e: 276 | os.remove(image_path) 277 | print(f"Deleted {image_path} due to error: {str(e)}") 278 | 279 | valid_df = pd.DataFrame(valid_rows) 280 | csv_path = os.path.join(args.output_path, "meta.csv") 281 | valid_df.to_csv(csv_path, index=False) 282 | 283 | print("DataFrame:") 284 | print(df) 285 | 286 | 287 | if __name__ == "__main__": 288 | parser = argparse.ArgumentParser("Inference script") 289 | parser.add_argument( 290 | "--finetuned_ckpt", 291 | type=str, 292 | required=True, 293 | help="key for indexing finetuned model", 294 | ) 295 | parser.add_argument( 296 | "--output_root", 297 | type=str, 298 | default="outputs/aug_samples", 299 | help="output root directory", 300 | ) 301 | parser.add_argument( 302 | "--model_path", type=str, default="CompVis/stable-diffusion-v1-4" 303 | ) 304 | parser.add_argument("--dataset", type=str, default="pascal", help="dataset name") 305 | parser.add_argument("--seed", type=int, default=0, help="random seed") 306 | parser.add_argument( 307 | "--examples_per_class", 308 | type=int, 309 | default=-1, 310 | help="synthetic examples per class", 311 | ) 312 | parser.add_argument("--resolution", type=int, default=512, help="image resolution") 313 | parser.add_argument("--batch_size", type=int, default=1, help="batch size") 314 | parser.add_argument( 315 | "--prompt", type=str, default="a photo of a {name}", help="prompt for synthesis" 316 | ) 317 | parser.add_argument( 318 | "--sample_strategy", 319 | type=str, 320 | default="ti-mix", 321 | choices=[ 322 | "real-gen", 323 | "real-aug", # real guidance 324 | "real-mix", 325 | "ti-aug", 326 | "ti-mix", 327 | "diff-aug", 328 | "diff-mix", 329 | "diff-gen", 330 | ], 331 | help="sampling strategy for synthetic data", 332 | ) 333 | parser.add_argument( 334 | "--guidance-scale", 335 | type=float, 336 | default=7.5, 337 | help="classifier free guidance scale", 338 | ) 339 | parser.add_argument("--gpu_ids", type=int, nargs="+", default=[0], help="gpu ids") 340 | parser.add_argument( 341 | "--task", 342 | type=str, 343 | default="vanilla", 344 | choices=["vanilla", "imbalanced"], 345 | help="task", 346 | ) 347 | parser.add_argument( 348 | "--imbalance_factor", 349 | type=float, 350 | default=0.01, 351 | choices=[0.01, 0.02, 0.1], 352 | help="imbalanced factor, only for imbalanced task", 353 | ) 354 | parser.add_argument( 355 | "--syn_dataset_mulitiplier", 356 | type=int, 357 | default=5, 358 | help="multiplier for the number of synthetic images compared to the number of real images", 359 | ) 360 | parser.add_argument( 361 | "--strength_strategy", 362 | type=str, 363 | default="fixed", 364 | choices=["fixed", "uniform"], 365 | ) 366 | parser.add_argument( 367 | "--aug_strength", type=float, default=0.5, help="augmentation strength" 368 | ) 369 | args = parser.parse_args() 370 | 371 | main(args) 372 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import re 4 | from typing import List, Union 5 | 6 | import pandas as pd 7 | import yaml 8 | 9 | 10 | def count_files_in_directory(directory): 11 | count = 0 12 | for root, dirs, files in os.walk(directory): 13 | count += len(files) 14 | return count 15 | 16 | 17 | def check_synthetic_dir_valid(synthetic_dir): 18 | 19 | if not os.path.exists(synthetic_dir): 20 | raise FileNotFoundError(f"Directory '{synthetic_dir}' does not exist.") 21 | 22 | total_files = count_files_in_directory(synthetic_dir) 23 | if total_files > 100: 24 | print(f"Directory '{synthetic_dir}' is valid with {total_files} files.") 25 | else: 26 | raise ValueError( 27 | f"Directory '{synthetic_dir}' contains less than 100 files, which is insufficient." 28 | ) 29 | 30 | 31 | def parse_finetuned_ckpt(finetuned_ckpt): 32 | lora_path = None 33 | embed_path = None 34 | for file in os.listdir(finetuned_ckpt): 35 | if "pytorch_lora_weights" in file: 36 | lora_path = os.path.join(finetuned_ckpt, file) 37 | elif "learned_embeds-steps-last" in file: 38 | embed_path = os.path.join(finetuned_ckpt, file) 39 | return lora_path, embed_path 40 | 41 | 42 | def checked_has_run(exp_dir, args): 43 | parent_dir = os.path.abspath(os.path.join(exp_dir, os.pardir)) 44 | current_args = copy.deepcopy(args) 45 | current_args.pop("gpu", None) 46 | current_args.pop("note", None) 47 | current_args.pop("target_class_num", None) 48 | 49 | for dirpath, dirnames, filenames in os.walk(parent_dir): 50 | for dirname in dirnames: 51 | config_file = os.path.join(dirpath, dirname, "config.yaml") 52 | if os.path.exists(config_file): 53 | with open(config_file, "r") as file: 54 | saved_args = yaml.load(file, Loader=yaml.FullLoader) 55 | 56 | if ( 57 | current_args["syndata_dir"] is None 58 | or "aug" in current_args["syndata_dir"] 59 | or "gen" in current_args["syndata_dir"] 60 | ): 61 | current_args.pop("gamma", None) 62 | saved_args.pop("gamma", None) 63 | saved_args.pop("gpu", None) 64 | saved_args.pop("note", None) 65 | saved_args.pop("target_class_num", None) 66 | if saved_args == current_args: 67 | print( 68 | f"This program has already been run in directory: {dirpath}/{dirname}" 69 | ) 70 | return True 71 | return False 72 | 73 | 74 | def parse_result(target_dir, extra_column=[]): 75 | results = [] 76 | for file in os.listdir(target_dir): 77 | config_file = os.path.join(target_dir, file, "config.yaml") 78 | config = yaml.safe_load(open(config_file, "r")) 79 | if isinstance(config["syndata_dir"], list): 80 | syndata_dir = config["syndata_dir"][0] 81 | else: 82 | syndata_dir = config["syndata_dir"] 83 | 84 | if syndata_dir is None: 85 | strategy = "baseline" 86 | strength = 0 87 | else: 88 | match = re.match(r"([a-zA-Z]+)([0-9.]*).*", syndata_dir) 89 | if match: 90 | strategy = match.group(1) 91 | strength = match.group(2) 92 | else: 93 | continue 94 | for basefile in os.listdir(os.path.join(target_dir, file)): 95 | if "acc_eval" in basefile: 96 | acc = float(basefile.split("_")[-1]) 97 | results.append( 98 | ( 99 | config["dir"], 100 | config["res_mode"], 101 | config["lr"], 102 | strategy, 103 | strength, 104 | config["gamma"], 105 | config["seed"], 106 | *[str(config.pop(key, "False")) for key in extra_column], 107 | acc, 108 | ) 109 | ) 110 | break 111 | 112 | df = pd.DataFrame( 113 | results, 114 | columns=[ 115 | "dataset", 116 | "resolution", 117 | "lr", 118 | "strategy", 119 | "strength", 120 | "soft power", 121 | "seed", 122 | *extra_column, 123 | "acc", 124 | ], 125 | ) 126 | df["acc"] = df["acc"].astype(float) 127 | result_seed = ( 128 | df.groupby( 129 | [ 130 | "dataset", 131 | "resolution", 132 | "lr", 133 | "strength", 134 | "strategy", 135 | "soft power", 136 | *extra_column, 137 | ] 138 | ) 139 | .agg({"acc": ["mean", "var"]}) 140 | .reset_index() 141 | ) 142 | result_sorted = result_seed.sort_values( 143 | by=["dataset", "resolution", "lr", "strategy", "strength", *extra_column] 144 | ) 145 | result_seed.columns = ["_".join(col).strip() for col in result_seed.columns.values] 146 | 147 | return result_sorted 148 | -------------------------------------------------------------------------------- /utils/network.py: -------------------------------------------------------------------------------- 1 | def count_parameters(model): 2 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 3 | total_params = sum(p.numel() for p in model.parameters()) 4 | return trainable_params, total_params 5 | 6 | 7 | def freeze_model(model, finetune_strategy="linear"): 8 | if finetune_strategy == "linear": 9 | for name, param in model.named_parameters(): 10 | if "fc" in name: 11 | param.requires_grad = True 12 | else: 13 | param.requires_grad = False 14 | elif finetune_strategy == "stages4+linear": 15 | for name, param in model.named_parameters(): 16 | if any(list(map(lambda x: x in name, ["layer4", "fc"]))): 17 | param.requires_grad = True 18 | else: 19 | param.requires_grad = False 20 | elif finetune_strategy == "stages3-4+linear": 21 | for name, param in model.named_parameters(): 22 | if any(list(map(lambda x: x in name, ["layer3", "layer4", "fc"]))): 23 | param.requires_grad = True 24 | else: 25 | param.requires_grad = False 26 | elif finetune_strategy == "stages2-4+linear": 27 | for name, param in model.named_parameters(): 28 | if any( 29 | list(map(lambda x: x in name, ["layer2", "layer3", "layer4", "fc"])) 30 | ): 31 | param.requires_grad = True 32 | else: 33 | param.requires_grad = False 34 | elif finetune_strategy == "stages1-4+linear": 35 | for name, param in model.named_parameters(): 36 | if any( 37 | list( 38 | map( 39 | lambda x: x in name, 40 | ["layer1", "layer2", "layer3", "layer4", "fc"], 41 | ) 42 | ) 43 | ): 44 | param.requires_grad = True 45 | else: 46 | param.requires_grad = False 47 | elif finetune_strategy == "all": 48 | for name, param in model.named_parameters(): 49 | param.requires_grad = True 50 | else: 51 | raise NotImplementedError(f"{finetune_strategy}") 52 | 53 | trainable_params, total_params = count_parameters(model) 54 | ratio = trainable_params / total_params 55 | 56 | print(f"{finetune_strategy}, Trainable / Total Parameter Ratio: {ratio:.4f}") 57 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Union 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | from einops import rearrange 8 | from PIL import Image 9 | from torchvision import transforms 10 | from torchvision.utils import make_grid 11 | 12 | 13 | def visualize_images( 14 | images: List[Union[Image.Image, torch.Tensor, np.ndarray]], 15 | nrow: int = 4, 16 | show=False, 17 | save=True, 18 | outpath=None, 19 | ): 20 | 21 | if isinstance(images[0], Image.Image): 22 | transform = transforms.ToTensor() 23 | images_ts = torch.stack([transform(image) for image in images]) 24 | elif isinstance(images[0], torch.Tensor): 25 | images_ts = torch.stack(images) 26 | elif isinstance(images[0], np.ndarray): 27 | images_ts = torch.stack([torch.from_numpy(image) for image in images]) 28 | # save images to a grid 29 | grid = make_grid(images_ts, nrow=nrow, normalize=True, scale_each=True) 30 | # set plt figure size to (4,16) 31 | 32 | if show: 33 | plt.figure( 34 | figsize=(4 * nrow, 4 * (len(images) // nrow + (len(images) % nrow > 0))) 35 | ) 36 | plt.imshow(grid.permute(1, 2, 0)) 37 | plt.axis("off") 38 | plt.show() 39 | # remove the axis 40 | grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() 41 | img = Image.fromarray(grid.astype(np.uint8)) 42 | if save: 43 | assert outpath is not None 44 | if os.path.dirname(outpath) and not os.path.exists(os.path.dirname(outpath)): 45 | os.makedirs(os.path.dirname(outpath), exist_ok=True) 46 | img.save(f"{outpath}") 47 | return img 48 | -------------------------------------------------------------------------------- /visualization/visualize_attn_map.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | from collections import defaultdict 5 | 6 | import matplotlib.pyplot as plt 7 | import torch 8 | 9 | from diffusers.models.attention import Attention 10 | from utils.utils import ( 11 | AUGMENT_METHODS, 12 | DATASET_NAME_MAPPING, 13 | T2I_DATASET_NAME_MAPPING, 14 | finetuned_ckpt_dir, 15 | ) 16 | 17 | 18 | class AttentionVisualizer: 19 | def __init__(self, model, hook_target_name): 20 | self.model = model 21 | self.hook_target_name = hook_target_name 22 | self.activation = defaultdict(list) 23 | self.hooks = [] 24 | 25 | def get_attn_softmax(self, name): 26 | def hook(unet, input, kwargs, output): 27 | scale = 1.0 28 | with torch.no_grad(): 29 | hidden_states = input[0] 30 | encoder_hidden_states = kwargs["encoder_hidden_states"] 31 | attention_mask = kwargs["attention_mask"] 32 | batch_size, sequence_length, _ = ( 33 | hidden_states.shape 34 | if encoder_hidden_states is None 35 | else encoder_hidden_states.shape 36 | ) 37 | attention_mask = unet.prepare_attention_mask( 38 | attention_mask, sequence_length, batch_size 39 | ) 40 | if hasattr(unet, "preocessor"): 41 | query = unet.to_q(hidden_states) + scale * unet.processor.to_q_lora( 42 | hidden_states 43 | ) 44 | query = unet.head_to_batch_dim(query) 45 | 46 | if encoder_hidden_states is None: 47 | encoder_hidden_states = hidden_states 48 | elif unet.norm_cross: 49 | encoder_hidden_states = unet.norm_encoder_hidden_states( 50 | encoder_hidden_states 51 | ) 52 | 53 | key = unet.to_k( 54 | encoder_hidden_states 55 | ) + scale * unet.processor.to_k_lora(encoder_hidden_states) 56 | value = unet.to_v( 57 | encoder_hidden_states 58 | ) + scale * unet.processor.to_v_lora(encoder_hidden_states) 59 | else: 60 | query = unet.to_q(hidden_states) 61 | query = unet.head_to_batch_dim(query) 62 | 63 | if encoder_hidden_states is None: 64 | encoder_hidden_states = hidden_states 65 | elif unet.norm_cross: 66 | encoder_hidden_states = unet.norm_encoder_hidden_states( 67 | encoder_hidden_states 68 | ) 69 | 70 | key = unet.to_k(encoder_hidden_states) 71 | value = unet.to_v(encoder_hidden_states) 72 | 73 | key = unet.head_to_batch_dim(key) 74 | value = unet.head_to_batch_dim(value) 75 | 76 | attention_probs = unet.get_attention_scores(query, key, attention_mask) 77 | 78 | self.activation[name].append(attention_probs) 79 | 80 | return hook 81 | 82 | def __enter__(self): 83 | unet = self.model.pipe.unet 84 | for name, module in unet.named_modules(): 85 | if self.hook_target_name is not None: 86 | if self.hook_target_name == name: 87 | print("Added hook to", name) 88 | hook = module.register_forward_hook( 89 | self.get_attn_softmax(name), with_kwargs=True 90 | ) 91 | self.hooks.append(hook) 92 | return self 93 | 94 | def __exit__(self, exc_type, exc_value, traceback): 95 | for hook in self.hooks: 96 | hook.remove() 97 | 98 | 99 | def plot_attn_map(attn_map, path="figures/attn_map/"): 100 | # Ensure the output directory exists 101 | os.makedirs(path, exist_ok=True) 102 | 103 | for i, attention_map in enumerate(attn_map): 104 | # Attention map is of shape [8+8, 4096, 77] 105 | num_heads, hw, num_tokens = attention_map.size() 106 | 107 | # Reshape to (num_heads, sqrt(num_tokens), sqrt(num_tokens), num_classes) 108 | H = int(math.sqrt(hw)) 109 | W = int(math.sqrt(hw)) 110 | vis_map = attention_map.view(num_heads, H, W, -1) 111 | 112 | # Split into unconditional and conditional attention maps 113 | uncond_attn_map, cond_attn_map = torch.chunk(vis_map, 2, dim=0) 114 | 115 | # Mean over heads [h, w, num_classes] 116 | cond_attn_map = cond_attn_map.mean(0) 117 | uncond_attn_map = uncond_attn_map.mean(0) 118 | 119 | # Plot and save attention maps 120 | fig, ax = plt.subplots(1, 10, figsize=(20, 2)) 121 | for j in range(10): 122 | attn_slice = cond_attn_map[:, :, j].unsqueeze(-1).cpu().numpy() 123 | ax[j].imshow(attn_slice) 124 | ax[j].axis("off") 125 | 126 | # Save the plot 127 | save_path = os.path.join(path, f"attn_map_{i:03d}.jpg") 128 | plt.tight_layout() 129 | plt.savefig(save_path) 130 | plt.close(fig) 131 | 132 | print(f"Saved attention map at: {save_path}") 133 | 134 | 135 | def synthesize_images( 136 | model, 137 | strength, 138 | dataset="cub", 139 | finetuned_ckpt="db_ti_latest", 140 | source_label=1, 141 | target_label=2, 142 | source_image=None, 143 | seed=0, 144 | hook_target_name: str = "up_blocks.2.attentions.1.transformer_blocks.0.attn2", 145 | ): 146 | 147 | random.seed(seed) 148 | train_dataset = DATASET_NAME_MAPPING[dataset](split="train") 149 | target_indice = random.sample(train_dataset.label_to_indices[target_label], 1)[0] 150 | 151 | if source_image is None: 152 | # source_indice = random.sample(train_dataset.label_to_indices[source_label], 1)[0] 153 | source_indice = train_dataset.label_to_indices[source_label][0] 154 | source_image = train_dataset.get_image_by_idx(source_indice) 155 | target_metadata = train_dataset.get_metadata_by_idx(target_indice) 156 | with AttentionVisualizer(model, hook_target_name) as visualizer: 157 | image, _ = model( 158 | image=[source_image], 159 | label=target_label, 160 | strength=strength, 161 | metadata=target_metadata, 162 | ) 163 | attn_map = visualizer.activation[hook_target_name] 164 | path = os.path.join("figures/attn_map/", dataset, finetuned_ckpt) 165 | plot_attn_map(attn_map, path=path) 166 | return image 167 | 168 | 169 | if __name__ == "__main__": 170 | dataset_list = ["cub"] 171 | aug = "diff-mix" #'diff-aug/mixup" "real-mix" 172 | finetuned_ckpt = "db_latest" 173 | guidance_scale = 7 174 | prompt = "a photo of a {name}" 175 | 176 | for dataset in dataset_list: 177 | lora_path, embed_path = finetuned_ckpt_dir( 178 | dataset=dataset, finetuned_ckpt=finetuned_ckpt 179 | ) 180 | AUGMENT_METHODS[aug].pipe = None 181 | model = AUGMENT_METHODS[aug]( 182 | embed_path=embed_path, 183 | lora_path=lora_path, 184 | prompt=prompt, 185 | guidance_scale=guidance_scale, 186 | mask=False, 187 | inverted=False, 188 | device=f"cuda:1", 189 | ) 190 | image = synthesize_images( 191 | model, 192 | 0.5, 193 | dataset, 194 | finetuned_ckpt=finetuned_ckpt, 195 | source_label=13, 196 | target_label=2, 197 | source_image=None, 198 | seed=0, 199 | ) 200 | -------------------------------------------------------------------------------- /visualization/visualize_cases.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import sys 4 | 5 | os.environ["DISABLE_TELEMETRY"] = "YES" 6 | sys.path.append("../") 7 | 8 | from augmentation import AUGMENT_METHODS 9 | from dataset import DATASET_NAME_MAPPING 10 | from utils.misc import finetuned_ckpt_dir 11 | from utils.visualization import visualize_images 12 | 13 | 14 | def synthesize_images( 15 | model, strength, train_dataset, source_label=1, target_label=2, source_image=None 16 | ): 17 | num = 1 18 | random.seed(seed) 19 | target_indice = random.sample(train_dataset.label_to_indices[target_label], 1)[0] 20 | 21 | if source_image is None: 22 | source_indice = random.sample(train_dataset.label_to_indices[source_label], 1)[ 23 | 0 24 | ] 25 | source_image = train_dataset.get_image_by_idx(source_indice) 26 | target_metadata = train_dataset.get_metadata_by_idx(target_indice) 27 | image_list = [] 28 | image, _ = model( 29 | image=[source_image], 30 | label=target_label, 31 | strength=strength, 32 | metadata=target_metadata, 33 | ) 34 | return image 35 | 36 | 37 | if __name__ == "__main__": 38 | device = "cuda:1" 39 | dataset = "pascal" 40 | aug = "diff-mix" #'diff-aug/mixup" "real-mix" 41 | finetuned_ckpt = "db_latest_5shot" 42 | guidance_scale = 7 43 | strength_list = [0.1, 0.3, 0.5, 0.7, 0.9, 1.0] 44 | 45 | seed = 0 46 | random.seed(seed) 47 | source_label = 5 48 | for target_label in [4, 7, 6]: 49 | for dataset in ["pascal"]: 50 | for aug in ["diff-mix"]: 51 | train_dataset = DATASET_NAME_MAPPING[dataset]( 52 | split="train", examples_per_class=5 53 | ) 54 | lora_path, embed_path = finetuned_ckpt_dir( 55 | dataset=dataset, finetuned_ckpt=finetuned_ckpt 56 | ) 57 | 58 | AUGMENT_METHODS[aug].pipe = None 59 | model = AUGMENT_METHODS[aug]( 60 | embed_path=embed_path, 61 | lora_path=lora_path, 62 | prompt="a photo of a {name}", 63 | guidance_scale=guidance_scale, 64 | mask=False, 65 | inverted=False, 66 | device=device, 67 | ) 68 | 69 | image_list = [] 70 | for strength in strength_list: 71 | source_image = train_dataset.get_image_by_idx( 72 | train_dataset.label_to_indices[source_label][3] 73 | ) 74 | image_list.append( 75 | synthesize_images( 76 | model, 77 | strength, 78 | train_dataset, 79 | source_label=source_label, 80 | target_label=target_label, 81 | source_image=source_image, 82 | )[0] 83 | ) 84 | 85 | outpath = ( 86 | f"figures/cases/{dataset}/{aug}_{source_label}_{target_label}.png" 87 | ) 88 | visualize_images( 89 | image_list, nrow=6, show=False, save=True, outpath=outpath 90 | ) 91 | -------------------------------------------------------------------------------- /visualization/visualize_filtered_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import sys 4 | 5 | os.environ["DISABLE_TELEMETRY"] = "YES" 6 | sys.path.append("../") 7 | 8 | from augmentation import AUGMENT_METHODS, load_embeddings 9 | from dataset import DATASET_NAME_MAPPING, IMBALANCE_DATASET_NAME_MAPPING 10 | from dataset.base import SyntheticDataset 11 | from utils.misc import parse_synthetic_dir 12 | from utils.visualization import visualize_images 13 | 14 | if __name__ == "__main__": 15 | device = "cuda:1" 16 | dataset = "aircraft" 17 | csv_file = "meta_90-100per.csv" 18 | csv_file = "meta_0-10-per.csv" 19 | row = 5 20 | column = 2 21 | synthetic_type = "mixup_uniform" 22 | synthetic_dir = parse_synthetic_dir(dataset, synthetic_type=synthetic_type) 23 | 24 | for csv_file in ["meta_0-10per.csv", "meta_90-100per.csv"]: 25 | ds = SyntheticDataset(synthetic_dir, csv_file=csv_file) 26 | num = row * column 27 | indices = random.sample(range(len(ds)), num) 28 | 29 | image_list = [] 30 | for index in indices: 31 | image_list.append(ds.get_image_by_idx(index).resize((224, 224))) 32 | 33 | outpath = f"figures/cases_filtered/{dataset}/{synthetic_type}_{csv_file}.png" 34 | visualize_images(image_list, nrow=row, show=False, save=True, outpath=outpath) 35 | --------------------------------------------------------------------------------