├── .gitignore
├── .gitmodules
├── .vscode
├── launch.json
└── settings.json
├── LICENSE
├── README.md
├── assets
└── teaser.png
├── augmentation
├── __init__.py
├── base_augmentation.py
├── diff_mix.py
├── real_mix.py
└── ti_mix.py
├── dataset
├── __init__.py
├── base.py
├── instance
│ ├── aircraft.py
│ ├── car.py
│ ├── cub.py
│ ├── dog.py
│ ├── flower.py
│ ├── food.py
│ ├── pascal.py
│ ├── pet.py
│ └── waterbird.py
└── template.py
├── downstream_tasks
├── imb_utils
│ ├── __init__.py
│ ├── autoaug.py
│ ├── moco_loader.py
│ ├── randaugment.py
│ └── util.py
├── losses.py
├── mixup.py
├── train_hub.py
├── train_hub_imb.py
└── train_hub_waterbird.py
├── outputs
├── requirement.txt
├── scripts
├── classification.sh
├── classification_imb.sh
├── classification_waterbird.sh
├── compose_syn_data.py
├── filter_syn_data.py
├── finetune.sh
├── finetune_imb.sh
├── sample.sh
├── sample_imb.sh
└── sample_mp.py
├── train_lora.py
├── utils
├── misc.py
├── network.py
└── visualization.py
└── visualization
├── visualize_attn_map.py
├── visualize_cases.py
└── visualize_filtered_data.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | debug/
7 | src/
8 | notebook/
9 | figures/
10 | ckpts/*
11 | aug_samples/*
12 | _others/
13 | pyrightconfig.json
14 | # images
15 | *.jpg
16 | *.pdf
17 |
18 | # experiment data
19 | *.out
20 | *.csv
21 | *.pt
22 |
23 | # C extensions
24 | *.so
25 |
26 | # Distribution / packaging
27 | .Python
28 | build/
29 | develop-eggs/
30 | dist/
31 | downloads/
32 | eggs/
33 | .eggs/
34 | lib/
35 | lib64/
36 | parts/
37 | sdist/
38 | var/
39 | wheels/
40 | pip-wheel-metadata/
41 | share/python-wheels/
42 | *.egg-info/
43 | .installed.cfg
44 | *.egg
45 | MANIFEST
46 |
47 | # PyInstaller
48 | # Usually these files are written by a python script from a template
49 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
50 | *.manifest
51 | *.spec
52 |
53 | # Installer logs
54 | pip-log.txt
55 | pip-delete-this-directory.txt
56 |
57 | # Unit test / coverage reports
58 | .vscode/*
59 | htmlcov/
60 | .tox/
61 | .nox/
62 | .coverage
63 | .coverage.*
64 | .cache
65 | nosetests.xml
66 | coverage.xml
67 | *.cover
68 | *.py,cover
69 | .hypothesis/
70 | .pytest_cache/
71 |
72 | # Translations
73 | *.mo
74 | *.pot
75 |
76 | # Django stuff:
77 | *.log
78 | local_settings.py
79 | db.sqlite3
80 | db.sqlite3-journal
81 |
82 | # Flask stuff:
83 | .webassets-cache
84 |
85 | # Scrapy stuff:
86 | .scrapy
87 |
88 | # Sphinx documentation
89 | docs/_build/
90 |
91 | # PyBuilder
92 | target/
93 |
94 | # Jupyter Notebook
95 | .ipynb_checkpoints
96 |
97 | # IPython
98 | profile_default/
99 | ipython_config.py
100 |
101 | # pyenv
102 | .python-version
103 |
104 | # pipenv
105 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
106 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
107 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
108 | # install all needed dependencies.
109 | #Pipfile.lock
110 |
111 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
112 | __pypackages__/
113 |
114 | # Celery stuff
115 | celerybeat-schedule
116 | celerybeat.pid
117 |
118 | # SageMath parsed files
119 | *.sage.py
120 |
121 | # Environments
122 | .env
123 | .venv
124 | env/
125 | venv/
126 | ENV/
127 | env.bak/
128 | venv.bak/
129 |
130 | # Spyder project settings
131 | .spyderproject
132 | .spyproject
133 |
134 | # Rope project settings
135 | .ropeproject
136 |
137 | # mkdocs documentation
138 | /site
139 |
140 | # mypy
141 | .mypy_cache/
142 | .dmypy.json
143 | dmypy.json
144 |
145 | # Pyre type checker
146 | .pyre/
147 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "stable-diffusion"]
2 | path = stable-diffusion
3 | url = https://github.com/CompVis/stable-diffusion.git
4 | [submodule "erasing"]
5 | path = erasing
6 | url = https://github.com/brandontrabucco/erasing.git
7 |
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 |
8 | {
9 | "name": "Python: Finetune Script",
10 | "type": "python",
11 | "request": "launch",
12 | "module": "accelerate.commands.launch",
13 | "console": "integratedTerminal",
14 | "env":{"CUDA_VISIBLE_DEVICES": "0,1"},
15 | "args": [
16 | "train_lora.py",
17 | "--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5",
18 | "--dataset_name=cub", // 请替换为实际的数据集名称
19 | "--resolution=512", //
20 | "--random_flip",
21 | "--max_train_steps=35000",
22 | "--num_train_epochs=10", //
23 | "--checkpointing_steps=5000",
24 | "--learning_rate=5e-05",
25 | "--lr_scheduler=constant",
26 | "--lr_warmup_steps=0",
27 | "--seed=42",
28 | "--rank=10",
29 | "--local_files_only",
30 | "--examples_per_class=5",
31 | "--train_batch_size=1",
32 | "--output_dir=/data/zhicai/code/Diff-Mix/outputs/finetune_model/finetune_ti_db/sd-cub-5shot-model-lora-rank10", // 请替换为实际的输出目录
33 | "--report_to=tensorboard"
34 | ]
35 | },
36 | {
37 | "name": "Python: Finetune IMB Script",
38 | "type": "python",
39 | "request": "launch",
40 | "module": "accelerate.commands.launch",
41 | "console": "integratedTerminal",
42 | "env":{"CUDA_VISIBLE_DEVICES": "4"},
43 | "args": [
44 | "train_lora.py",
45 | "--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5",
46 | "--dataset_name=cub", // 请替换为实际的数据集名称
47 | "--resolution=512", //
48 | "--random_flip",
49 | "--max_train_steps=35000",
50 | "--num_train_epochs=10", //
51 | "--checkpointing_steps=5000",
52 | "--learning_rate=5e-05",
53 | "--lr_scheduler=constant",
54 | "--lr_warmup_steps=0",
55 | "--seed=42",
56 | "--rank=10",
57 | "--local_files_only",
58 | "--train_batch_size=1",
59 | "--output_dir=ckpts/cub/IMB0.01_lora_rank10", // 请替换为实际的输出目录
60 | "--report_to=tensorboard",
61 | "--task=imbalanced",
62 | "--imbalance_factor=0.01"
63 |
64 | ]
65 | },
66 | {
67 | "name": "Python: Sample MP",
68 | "type": "python",
69 | "request": "launch",
70 | "program": "${workspaceFolder}/scripts/sample_mp.py",
71 | "args": [
72 | "--output_root", "aug_samples",
73 | "--finetuned_ckpt", "ckpts/cub/shot-1-lora-rank10",
74 | "--dataset", "cub",
75 | "--syn_dataset_mulitiplier", "5",
76 | "--strength_strategy", "fixed",
77 | "--resolution", "512",
78 | "--batch_size", "1",
79 | "--aug_strength", "0.7",
80 | "--model_path", "runwayml/stable-diffusion-v1-5",
81 | "--sample_strategy", "diff-mix",
82 | "--gpu_ids", "3", "4",
83 | ],
84 | "console": "integratedTerminal"
85 | },
86 | {
87 | "name": "Python: Sample MP IMB",
88 | "type": "python",
89 | "request": "launch",
90 | "program": "${workspaceFolder}/scripts/sample_mp.py",
91 | "args": [
92 | "--output_root", "aug_samples",
93 | "--finetuned_ckpt", "ckpts/cub/shot-1-lora-rank10",
94 | "--dataset", "cub",
95 | "--syn_dataset_mulitiplier", "5",
96 | "--strength_strategy", "fixed",
97 | "--resolution", "512",
98 | "--batch_size", "1",
99 | "--aug_strength", "0.7",
100 | "--model_path", "runwayml/stable-diffusion-v1-5",
101 | "--sample_strategy", "diff-mix",
102 | "--gpu_ids", "3", "4",
103 | "--task", "imbalanced",
104 | "--imbalance_factor", "0.01"
105 | ],
106 | "console": "integratedTerminal"
107 | },
108 | {
109 | "name": "Python: Train Hub",
110 | "type": "python",
111 | "request": "launch",
112 | "program": "${workspaceFolder}/downstream_tasks/train_hub.py",
113 | "args": [
114 | "--dataset", "cub",
115 | "--syndata_p", "0.1",
116 | "--syndata_dir", "outputs/aug_samples/cub/dreambooth-lora-mixup-Multi7-db_ti10000-Strength0.5",
117 | "--model", "resnet50",
118 | "--gamma", "0.8",
119 | "--examples_per_class", "-1",
120 | "--gpu", "1",
121 | "--amp", "2",
122 | "--note", "${env:DATE}",
123 | "--group_note", "test",
124 | "--nepoch", "60",
125 | "--res_mode", "224",
126 | "--lr", "0.05",
127 | "--seed", "0",
128 | "--weight_decay", "0.0005"
129 | ],
130 | "console": "integratedTerminal",
131 | "env": {
132 | "DATE": "${command:python.interpreterPath} -c \"import datetime; print(datetime.datetime.now().strftime('%m%d%H%M'))\""
133 | }
134 | },
135 | {
136 | "name": "Python: Train Hub Waterbird",
137 | "type": "python",
138 | "request": "launch",
139 | "program": "${workspaceFolder}/downstream_tasks/train_hub_waterbird.py",
140 | "args": [
141 | "--dataset", "cub",
142 | "--syndata_p", "0.1",
143 | "--syndata_dir", "outputs/aug_samples/cub/dreambooth-lora-mixup-Multi7-db_ti10000-Strength0.5",
144 | "--model", "resnet50",
145 | "--gamma", "0.8",
146 | "--examples_per_class", "-1",
147 | "--gpu", "4",
148 | "--amp", "2",
149 | "--note", "${env:DATE}",
150 | "--group_note", "test",
151 | "--nepoch", "60",
152 | "--res_mode", "224",
153 | "--lr", "0.05",
154 | "--seed", "0",
155 | "--weight_decay", "0.0005"
156 | ],
157 | "console": "integratedTerminal",
158 | "env": {
159 | "DATE": "${command:python.interpreterPath} -c \"import datetime; print(datetime.datetime.now().strftime('%m%d%H%M'))\""
160 | }
161 | },
162 | {
163 | "name": "Python: Train Hub Imb",
164 | "type": "python",
165 | "request": "launch",
166 | "program": "${workspaceFolder}/downstream_tasks/train_hub_imb.py",
167 | "args": [
168 | "--dataset", "cub",
169 | "--loss_type", "CE",
170 | "--lr", "0.005",
171 | "--epochs", "200",
172 | "--imb_factor", "0.01",
173 | "-b", "128",
174 | "--gpu", "4",
175 | "--data_aug", "vanilla",
176 | "--root_log", "outputs/results_cmo",
177 | "--syndata_dir", "outputs/aug_samples_imbalance/cub/dreambooth-lora-mixup-Multi10-db_ti_latest_imb_0.1-Strength0.7",
178 | "--syndata_p", "0.5",
179 | "--gamma", "0.8",
180 | "--use_weighted_syn"
181 | ],
182 | "console": "integratedTerminal",
183 | }
184 | ]
185 | }
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.analysis.typeCheckingMode": "off"
3 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Zhicai Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
Enhance Image Classification Via Inter-Class Image Mixup With Diffusion Model
5 |
6 |
7 |
8 |
9 |
10 |

11 |
12 |
13 | ## Introduction 👋
14 | This repository implements various **generative data augmentation** strategies using stable diffusion to create synthetic datasets, aimed at enhancing classification tasks.
15 | ## Requirements
16 | The key packages and their versions are listed below. The code is tested on a single node with 4 NVIDIA RTX3090 GPUs.
17 | ```
18 | torch==2.0.1+cu118
19 | diffusers==0.25.1
20 | transformers==4.36.2
21 | datasets==2.16.1
22 | accelerate==0.26.1
23 | numpy==1.24.4
24 | ```
25 | ## Datasets
26 | For convenience, well-structured datasets in Hugging Face can be utilized. The fine-grained datasets `CUB` and `Aircraft` we experimented with can be downloaded from [Multimodal-Fatima/CUB_train](https://huggingface.co/datasets/Multimodal-Fatima/CUB_train) and [Multimodal-Fatima/FGVC_Aircraft_train](https://huggingface.co/datasets/Multimodal-Fatima/FGVC_Aircraft_train), respectively. In case of encountering network connection problem during training, please pre-download the data from the website, and the saved local path `HUG_LOCAL_IMAGE_TRAIN_DIR` should be specified in the `dataset/instance/cub.py`.
27 |
28 |
29 |
30 | ## Fine-tune on a dataset 🔥
31 | ### Pre-trained lora weights
32 | We provide the lora weights fine-tuned on the full dataset in case for fast reproducation on given datasets. One can download using the following link, and unzip the file into dir `ckpts` and the file structure look like:
33 |
34 | ```
35 | ckpts
36 | ├── cub -packages/torch/nn/modules/module.py", line 1501, in _call_impl
37 | │ └── shot-1-lora-rank10
38 | │ ├── learned_embeds-steps-last.bin -packages/diffusers/models/attention_processor.py", line 527, in forward
39 | │ └── pytorch_lora_weights.safetensors
40 | └── put_finetuned_ckpts_here.txt
41 | ```
42 |
43 | | Dataset | data | ckpts (fullshot) |
44 | |---------|---------------------------------------------------------------------|---------------------------------------------------------------------|
45 | | CUB | huggingface ([train](https://huggingface.co/datasets/Multimodal-Fatima/CUB_train)/[test](https://huggingface.co/datasets/Multimodal-Fatima/CUB_test))| [google drive](https://drive.google.com/file/d/1AOX4TcXSPGRSmxSgB08L8P-28c5TPkxw/view?usp=sharing) |
46 | | Flower | [official website ](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/) | [google drive](https://drive.google.com/file/d/1hBodBaLb_GokxfMXvQyhr4OGzyBgyBm0/view?usp=sharing) |
47 | | Aircraft | huggingface ([train](https://huggingface.co/datasets/Multimodal-Fatima/FGVC_Aircraft_train)/[test](https://huggingface.co/datasets/Multimodal-Fatima/FGVC_Aircraft_test)) | [google drive](https://drive.google.com/file/d/19PuRbIsurv1IKeu-jx5WieocMy5rfIKg/view?usp=sharing) |
48 |
49 |
50 |
51 | ### Customized fine-tuning
52 | The `scripts/finetune.sh` script allows users to perform fine-tuning on their own datasets. By default, it implements a fine-tuning strategy combining `DreamBooth` and `Textual Inversion`. Users can customize the `examples_per_class` argument to fine-tune the model on a dataset with {examples_per_class} shots. The tuning process costs around 4 hours on 4 RTX3090 GPUs for full-shot cub dataset.
53 |
54 | ```
55 | MODEL_NAME="runwayml/stable-diffusion-v1-5"
56 | DATASET='cub'
57 | SHOT=-1 # set -1 for full shot
58 | OUTPUT_DIR="ckpts/${DATASET}/shot${SHOT}_lora_rank10"
59 |
60 | accelerate launch --mixed_precision='fp16' --main_process_port 29507 \
61 | train_lora.py \
62 | --pretrained_model_name_or_path=$MODEL_NAME \
63 | --dataset_name=$DATASET \
64 | --resolution=224 \
65 | --random_flip \
66 | --max_train_steps=35000 \
67 | --num_train_epochs=10 \
68 | --checkpointing_steps=5000 \
69 | --learning_rate=5e-05 \
70 | --lr_scheduler='constant' \
71 | --lr_warmup_steps=0 \
72 | --seed=42 \
73 | --rank=10 \
74 | --local_files_only \
75 | --examples_per_class $SHOT \
76 | --train_batch_size 2 \
77 | --output_dir=$OUTPUT_DIR \
78 | --report_to='tensorboard'"
79 | ```
80 |
81 | ## Contruct synthetic data
82 | `scripts/sample.sh` provides script to synthesize augmented images in a multi-processing way. Each item in `GPU_IDS` denotes the process running on the indexed GPU. The simplified command for sampling a $5\times$ synthetic subset in an inter-class translation manner (`diff-mix`) with strength $s=0.7$ is:
83 |
84 | ```bash
85 | DATASET='cub'
86 | # set -1 for full shot
87 | SHOT=-1
88 | FINETUNED_CKPT="ckpts/cub/shot${SHOT}-lora-rank10"
89 | # ['diff-mix', 'diff-aug', 'diff-gen', 'real-mix', 'real-aug', 'real-gen', 'ti_mix', 'ti_aug']
90 | SAMPLE_STRATEGY='diff-mix'
91 | STRENGTH=0.8
92 | # ['fixed', 'uniform']. 'fixed': use fixed $STRENGTH, 'uniform': sample from [0.3, 0.5, 0.7, 0.9]
93 | STRENGTH_STRATEGY='fixed'
94 | # expand the dataset by 5 times
95 | MULTIPLIER=5
96 | # spwan 4 processes
97 | GPU_IDS=(0 1 2 3)
98 |
99 | python scripts/sample_mp.py \
100 | --model-path='runwayml/stable-diffusion-v1-5' \
101 | --output_root='outputs/aug_samples' \
102 | --dataset=$DATASET \
103 | --finetuned_ckpt=$FINETUNED_CKPT \
104 | --syn_dataset_mulitiplier=$MULTIPLIER \
105 | --strength_strategy=$STRENGTH_STRATEGY \
106 | --sample_strategy=$SAMPLE_STRATEGY \
107 | --examples_per_class=$SHOT \
108 | --resolution=512 \
109 | --batch_size=1 \
110 | --aug_strength=0.8 \
111 | --gpu-ids=${GPU_IDS[@]}
112 | ```
113 | The output synthetic dir will be located at `aug_samples/cub/diff-mix_-1_fixed_0.7`. To create a 5-shot setting, set the `examples_per_class` argument to 5 and the output dir will be at `aug_samples/cub/diff-mix_5_fixed_0.7`. Please ensure that the `finetuned_ckpt` is also fine-tuned under the same 5-shot setting.
114 |
115 |
116 | ## Downstream classification
117 | After completing the sampling process, you can integrate the synthetic data into downstream classification and initiate training using the script `scripts/classification.sh`:
118 | ```
119 | GPU=1
120 | DATASET="cub"
121 | SHOT=-1
122 | # "shot{args.examples_per_class}_{args.sample_strategy}_{args.strength_strategy}_{args.aug_strength}"
123 | SYNDATA_DIR="aug_samples/cub/shot${SHOT}_diff-mix_fixed_0.7" # shot-1 denotes full shot
124 | SYNDATA_P=0.1
125 | GAMMA=0.8
126 |
127 | python downstream_tasks/train_hub.py \
128 | --dataset $DATASET \
129 | --syndata_dir $SYNDATA_DIR \
130 | --syndata_p $SYNDATA_P \
131 | --model "resnet50" \
132 | --gamma $GAMMA \
133 | --examples_per_class $SHOT \
134 | --gpu $GPU \
135 | --amp 2 \
136 | --note $(date +%m%d%H%M) \
137 | --group_note "fullshot" \
138 | --nepoch 120 \
139 | --res_mode 224 \
140 | --lr 0.05 \
141 | --seed 0 \
142 | --weight_decay 0.0005
143 | ```
144 |
145 | We also provides the scripts for robustness test and long-tail classification in `scripts/classification_waterbird.sh` and `scripts/classification_imb.sh`, respectively.
146 | ## Acknowledgements
147 |
148 | This project is built upon the repository [Da-fusion](https://github.com/brandontrabucco/da-fusion) and [diffusers](https://github.com/huggingface/diffusers). Special thanks to the contributors.
149 |
150 |
151 |
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhicaiwww/Diff-Mix/a81337b1492bcc7f7a8a61921836e94191a3a0ef/assets/teaser.png
--------------------------------------------------------------------------------
/augmentation/__init__.py:
--------------------------------------------------------------------------------
1 | from .diff_mix import *
2 | from .real_mix import *
3 | from .ti_mix import *
4 |
5 | AUGMENT_METHODS = {
6 | "ti-mix": TextualInversionMixup,
7 | "ti_aug": TextualInversionMixup,
8 | "real-aug": DreamboothLoraMixup,
9 | "real-mix": DreamboothLoraMixup,
10 | "real-gen": RealGeneration,
11 | "diff-mix": DreamboothLoraMixup,
12 | "diff-aug": DreamboothLoraMixup,
13 | "diff-gen": DreamboothLoraGeneration,
14 | }
15 |
--------------------------------------------------------------------------------
/augmentation/base_augmentation.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Any, Tuple
3 |
4 | import torch.nn as nn
5 | from PIL import Image
6 |
7 |
8 | class GenerativeAugmentation(nn.Module, abc.ABC):
9 |
10 | @abc.abstractmethod
11 | def forward(
12 | self, image: Image.Image, label: int, metadata: dict
13 | ) -> Tuple[Image.Image, int]:
14 |
15 | return NotImplemented
16 |
17 |
18 | class GenerativeMixup(nn.Module, abc.ABC):
19 |
20 | @abc.abstractmethod
21 | def forward(
22 | self, image: Image.Image, label: int, metadata: dict, strength: float
23 | ) -> Tuple[Image.Image, int]:
24 |
25 | return NotImplemented
26 |
--------------------------------------------------------------------------------
/augmentation/diff_mix.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Callable, Tuple
3 |
4 | import torch
5 | from PIL import Image
6 | from torch import autocast
7 | from transformers import CLIPTextModel, CLIPTokenizer
8 |
9 | from augmentation.base_augmentation import GenerativeMixup
10 | from diffusers import (
11 | DPMSolverMultistepScheduler,
12 | StableDiffusionImg2ImgPipeline,
13 | StableDiffusionPipeline,
14 | )
15 | from diffusers.utils import logging
16 |
17 | os.environ["WANDB_DISABLED"] = "true"
18 | ERROR_MESSAGE = "Tokenizer already contains the token {token}. \
19 | Please pass a different `token` that is not already in the tokenizer."
20 |
21 |
22 | def format_name(name):
23 | return f"<{name.replace(' ', '_')}>"
24 |
25 |
26 | def load_diffmix_embeddings(
27 | embed_path: str,
28 | text_encoder: CLIPTextModel,
29 | tokenizer: CLIPTokenizer,
30 | device="cuda",
31 | ):
32 |
33 | embedding_ckpt = torch.load(embed_path, map_location="cpu")
34 | learned_embeds_dict = embedding_ckpt["learned_embeds_dict"]
35 | name2placeholder = embedding_ckpt["name2placeholder"]
36 | placeholder2name = embedding_ckpt["placeholder2name"]
37 |
38 | name2placeholder = {
39 | k.replace("/", " ").replace("_", " "): v for k, v in name2placeholder.items()
40 | }
41 | placeholder2name = {
42 | v: k.replace("/", " ").replace("_", " ") for k, v in name2placeholder.items()
43 | }
44 |
45 | for token, token_embedding in learned_embeds_dict.items():
46 |
47 | # add the token in tokenizer
48 | num_added_tokens = tokenizer.add_tokens(token)
49 | assert num_added_tokens > 0, ERROR_MESSAGE.format(token=token)
50 |
51 | # resize the token embeddings
52 | text_encoder.resize_token_embeddings(len(tokenizer))
53 | added_token_id = tokenizer.convert_tokens_to_ids(token)
54 |
55 | # get the old word embeddings
56 | embeddings = text_encoder.get_input_embeddings()
57 |
58 | # get the id for the token and assign new embeds
59 | embeddings.weight.data[added_token_id] = token_embedding.to(
60 | embeddings.weight.dtype
61 | )
62 |
63 | return name2placeholder, placeholder2name
64 |
65 |
66 | def identity(*args):
67 | return args
68 |
69 |
70 | class IdentityMap:
71 | def __getitem__(self, key):
72 | return key
73 |
74 |
75 | class DreamboothLoraMixup(GenerativeMixup):
76 |
77 | pipe = None # global sharing is a hack to avoid OOM
78 |
79 | def __init__(
80 | self,
81 | lora_path: str,
82 | model_path: str = "runwayml/stable-diffusion-v1-5",
83 | embed_path: str = None,
84 | prompt: str = "a photo of a {name}",
85 | format_name: Callable = format_name,
86 | guidance_scale: float = 7.5,
87 | disable_safety_checker: bool = True,
88 | revision: str = None,
89 | device="cuda",
90 | **kwargs,
91 | ):
92 |
93 | super(DreamboothLoraMixup, self).__init__()
94 |
95 | if DreamboothLoraMixup.pipe is None:
96 |
97 | PipelineClass = StableDiffusionImg2ImgPipeline
98 |
99 | DreamboothLoraMixup.pipe = PipelineClass.from_pretrained(
100 | model_path,
101 | use_auth_token=True,
102 | revision=revision,
103 | local_files_only=True,
104 | torch_dtype=torch.float16,
105 | ).to(device)
106 |
107 | scheduler = DPMSolverMultistepScheduler.from_config(
108 | DreamboothLoraMixup.pipe.scheduler.config, local_files_only=True
109 | )
110 | self.placeholder2name = {}
111 | self.name2placeholder = {}
112 | if embed_path is not None:
113 | self.name2placeholder, self.placeholder2name = load_diffmix_embeddings(
114 | embed_path,
115 | DreamboothLoraMixup.pipe.text_encoder,
116 | DreamboothLoraMixup.pipe.tokenizer,
117 | )
118 | if lora_path is not None:
119 | DreamboothLoraMixup.pipe.load_lora_weights(lora_path)
120 | DreamboothLoraMixup.pipe.scheduler = scheduler
121 |
122 | print(f"successfuly load lora weights from {lora_path}! ! ! ")
123 |
124 | logging.disable_progress_bar()
125 | self.pipe.set_progress_bar_config(disable=True)
126 |
127 | if disable_safety_checker:
128 | self.pipe.safety_checker = None
129 |
130 | self.prompt = prompt
131 | self.guidance_scale = guidance_scale
132 | self.format_name = format_name
133 |
134 | def forward(
135 | self,
136 | image: Image.Image,
137 | label: int,
138 | metadata: dict,
139 | strength: float = 0.5,
140 | resolution=512,
141 | ) -> Tuple[Image.Image, int]:
142 |
143 | canvas = [img.resize((resolution, resolution), Image.BILINEAR) for img in image]
144 | name = metadata.get("name", "")
145 |
146 | if self.name2placeholder is not None:
147 | name = self.name2placeholder[name]
148 | if metadata.get("super_class", None) is not None:
149 | name = name + " " + metadata.get("super_class", "")
150 | prompt = self.prompt.format(name=name)
151 |
152 | print(prompt)
153 |
154 | kwargs = dict(
155 | image=canvas,
156 | prompt=[prompt],
157 | strength=strength,
158 | guidance_scale=self.guidance_scale,
159 | num_inference_steps=25,
160 | num_images_per_prompt=len(canvas),
161 | )
162 |
163 | has_nsfw_concept = True
164 | while has_nsfw_concept:
165 | with autocast("cuda"):
166 | outputs = self.pipe(**kwargs)
167 |
168 | has_nsfw_concept = (
169 | self.pipe.safety_checker is not None
170 | and outputs.nsfw_content_detected[0]
171 | )
172 | canvas = []
173 | for orig, out in zip(image, outputs.images):
174 | canvas.append(out.resize(orig.size, Image.BILINEAR))
175 | return canvas, label
176 |
177 |
178 | class DreamboothLoraGeneration(GenerativeMixup):
179 |
180 | pipe = None # global sharing is a hack to avoid OOM
181 |
182 | def __init__(
183 | self,
184 | lora_path: str,
185 | model_path: str = "runwayml/stable-diffusion-v1-5",
186 | embed_path: str = None,
187 | prompt: str = "a photo of a {name}",
188 | format_name: Callable = format_name,
189 | guidance_scale: float = 7.5,
190 | disable_safety_checker: bool = True,
191 | revision: str = None,
192 | device="cuda",
193 | **kwargs,
194 | ):
195 |
196 | super(DreamboothLoraGeneration, self).__init__()
197 |
198 | if DreamboothLoraGeneration.pipe is None:
199 |
200 | PipelineClass = StableDiffusionPipeline
201 |
202 | DreamboothLoraGeneration.pipe = PipelineClass.from_pretrained(
203 | model_path,
204 | use_auth_token=True,
205 | revision=revision,
206 | local_files_only=True,
207 | torch_dtype=torch.float16,
208 | ).to(device)
209 |
210 | scheduler = DPMSolverMultistepScheduler.from_config(
211 | DreamboothLoraGeneration.pipe.scheduler.config, local_files_only=True
212 | )
213 | self.placeholder2name = None
214 | self.name2placeholder = None
215 | if embed_path is not None:
216 | self.name2placeholder, self.placeholder2name = load_diffmix_embeddings(
217 | embed_path,
218 | DreamboothLoraGeneration.pipe.text_encoder,
219 | DreamboothLoraGeneration.pipe.tokenizer,
220 | )
221 | if lora_path is not None:
222 | DreamboothLoraGeneration.pipe.load_lora_weights(lora_path)
223 | DreamboothLoraGeneration.pipe.scheduler = scheduler
224 |
225 | print(f"successfuly load lora weights from {lora_path}! ! ! ")
226 |
227 | logging.disable_progress_bar()
228 | self.pipe.set_progress_bar_config(disable=True)
229 |
230 | if disable_safety_checker:
231 | self.pipe.safety_checker = None
232 |
233 | self.prompt = prompt
234 | self.guidance_scale = guidance_scale
235 | self.format_name = format_name
236 |
237 | def forward(
238 | self,
239 | image: Image.Image,
240 | label: int,
241 | metadata: dict,
242 | strength: float = 0.5,
243 | resolution=512,
244 | ) -> Tuple[Image.Image, int]:
245 |
246 | name = metadata.get("name", "")
247 |
248 | if self.name2placeholder is not None:
249 | name = self.name2placeholder[name]
250 | if metadata.get("super_class", None) is not None:
251 | name = name + " " + metadata.get("super_class", "")
252 | prompt = self.prompt.format(name=name)
253 |
254 | print(prompt)
255 |
256 | kwargs = dict(
257 | prompt=[prompt],
258 | guidance_scale=self.guidance_scale,
259 | num_inference_steps=25,
260 | num_images_per_prompt=len(image),
261 | height=resolution,
262 | width=resolution,
263 | )
264 |
265 | has_nsfw_concept = True
266 | while has_nsfw_concept:
267 | with autocast("cuda"):
268 | outputs = self.pipe(**kwargs)
269 |
270 | has_nsfw_concept = (
271 | self.pipe.safety_checker is not None
272 | and outputs.nsfw_content_detected[0]
273 | )
274 |
275 | canvas = []
276 | for out in outputs.images:
277 | canvas.append(out)
278 | return canvas, label
279 |
--------------------------------------------------------------------------------
/augmentation/real_mix.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from typing import Any, Callable, Tuple
4 |
5 | import numpy as np
6 | import torch
7 | from PIL import Image, ImageOps
8 | from scipy.ndimage import maximum_filter
9 | from torch import autocast
10 |
11 | from augmentation.base_augmentation import GenerativeAugmentation, GenerativeMixup
12 | from diffusers import (
13 | DPMSolverMultistepScheduler,
14 | StableDiffusionImg2ImgPipeline,
15 | StableDiffusionInpaintPipeline,
16 | StableDiffusionPipeline,
17 | )
18 | from diffusers.utils import logging
19 |
20 |
21 | def format_name(name):
22 | return f"<{name.replace(' ', '_')}>"
23 |
24 |
25 | class RealGeneration(GenerativeMixup):
26 |
27 | pipe = None # global sharing is a hack to avoid OOM
28 |
29 | def __init__(
30 | self,
31 | model_path: str = "runwayml/stable-diffusion-v1-5",
32 | prompt: str = "a photo of a {name}",
33 | format_name: Callable = format_name,
34 | guidance_scale: float = 7.5,
35 | mask: bool = False,
36 | inverted: bool = False,
37 | mask_grow_radius: int = 16,
38 | disable_safety_checker: bool = True,
39 | revision: str = None,
40 | device="cuda",
41 | **kwargs,
42 | ):
43 |
44 | super(RealGeneration, self).__init__()
45 |
46 | if RealGeneration.pipe is None:
47 |
48 | PipelineClass = StableDiffusionPipeline
49 |
50 | RealGeneration.pipe = PipelineClass.from_pretrained(
51 | model_path,
52 | use_auth_token=True,
53 | revision=revision,
54 | local_files_only=True,
55 | torch_dtype=torch.float16,
56 | ).to(device)
57 | scheduler = DPMSolverMultistepScheduler.from_config(
58 | RealGeneration.pipe.scheduler.config
59 | )
60 | RealGeneration.pipe.scheduler = scheduler
61 | logging.disable_progress_bar()
62 | self.pipe.set_progress_bar_config(disable=True)
63 |
64 | if disable_safety_checker:
65 | self.pipe.safety_checker = None
66 |
67 | self.prompt = prompt
68 | self.guidance_scale = guidance_scale
69 | self.format_name = format_name
70 |
71 | self.mask = mask
72 | self.inverted = inverted
73 | self.mask_grow_radius = mask_grow_radius
74 |
75 | def forward(
76 | self,
77 | image: Image.Image,
78 | label: int,
79 | metadata: dict,
80 | strength: float = 0.5,
81 | resolution=512,
82 | ) -> Tuple[Image.Image, int]:
83 |
84 | name = self.format_name(metadata.get("name", ""))
85 | prompt = self.prompt.format(name=name)
86 |
87 | if self.mask:
88 | assert "mask" in metadata, "mask=True but no mask present in metadata"
89 |
90 | # word_name = metadata.get("name", "").replace(" ", "")
91 |
92 | kwargs = dict(
93 | prompt=[prompt],
94 | guidance_scale=self.guidance_scale,
95 | num_inference_steps=25,
96 | num_images_per_prompt=len(image),
97 | height=resolution,
98 | width=resolution,
99 | )
100 |
101 | if self.mask: # use focal object mask
102 | # TODO
103 | mask_image = metadata["mask"].resize((512, 512), Image.NEAREST)
104 | mask_image = Image.fromarray(
105 | maximum_filter(np.array(mask_image), size=self.mask_grow_radius)
106 | )
107 |
108 | if self.inverted:
109 |
110 | mask_image = ImageOps.invert(mask_image.convert("L")).convert("1")
111 |
112 | kwargs["mask_image"] = mask_image
113 |
114 | has_nsfw_concept = True
115 | while has_nsfw_concept:
116 | with autocast("cuda"):
117 | outputs = self.pipe(**kwargs)
118 |
119 | has_nsfw_concept = (
120 | self.pipe.safety_checker is not None
121 | and outputs.nsfw_content_detected[0]
122 | )
123 |
124 | canvas = []
125 | for orig, out in zip(image, outputs.images):
126 | canvas.append(out.resize(orig.size, Image.BILINEAR))
127 |
128 | return canvas, label
129 |
130 |
131 | class RealGuidance(GenerativeAugmentation):
132 |
133 | pipe = None # global sharing is a hack to avoid OOM
134 |
135 | def __init__(
136 | self,
137 | model_path: str = "runwayml/stable-diffusion-v1-5",
138 | prompt: str = "a photo of a {name}",
139 | strength: float = 0.5,
140 | guidance_scale: float = 7.5,
141 | mask: bool = False,
142 | inverted: bool = False,
143 | mask_grow_radius: int = 16,
144 | erasure_ckpt_path: str = None,
145 | disable_safety_checker: bool = True,
146 | **kwargs,
147 | ):
148 |
149 | super(RealGuidance, self).__init__()
150 |
151 | if RealGuidance.pipe is None:
152 |
153 | PipelineClass = (
154 | StableDiffusionInpaintPipeline
155 | if mask
156 | else StableDiffusionImg2ImgPipeline
157 | )
158 |
159 | self.pipe = PipelineClass.from_pretrained(
160 | model_path,
161 | use_auth_token=True,
162 | revision="fp16",
163 | torch_dtype=torch.float16,
164 | ).to("cuda")
165 |
166 | logging.disable_progress_bar()
167 | self.pipe.set_progress_bar_config(disable=True)
168 |
169 | if disable_safety_checker:
170 | self.pipe.safety_checker = None
171 |
172 | self.prompt = prompt
173 | self.strength = strength
174 | self.guidance_scale = guidance_scale
175 |
176 | self.mask = mask
177 | self.inverted = inverted
178 | self.mask_grow_radius = mask_grow_radius
179 |
180 | self.erasure_ckpt_path = erasure_ckpt_path
181 | self.erasure_word_name = None
182 |
183 | def forward(
184 | self, image: Image.Image, label: int, metadata: dict
185 | ) -> Tuple[Image.Image, int]:
186 |
187 | canvas = image.resize((512, 512), Image.BILINEAR)
188 | prompt = self.prompt.format(name=metadata.get("name", ""))
189 |
190 | if self.mask:
191 | assert "mask" in metadata, "mask=True but no mask present in metadata"
192 |
193 | word_name = metadata.get("name", "").replace(" ", "")
194 |
195 | if self.erasure_ckpt_path is not None and (
196 | self.erasure_word_name is None or self.erasure_word_name != word_name
197 | ):
198 |
199 | self.erasure_word_name = word_name
200 | ckpt_name = "method_full-sg_3-ng_1-iter_1000-lr_1e-05"
201 |
202 | ckpt_path = os.path.join(
203 | self.erasure_ckpt_path,
204 | f"compvis-word_{word_name}-{ckpt_name}",
205 | f"diffusers-word_{word_name}-{ckpt_name}.pt",
206 | )
207 |
208 | self.pipe.unet.load_state_dict(torch.load(ckpt_path, map_location="cuda"))
209 |
210 | kwargs = dict(
211 | image=canvas,
212 | prompt=[prompt],
213 | strength=self.strength,
214 | guidance_scale=self.guidance_scale,
215 | )
216 |
217 | if self.mask: # use focal object mask
218 |
219 | mask_image = Image.fromarray(
220 | (np.where(metadata["mask"], 255, 0)).astype(np.uint8)
221 | ).resize((512, 512), Image.NEAREST)
222 |
223 | mask_image = Image.fromarray(
224 | maximum_filter(np.array(mask_image), size=self.mask_grow_radius)
225 | )
226 |
227 | if self.inverted:
228 |
229 | mask_image = ImageOps.invert(mask_image.convert("L")).convert("1")
230 |
231 | kwargs["mask_image"] = mask_image
232 |
233 | has_nsfw_concept = True
234 | while has_nsfw_concept:
235 | with autocast("cuda"):
236 | outputs = self.pipe(**kwargs)
237 |
238 | has_nsfw_concept = (
239 | self.pipe.safety_checker is not None
240 | and outputs.nsfw_content_detected[0]
241 | )
242 |
243 | canvas = outputs.images[0].resize(image.size, Image.BILINEAR)
244 |
245 | return canvas, label
246 |
--------------------------------------------------------------------------------
/augmentation/ti_mix.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Tuple
2 |
3 | import numpy as np
4 | import torch
5 | from PIL import Image, ImageOps
6 | from scipy.ndimage import maximum_filter
7 | from torch import autocast
8 | from transformers import CLIPTextModel, CLIPTokenizer
9 |
10 | from augmentation.base_augmentation import GenerativeMixup
11 | from diffusers import (
12 | DPMSolverMultistepScheduler,
13 | StableDiffusionImg2ImgPipeline,
14 | StableDiffusionInpaintPipeline,
15 | )
16 | from diffusers.utils import logging
17 |
18 | ERROR_MESSAGE = "Tokenizer already contains the token {token}. \
19 | Please pass a different `token` that is not already in the tokenizer."
20 |
21 |
22 | def load_embeddings(
23 | embed_path: str,
24 | model_path: str = "runwayml/stable-diffusion-v1-5",
25 | revision: str = "39593d5650112b4cc580433f6b0435385882d819",
26 | ):
27 |
28 | tokenizer = CLIPTokenizer.from_pretrained(
29 | model_path, use_auth_token=True, revision=revision, subfolder="tokenizer"
30 | )
31 |
32 | text_encoder = CLIPTextModel.from_pretrained(
33 | model_path, use_auth_token=True, revision=revision, subfolder="text_encoder"
34 | )
35 |
36 | for token, token_embedding in torch.load(embed_path, map_location="cpu").items():
37 |
38 | # add the token in tokenizer
39 | num_added_tokens = tokenizer.add_tokens(token)
40 | assert num_added_tokens > 0, ERROR_MESSAGE.format(token=token)
41 |
42 | # resize the token embeddings
43 | text_encoder.resize_token_embeddings(len(tokenizer))
44 | added_token_id = tokenizer.convert_tokens_to_ids(token)
45 |
46 | # get the old word embeddings
47 | embeddings = text_encoder.get_input_embeddings()
48 |
49 | # get the id for the token and assign new embeds
50 | embeddings.weight.data[added_token_id] = token_embedding.to(
51 | embeddings.weight.dtype
52 | )
53 |
54 | return tokenizer, text_encoder
55 |
56 |
57 | def format_name(name):
58 | return f"<{name.replace(' ', '_')}>"
59 |
60 |
61 | class TextualInversionMixup(GenerativeMixup):
62 |
63 | pipe = None # global sharing is a hack to avoid OOM
64 |
65 | def __init__(
66 | self,
67 | embed_path: str,
68 | model_path: str = "runwayml/stable-diffusion-v1-5",
69 | prompt: str = "a photo of a {name}",
70 | format_name: Callable = format_name,
71 | guidance_scale: float = 7.5,
72 | mask: bool = False,
73 | inverted: bool = False,
74 | mask_grow_radius: int = 16,
75 | disable_safety_checker: bool = True,
76 | revision: str = "39593d5650112b4cc580433f6b0435385882d819",
77 | device="cuda",
78 | **kwargs,
79 | ):
80 |
81 | super().__init__()
82 |
83 | if TextualInversionMixup.pipe is None:
84 |
85 | PipelineClass = (
86 | StableDiffusionInpaintPipeline
87 | if mask
88 | else StableDiffusionImg2ImgPipeline
89 | )
90 |
91 | tokenizer, text_encoder = load_embeddings(
92 | embed_path, model_path=model_path, revision=revision
93 | )
94 |
95 | TextualInversionMixup.pipe = PipelineClass.from_pretrained(
96 | model_path,
97 | use_auth_token=True,
98 | revision=revision,
99 | torch_dtype=torch.float16,
100 | ).to(device)
101 | scheduler = DPMSolverMultistepScheduler.from_config(
102 | TextualInversionMixup.pipe.scheduler.config
103 | )
104 | TextualInversionMixup.pipe.scheduler = scheduler
105 | self.pipe.tokenizer = tokenizer
106 | self.pipe.text_encoder = text_encoder.to(device)
107 |
108 | logging.disable_progress_bar()
109 | self.pipe.set_progress_bar_config(disable=True)
110 |
111 | if disable_safety_checker:
112 | self.pipe.safety_checker = None
113 |
114 | self.prompt = prompt
115 | self.guidance_scale = guidance_scale
116 | self.format_name = format_name
117 |
118 | self.mask = mask
119 | self.inverted = inverted
120 | self.mask_grow_radius = mask_grow_radius
121 |
122 | self.erasure_word_name = None
123 |
124 | def forward(
125 | self, image: Image.Image, label: int, metadata: dict, strength: float = 0.5
126 | ) -> Tuple[Image.Image, int]:
127 |
128 | canvas = image.resize((512, 512), Image.BILINEAR)
129 | name = self.format_name(metadata.get("name", ""))
130 | prompt = self.prompt.format(name=name)
131 |
132 | if self.mask:
133 | assert "mask" in metadata, "mask=True but no mask present in metadata"
134 |
135 | word_name = metadata.get("name", "").replace(" ", "")
136 |
137 | kwargs = dict(
138 | image=canvas,
139 | prompt=[prompt],
140 | strength=strength,
141 | guidance_scale=self.guidance_scale,
142 | )
143 |
144 | if self.mask: # use focal object mask
145 | # TODO
146 | mask_image = Image.fromarray(
147 | (np.where(metadata["mask"], 255, 0)).astype(np.uint8)
148 | ).resize((512, 512), Image.NEAREST)
149 |
150 | mask_image = Image.fromarray(
151 | maximum_filter(np.array(mask_image), size=self.mask_grow_radius)
152 | )
153 |
154 | if self.inverted:
155 |
156 | mask_image = ImageOps.invert(mask_image.convert("L")).convert("1")
157 |
158 | kwargs["mask_image"] = mask_image
159 |
160 | has_nsfw_concept = True
161 | while has_nsfw_concept:
162 | with autocast("cuda"):
163 | outputs = self.pipe(**kwargs)
164 |
165 | has_nsfw_concept = (
166 | self.pipe.safety_checker is not None
167 | and outputs.nsfw_content_detected[0]
168 | )
169 |
170 | canvas = outputs.images[0].resize(image.size, Image.BILINEAR)
171 |
172 | return canvas, label
173 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .instance.aircraft import *
2 | from .instance.car import *
3 | from .instance.cub import *
4 | from .instance.dog import *
5 | from .instance.flower import *
6 | from .instance.food import *
7 | from .instance.pascal import *
8 | from .instance.pet import *
9 | from .instance.waterbird import *
10 |
11 | DATASET_NAME_MAPPING = {
12 | "cub": CUBBirdHugDataset,
13 | "flower": Flowers102Dataset,
14 | "car": CarHugDataset,
15 | "pet": PetHugDataset,
16 | "aircraft": AircraftHugDataset,
17 | "food": FoodHugDataset,
18 | "pascal": PascalDataset,
19 | "dog": StanfordDogDataset,
20 | }
21 | IMBALANCE_DATASET_NAME_MAPPING = {
22 | "cub": CUBBirdHugImbalanceDataset,
23 | "flower": FlowersImbalanceDataset,
24 | }
25 | T2I_DATASET_NAME_MAPPING = {
26 | "cub": CUBBirdHugDatasetForT2I,
27 | "flower": FlowersDatasetForT2I,
28 | }
29 | T2I_IMBALANCE_DATASET_NAME_MAPPING = {
30 | "cub": CUBBirdHugImbalanceDatasetForT2I,
31 | "flower": FlowersImbalanceDatasetForT2I,
32 | }
33 |
--------------------------------------------------------------------------------
/dataset/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import math
3 | import os
4 | import random
5 | from typing import Any, List, Tuple, Union
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import torch
10 | import torchvision.transforms as transforms
11 | from PIL import Image
12 | from torch.utils.data import Dataset
13 |
14 |
15 | def onehot(size, target):
16 | vec = torch.zeros(size, dtype=torch.float32)
17 | vec[target] = 1.0
18 | return vec
19 |
20 |
21 | class SyntheticDataset(Dataset):
22 | def __init__(
23 | self,
24 | synthetic_dir: Union[str, List[str]] = None,
25 | gamma: int = 1,
26 | soft_scaler: float = 1,
27 | num_syn_seeds: int = 999,
28 | image_size: int = 512,
29 | crop_size: int = 448,
30 | class2label: dict = None,
31 | ) -> None:
32 | super().__init__()
33 |
34 | self.synthetic_dir = synthetic_dir
35 | self.num_syn_seeds = num_syn_seeds # number of seeds to generate synthetic data
36 | self.gamma = gamma
37 | self.soft_scaler = soft_scaler
38 | self.class_names = None
39 |
40 | self.parse_syn_data_pd(synthetic_dir)
41 |
42 | test_transform = transforms.Compose(
43 | [
44 | transforms.Resize((image_size, image_size)),
45 | transforms.CenterCrop((crop_size, crop_size)),
46 | transforms.ToTensor(),
47 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
48 | ]
49 | )
50 |
51 | self.transform = test_transform
52 | self.class2label = (
53 | {name: i for i, name in enumerate(self.class_names)}
54 | if class2label is None
55 | else class2label
56 | )
57 | self.num_classes = len(self.class2label.keys())
58 |
59 | def set_transform(self, transform) -> None:
60 | self.transform = transform
61 |
62 | def parse_syn_data_pd(self, synthetic_dir) -> None:
63 | if isinstance(synthetic_dir, list):
64 | pass
65 | elif isinstance(synthetic_dir, str):
66 | synthetic_dir = [synthetic_dir]
67 | else:
68 | raise NotImplementedError("Not supported type")
69 | meta_df_list = []
70 |
71 | for _dir in synthetic_dir:
72 | meta_dir = os.path.join(_dir, self.csv_file)
73 | meta_df = pd.read_csv(meta_dir)
74 | meta_df.loc[:, "Path"] = meta_df["Path"].apply(
75 | lambda x: os.path.join(_dir, "data", x)
76 | )
77 | meta_df_list.append(meta_df)
78 | self.meta_df = pd.concat(meta_df_list).reset_index(drop=True)
79 |
80 | self.syn_nums = len(self.meta_df)
81 | self.class_names = list(set(self.meta_df["First Directory"].values))
82 | print(f"Syn numbers: {self.syn_nums}\n")
83 |
84 | def get_syn_item_raw(self, idx: int):
85 | df_data = self.meta_df.iloc[idx]
86 | src_label = self.class2label[df_data["First Directory"]]
87 | tar_label = self.class2label[df_data["Second Directory"]]
88 | path = df_data["Path"]
89 | return path, src_label, tar_label
90 |
91 | def __len__(self) -> int:
92 | return self.syn_nums
93 |
94 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
95 | path, src_label, target_label = self.get_syn_item_raw(idx)
96 | image = Image.open(path).convert("RGB")
97 | return {
98 | "pixel_values": self.transform(image),
99 | "src_label": src_label,
100 | "tar_label": target_label,
101 | }
102 |
103 |
104 | class HugFewShotDataset(Dataset):
105 |
106 | num_classes: int = None
107 | class_names: int = None
108 | class2label: dict = None
109 | label2class: dict = None
110 |
111 | def __init__(
112 | self,
113 | split: str = "train",
114 | examples_per_class: int = None,
115 | synthetic_probability: float = 0.5,
116 | return_onehot: bool = False,
117 | soft_scaler: float = 1,
118 | synthetic_dir: Union[str, List[str]] = None,
119 | image_size: int = 512,
120 | crop_size: int = 448,
121 | gamma: int = 1,
122 | num_syn_seeds: int = 99999,
123 | clip_filtered_syn: bool = False,
124 | target_class_num: int = None,
125 | **kwargs,
126 | ):
127 |
128 | self.examples_per_class = examples_per_class
129 | self.num_syn_seeds = num_syn_seeds # number of seeds to generate synthetic data
130 |
131 | self.synthetic_dir = synthetic_dir
132 | self.clip_filtered_syn = clip_filtered_syn
133 | self.return_onehot = return_onehot
134 |
135 | if self.synthetic_dir is not None:
136 | assert self.return_onehot == True
137 | self.synthetic_probability = synthetic_probability
138 | self.soft_scaler = soft_scaler
139 | self.gamma = gamma
140 | self.target_class_num = target_class_num
141 | self.parse_syn_data_pd(synthetic_dir)
142 |
143 | train_transform = transforms.Compose(
144 | [
145 | transforms.Resize((image_size, image_size)),
146 | transforms.RandomCrop(crop_size, padding=8),
147 | transforms.RandomHorizontalFlip(),
148 | transforms.ToTensor(),
149 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
150 | ]
151 | )
152 | test_transform = transforms.Compose(
153 | [
154 | transforms.Resize((image_size, image_size)),
155 | transforms.CenterCrop((crop_size, crop_size)),
156 | transforms.ToTensor(),
157 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
158 | ]
159 | )
160 | self.transform = {"train": train_transform, "val": test_transform}[split]
161 |
162 | def set_transform(self, transform) -> None:
163 | self.transform = transform
164 |
165 | @abc.abstractmethod
166 | def get_image_by_idx(self, idx: int) -> Image.Image:
167 |
168 | return NotImplemented
169 |
170 | @abc.abstractmethod
171 | def get_label_by_idx(self, idx: int) -> int:
172 |
173 | return NotImplemented
174 |
175 | @abc.abstractmethod
176 | def get_metadata_by_idx(self, idx: int) -> dict:
177 |
178 | return NotImplemented
179 |
180 | def parse_syn_data_pd(self, synthetic_dir, filter=True) -> None:
181 | if isinstance(synthetic_dir, list):
182 | pass
183 | elif isinstance(synthetic_dir, str):
184 | synthetic_dir = [synthetic_dir]
185 | else:
186 | raise NotImplementedError("Not supported type")
187 | meta_df_list = []
188 | for _dir in synthetic_dir:
189 | df_basename = (
190 | "meta.csv" if not self.clip_filtered_syn else "remained_meta.csv"
191 | )
192 | meta_dir = os.path.join(_dir, df_basename)
193 | meta_df = self.filter_df(pd.read_csv(meta_dir))
194 | meta_df.loc[:, "Path"] = meta_df["Path"].apply(
195 | lambda x: os.path.join(_dir, "data", x)
196 | )
197 | meta_df_list.append(meta_df)
198 | self.meta_df = pd.concat(meta_df_list).reset_index(drop=True)
199 | self.syn_nums = len(self.meta_df)
200 |
201 | print(f"Syn numbers: {self.syn_nums}\n")
202 |
203 | def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
204 |
205 | if self.target_class_num is not None:
206 | selected_indexs = []
207 | for source_name in self.class_names:
208 | target_classes = random.sample(self.class_names, self.target_class_num)
209 | indexs = df[
210 | (df["First Directory"] == source_name)
211 | & (df["Second Directory"].isin(target_classes))
212 | ]
213 | selected_indexs.append(indexs)
214 |
215 | meta2 = pd.concat(selected_indexs, axis=0)
216 | total_num = min(len(meta2), 18000)
217 | idxs = random.sample(range(len(meta2)), total_num)
218 | meta2 = meta2.iloc[idxs]
219 | meta2.reset_index(drop=True, inplace=True)
220 | df = meta2
221 | print("filter_df", self.target_class_num, len(df))
222 | return df
223 |
224 | def get_syn_item(self, idx: int):
225 |
226 | df_data = self.meta_df.iloc[idx]
227 | src_label = self.class2label[df_data["First Directory"]]
228 | tar_label = self.class2label[df_data["Second Directory"]]
229 | path = df_data["Path"]
230 | strength = df_data["Strength"]
231 | onehot_label = torch.zeros(self.num_classes)
232 | onehot_label[src_label] += self.soft_scaler * (
233 | 1 - math.pow(strength, self.gamma)
234 | )
235 | onehot_label[tar_label] += self.soft_scaler * math.pow(strength, self.gamma)
236 |
237 | return path, onehot_label
238 |
239 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
240 |
241 | if (
242 | self.synthetic_dir is not None
243 | and np.random.uniform() < self.synthetic_probability
244 | ):
245 | syn_idx = np.random.choice(self.syn_nums)
246 | path, label = self.get_syn_item(syn_idx)
247 | image = Image.open(path).convert("RGB")
248 | else:
249 | image = self.get_image_by_idx(idx)
250 | label = self.get_label_by_idx(idx)
251 |
252 | if self.return_onehot:
253 | if isinstance(label, (int, np.int64)):
254 | label = onehot(self.num_classes, label)
255 | return dict(pixel_values=self.transform(image), label=label)
256 |
--------------------------------------------------------------------------------
/dataset/instance/aircraft.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import defaultdict
3 | from typing import Any, Dict, Tuple
4 |
5 | import numpy as np
6 | import torch
7 | import torchvision.transforms as transforms
8 | from datasets import load_dataset
9 | from PIL import Image
10 |
11 | from dataset.base import HugFewShotDataset
12 | from dataset.template import IMAGENET_TEMPLATES_TINY
13 |
14 | SUPER_CLASS_NAME = "aircraft"
15 | HUG_IMAGE_TRAIN_DIR = r"Multimodal-Fatima/FGVC_Aircraft_train"
16 | HUG_IMAGE_TEST_DIR = r"Multimodal-Fatima/FGVC_Aircraft_test"
17 | HUG_LOCAL_IMAGE_TRAIN_DIR = "/home/zhicai/.cache/huggingface/datasets/Multimodal-Fatima___parquet/Multimodal-Fatima--FGVC_Aircraft_train-d13c51225819e71f"
18 | HUG_LOCAL_IMAGE_TEST_DIR = "/home/zhicai/.cache/huggingface/datasets/Multimodal-Fatima___parquet/Multimodal-Fatima--FGVC_Aircraft_test-2d1cae4ba1777be1"
19 |
20 |
21 | class AircraftHugDataset(HugFewShotDataset):
22 |
23 | super_class_name = SUPER_CLASS_NAME
24 |
25 | def __init__(
26 | self,
27 | *args,
28 | split: str = "train",
29 | seed: int = 0,
30 | image_train_dir: str = HUG_LOCAL_IMAGE_TRAIN_DIR,
31 | image_test_dir: str = HUG_LOCAL_IMAGE_TEST_DIR,
32 | examples_per_class: int = -1,
33 | synthetic_probability: float = 0.5,
34 | return_onehot: bool = False,
35 | soft_scaler: float = 0.9,
36 | synthetic_dir: str = None,
37 | image_size: int = 512,
38 | crop_size: int = 448,
39 | **kwargs,
40 | ):
41 |
42 | super().__init__(
43 | *args,
44 | split=split,
45 | examples_per_class=examples_per_class,
46 | synthetic_probability=synthetic_probability,
47 | return_onehot=return_onehot,
48 | soft_scaler=soft_scaler,
49 | synthetic_dir=synthetic_dir,
50 | image_size=image_size,
51 | crop_size=crop_size,
52 | **kwargs,
53 | )
54 |
55 | if split == "train":
56 | dataset = load_dataset(HUG_LOCAL_IMAGE_TRAIN_DIR, split="train")
57 | else:
58 | dataset = load_dataset(HUG_LOCAL_IMAGE_TEST_DIR, split="test")
59 |
60 | # self.class_names = [name.replace('/',' ') for name in dataset.features['label'].names]
61 |
62 | random.seed(seed)
63 | np.random.seed(seed)
64 | if examples_per_class is not None and examples_per_class > 0:
65 | all_labels = dataset["label"]
66 | label_to_indices = defaultdict(list)
67 | for i, label in enumerate(all_labels):
68 | label_to_indices[label].append(i)
69 |
70 | _all_indices = []
71 | for key, items in label_to_indices.items():
72 | try:
73 | sampled_indices = random.sample(items, examples_per_class)
74 | except ValueError:
75 | print(
76 | f"{key}: Sample larger than population or is negative, use random.choices instead"
77 | )
78 | sampled_indices = random.choices(items, k=examples_per_class)
79 |
80 | label_to_indices[key] = sampled_indices
81 | _all_indices.extend(sampled_indices)
82 | dataset = dataset.select(_all_indices)
83 |
84 | self.dataset = dataset
85 | class2label = self.dataset.features["label"]._str2int
86 | self.class2label = {k.replace("/", " "): v for k, v in class2label.items()}
87 | self.label2class = {v: k.replace("/", " ") for k, v in class2label.items()}
88 | self.class_names = [
89 | name.replace("/", " ") for name in dataset.features["label"].names
90 | ]
91 | self.num_classes = len(self.class_names)
92 |
93 | self.label_to_indices = defaultdict(list)
94 | for i, label in enumerate(self.dataset["label"]):
95 | self.label_to_indices[label].append(i)
96 |
97 | def __len__(self):
98 |
99 | return len(self.dataset)
100 |
101 | def get_image_by_idx(self, idx: int) -> Image.Image:
102 |
103 | return self.dataset[idx]["image"].convert("RGB")
104 |
105 | def get_label_by_idx(self, idx: int) -> int:
106 |
107 | return self.dataset[idx]["label"]
108 |
109 | def get_metadata_by_idx(self, idx: int) -> dict:
110 |
111 | return dict(
112 | name=self.label2class[self.get_label_by_idx(idx)],
113 | super_class=self.super_class_name,
114 | )
115 |
116 |
117 | class AircraftHugDatasetForT2I(torch.utils.data.Dataset):
118 |
119 | super_class_name = SUPER_CLASS_NAME
120 |
121 | def __init__(
122 | self,
123 | *args,
124 | split: str = "train",
125 | seed: int = 0,
126 | image_train_dir: str = HUG_LOCAL_IMAGE_TRAIN_DIR,
127 | max_train_samples: int = -1,
128 | class_prompts_ratio: float = 0.5,
129 | resolution: int = 512,
130 | center_crop: bool = False,
131 | random_flip: bool = False,
132 | use_placeholder: bool = False,
133 | examples_per_class: int = -1,
134 | **kwargs,
135 | ):
136 |
137 | super().__init__()
138 |
139 | dataset = load_dataset(
140 | "/home/zhicai/.cache/huggingface/datasets/Multimodal-Fatima___parquet/Multimodal-Fatima--FGVC_Aircraft_train-d13c51225819e71f",
141 | split="train",
142 | )
143 |
144 | random.seed(seed)
145 | np.random.seed(seed)
146 | if max_train_samples is not None and max_train_samples > 0:
147 | dataset = dataset.shuffle(seed=seed).select(range(max_train_samples))
148 | if examples_per_class is not None and examples_per_class > 0:
149 | all_labels = dataset["label"]
150 | label_to_indices = defaultdict(list)
151 | for i, label in enumerate(all_labels):
152 | label_to_indices[label].append(i)
153 |
154 | _all_indices = []
155 | for key, items in label_to_indices.items():
156 | try:
157 | sampled_indices = random.sample(items, examples_per_class)
158 | except ValueError:
159 | print(
160 | f"{key}: Sample larger than population or is negative, use random.choices instead"
161 | )
162 | sampled_indices = random.choices(items, k=examples_per_class)
163 |
164 | label_to_indices[key] = sampled_indices
165 | _all_indices.extend(sampled_indices)
166 | dataset = dataset.select(_all_indices)
167 | self.dataset = dataset
168 | class2label = self.dataset.features["label"]._str2int
169 | self.class2label = {k.replace("/", " "): v for k, v in class2label.items()}
170 | self.label2class = {v: k.replace("/", " ") for k, v in class2label.items()}
171 | self.class_names = [
172 | name.replace("/", " ") for name in dataset.features["label"].names
173 | ]
174 | self.num_classes = len(self.class_names)
175 | self.use_placeholder = use_placeholder
176 | self.name2placeholder = {}
177 | self.placeholder2name = {}
178 | self.label_to_indices = defaultdict(list)
179 | for i, label in enumerate(self.dataset["label"]):
180 | self.label_to_indices[label].append(i)
181 |
182 | self.transform = transforms.Compose(
183 | [
184 | transforms.Resize(
185 | resolution, interpolation=transforms.InterpolationMode.BILINEAR
186 | ),
187 | (
188 | transforms.CenterCrop(resolution)
189 | if center_crop
190 | else transforms.RandomCrop(resolution)
191 | ),
192 | (
193 | transforms.RandomHorizontalFlip()
194 | if random_flip
195 | else transforms.Lambda(lambda x: x)
196 | ),
197 | transforms.ToTensor(),
198 | transforms.Normalize([0.5], [0.5]),
199 | ]
200 | )
201 |
202 | def __len__(self):
203 |
204 | return len(self.dataset)
205 |
206 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
207 |
208 | image = self.get_image_by_idx(idx)
209 | prompt = self.get_prompt_by_idx(idx)
210 |
211 | return dict(pixel_values=self.transform(image), caption=prompt)
212 |
213 | def get_image_by_idx(self, idx: int) -> Image.Image:
214 |
215 | return self.dataset[idx]["image"].convert("RGB")
216 |
217 | def get_label_by_idx(self, idx: int) -> int:
218 |
219 | return self.dataset[idx]["label"]
220 |
221 | def get_prompt_by_idx(self, idx: int) -> int:
222 | # randomly choose from class name or description
223 | if self.use_placeholder:
224 | content = (
225 | self.name2placeholder[self.label2class[self.dataset[idx]["label"]]]
226 | + f" {self.super_class_name}"
227 | )
228 | else:
229 | content = self.label2class[self.dataset[idx]["label"]]
230 | prompt = random.choice(IMAGENET_TEMPLATES_TINY).format(content)
231 |
232 | return prompt
233 |
234 | def get_metadata_by_idx(self, idx: int) -> dict:
235 |
236 | return dict(
237 | name=self.label2class[self.get_label_by_idx(idx)],
238 | super_class=self.super_class_name,
239 | )
240 |
--------------------------------------------------------------------------------
/dataset/instance/car.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import defaultdict
3 | from typing import Any, Dict, Tuple
4 |
5 | import numpy as np
6 | import torch
7 | import torchvision.transforms as transforms
8 | from datasets import load_dataset
9 | from PIL import Image
10 |
11 | from dataset.base import HugFewShotDataset
12 | from dataset.template import IMAGENET_TEMPLATES_TINY
13 |
14 | SUPER_CLASS_NAME = "car"
15 | HUG_LOCAL_IMAGE_TRAIN_DIR = r"Multimodal-Fatima/StanfordCars_train"
16 | HUG_LOCAL_IMAGE_TEST_DIR = r"Multimodal-Fatima/StanfordCars_test"
17 |
18 |
19 | class CarHugDataset(HugFewShotDataset):
20 |
21 | super_class_name = SUPER_CLASS_NAME
22 |
23 | def __init__(
24 | self,
25 | *args,
26 | split: str = "train",
27 | seed: int = 0,
28 | image_train_dir: str = HUG_LOCAL_IMAGE_TRAIN_DIR,
29 | image_test_dir: str = HUG_LOCAL_IMAGE_TEST_DIR,
30 | examples_per_class: int = -1,
31 | synthetic_probability: float = 0.5,
32 | return_onehot: bool = False,
33 | soft_scaler: float = 0.9,
34 | synthetic_dir: str = None,
35 | image_size: int = 512,
36 | crop_size: int = 448,
37 | **kwargs,
38 | ):
39 |
40 | super(CarHugDataset, self).__init__(
41 | *args,
42 | split=split,
43 | examples_per_class=examples_per_class,
44 | synthetic_probability=synthetic_probability,
45 | return_onehot=return_onehot,
46 | soft_scaler=soft_scaler,
47 | synthetic_dir=synthetic_dir,
48 | image_size=image_size,
49 | crop_size=crop_size,
50 | **kwargs,
51 | )
52 |
53 | if split == "train":
54 | dataset = load_dataset(
55 | "/data/zhicai/cache/huggingface/datasets/Multimodal-Fatima___stanford_cars_train",
56 | split="train",
57 | )
58 | else:
59 | dataset = load_dataset(
60 | "/data/zhicai/cache/huggingface/datasets/Multimodal-Fatima___stanford_cars_test",
61 | split="test",
62 | )
63 |
64 | self.class_names = [
65 | name.replace("/", " ") for name in dataset.features["label"].names
66 | ]
67 |
68 | random.seed(seed)
69 | np.random.seed(seed)
70 | if examples_per_class is not None and examples_per_class > 0:
71 | all_labels = dataset["label"]
72 | label_to_indices = defaultdict(list)
73 | for i, label in enumerate(all_labels):
74 | label_to_indices[label].append(i)
75 |
76 | _all_indices = []
77 | for key, items in label_to_indices.items():
78 | try:
79 | sampled_indices = random.sample(items, examples_per_class)
80 | except ValueError:
81 | print(
82 | f"{key}: Sample larger than population or is negative, use random.choices instead"
83 | )
84 | sampled_indices = random.choices(items, k=examples_per_class)
85 |
86 | label_to_indices[key] = sampled_indices
87 | _all_indices.extend(sampled_indices)
88 | dataset = dataset.select(_all_indices)
89 |
90 | self.dataset = dataset
91 | class2label = self.dataset.features["label"]._str2int
92 | self.class2label = {k.replace("/", " "): v for k, v in class2label.items()}
93 | self.label2class = {v: k.replace("/", " ") for k, v in class2label.items()}
94 | self.class_names = [
95 | name.replace("/", " ") for name in dataset.features["label"].names
96 | ]
97 | self.num_classes = len(self.class_names)
98 |
99 | self.label_to_indices = defaultdict(list)
100 | for i, label in enumerate(self.dataset["label"]):
101 | self.label_to_indices[label].append(i)
102 |
103 | def __len__(self):
104 |
105 | return len(self.dataset)
106 |
107 | def get_image_by_idx(self, idx: int) -> Image.Image:
108 |
109 | return self.dataset[idx]["image"].convert("RGB")
110 |
111 | def get_label_by_idx(self, idx: int) -> int:
112 |
113 | return self.dataset[idx]["label"]
114 |
115 | def get_metadata_by_idx(self, idx: int) -> dict:
116 |
117 | return dict(
118 | name=self.label2class[self.get_label_by_idx(idx)],
119 | super_class=self.super_class_name,
120 | )
121 |
122 |
123 | class CarHugDatasetForT2I(torch.utils.data.Dataset):
124 |
125 | super_class_name = SUPER_CLASS_NAME
126 |
127 | def __init__(
128 | self,
129 | *args,
130 | split: str = "train",
131 | seed: int = 0,
132 | image_train_dir: str = HUG_LOCAL_IMAGE_TRAIN_DIR,
133 | max_train_samples: int = -1,
134 | class_prompts_ratio: float = 0.5,
135 | resolution: int = 512,
136 | center_crop: bool = False,
137 | random_flip: bool = False,
138 | use_placeholder: bool = False,
139 | examples_per_class: int = -1,
140 | **kwargs,
141 | ):
142 |
143 | super().__init__()
144 |
145 | dataset = load_dataset(
146 | "/data/zhicai/cache/huggingface/datasets/Multimodal-Fatima___stanford_cars_train",
147 | split="train",
148 | )
149 |
150 | random.seed(seed)
151 | np.random.seed(seed)
152 | if max_train_samples is not None and max_train_samples > 0:
153 | dataset = dataset.shuffle(seed=seed).select(range(max_train_samples))
154 | if examples_per_class is not None and examples_per_class > 0:
155 | all_labels = dataset["label"]
156 | label_to_indices = defaultdict(list)
157 | for i, label in enumerate(all_labels):
158 | label_to_indices[label].append(i)
159 |
160 | _all_indices = []
161 | for key, items in label_to_indices.items():
162 | try:
163 | sampled_indices = random.sample(items, examples_per_class)
164 | except ValueError:
165 | print(
166 | f"{key}: Sample larger than population or is negative, use random.choices instead"
167 | )
168 | sampled_indices = random.choices(items, k=examples_per_class)
169 |
170 | label_to_indices[key] = sampled_indices
171 | _all_indices.extend(sampled_indices)
172 | dataset = dataset.select(_all_indices)
173 | self.dataset = dataset
174 | class2label = self.dataset.features["label"]._str2int
175 | self.class2label = {k.replace("/", " "): v for k, v in class2label.items()}
176 | self.label2class = {v: k.replace("/", " ") for k, v in class2label.items()}
177 | self.class_names = [
178 | name.replace("/", " ") for name in dataset.features["label"].names
179 | ]
180 | self.num_classes = len(self.class_names)
181 | self.use_placeholder = use_placeholder
182 | self.name2placeholder = {}
183 | self.placeholder2name = {}
184 | self.label_to_indices = defaultdict(list)
185 | for i, label in enumerate(self.dataset["label"]):
186 | self.label_to_indices[label].append(i)
187 |
188 | self.transform = transforms.Compose(
189 | [
190 | transforms.Resize(
191 | resolution, interpolation=transforms.InterpolationMode.BILINEAR
192 | ),
193 | (
194 | transforms.CenterCrop(resolution)
195 | if center_crop
196 | else transforms.RandomCrop(resolution)
197 | ),
198 | (
199 | transforms.RandomHorizontalFlip()
200 | if random_flip
201 | else transforms.Lambda(lambda x: x)
202 | ),
203 | transforms.ToTensor(),
204 | transforms.Normalize([0.5], [0.5]),
205 | ]
206 | )
207 |
208 | def __len__(self):
209 |
210 | return len(self.dataset)
211 |
212 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
213 |
214 | image = self.get_image_by_idx(idx)
215 | prompt = self.get_prompt_by_idx(idx)
216 | # label = self.get_label_by_idx(idx)
217 |
218 | return dict(pixel_values=self.transform(image), caption=prompt)
219 |
220 | def get_image_by_idx(self, idx: int) -> Image.Image:
221 |
222 | return self.dataset[idx]["image"].convert("RGB")
223 |
224 | def get_label_by_idx(self, idx: int) -> int:
225 |
226 | return self.dataset[idx]["label"]
227 |
228 | def get_prompt_by_idx(self, idx: int) -> int:
229 | # randomly choose from class name or description
230 | if self.use_placeholder:
231 | content = (
232 | self.name2placeholder[self.label2class[self.dataset[idx]["label"]]]
233 | + f"{self.super_class_name}"
234 | )
235 | else:
236 | content = self.label2class[self.dataset[idx]["label"]]
237 | prompt = random.choice(IMAGENET_TEMPLATES_TINY).format(content)
238 |
239 | return prompt
240 |
241 | def get_metadata_by_idx(self, idx: int) -> dict:
242 |
243 | return dict(name=self.label2class[self.get_label_by_idx(idx)])
244 |
--------------------------------------------------------------------------------
/dataset/instance/dog.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from collections import defaultdict
4 | from typing import Tuple
5 |
6 | import numpy as np
7 | import torch
8 | import torchvision.transforms as transforms
9 | from PIL import Image
10 | from scipy.io import loadmat
11 |
12 | from dataset.base import HugFewShotDataset
13 | from dataset.template import IMAGENET_TEMPLATES_TINY
14 |
15 | SUPER_CLASS_NAME = "dog"
16 | DEFAULT_IMAGE_DIR = "/data/zhicai/datasets/fgvc_datasets/stanford_dogs/"
17 |
18 | CLASS_NAME = [
19 | "Chihuahua",
20 | "Japanese spaniel",
21 | "Maltese dog",
22 | "Pekinese",
23 | "Shih-Tzu",
24 | "Blenheim spaniel",
25 | "papillon",
26 | "toy terrier",
27 | "Rhodesian ridgeback",
28 | "Afghan hound",
29 | "basset",
30 | "beagle",
31 | "bloodhound",
32 | "bluetick",
33 | "black-and-tan coonhound",
34 | "Walker hound",
35 | "English foxhound",
36 | "redbone",
37 | "borzoi",
38 | "Irish wolfhound",
39 | "Italian greyhound",
40 | "whippet",
41 | "Ibizan hound",
42 | "Norwegian elkhound",
43 | "otterhound",
44 | "Saluki",
45 | "Scottish deerhound",
46 | "Weimaraner",
47 | "Staffordshire bullterrier",
48 | "American Staffordshire terrier",
49 | "Bedlington terrier",
50 | "Border terrier",
51 | "Kerry blue terrier",
52 | "Irish terrier",
53 | "Norfolk terrier",
54 | "Norwich terrier",
55 | "Yorkshire terrier",
56 | "wire-haired fox terrier",
57 | "Lakeland terrier",
58 | "Sealyham terrier",
59 | "Airedale",
60 | "cairn",
61 | "Australian terrier",
62 | "Dandie Dinmont",
63 | "Boston bull",
64 | "miniature schnauzer",
65 | "giant schnauzer",
66 | "standard schnauzer",
67 | "Scotch terrier",
68 | "Tibetan terrier",
69 | "silky terrier",
70 | "soft-coated wheaten terrier",
71 | "West Highland white terrier",
72 | "Lhasa",
73 | "flat-coated retriever",
74 | "curly-coated retriever",
75 | "golden retriever",
76 | "Labrador retriever",
77 | "Chesapeake Bay retriever",
78 | "German short-haired pointer",
79 | "vizsla",
80 | "English setter",
81 | "Irish setter",
82 | "Gordon setter",
83 | "Brittany spaniel",
84 | "clumber",
85 | "English springer",
86 | "Welsh springer spaniel",
87 | "cocker spaniel",
88 | "Sussex spaniel",
89 | "Irish water spaniel",
90 | "kuvasz",
91 | "schipperke",
92 | "groenendael",
93 | "malinois",
94 | "briard",
95 | "kelpie",
96 | "komondor",
97 | "Old English sheepdog",
98 | "Shetland sheepdog",
99 | "collie",
100 | "Border collie",
101 | "Bouvier des Flandres",
102 | "Rottweiler",
103 | "German shepherd",
104 | "Doberman",
105 | "miniature pinscher",
106 | "Greater Swiss Mountain dog",
107 | "Bernese mountain dog",
108 | "Appenzeller",
109 | "EntleBucher",
110 | "boxer",
111 | "bull mastiff",
112 | "Tibetan mastiff",
113 | "French bulldog",
114 | "Great Dane",
115 | "Saint Bernard",
116 | "Eskimo dog",
117 | "malamute",
118 | "Siberian husky",
119 | "affenpinscher",
120 | "basenji",
121 | "pug",
122 | "Leonberg",
123 | "Newfoundland",
124 | "Great Pyrenees",
125 | "Samoyed",
126 | "Pomeranian",
127 | "chow",
128 | "keeshond",
129 | "Brabancon griffon",
130 | "Pembroke",
131 | "Cardigan",
132 | "toy poodle",
133 | "miniature poodle",
134 | "standard poodle",
135 | "Mexican hairless",
136 | "dingo",
137 | "dhole",
138 | "African hunting dog",
139 | ]
140 |
141 |
142 | class StanfordDogDataset(HugFewShotDataset):
143 |
144 | class_names = CLASS_NAME
145 | super_class_name = SUPER_CLASS_NAME
146 | num_classes: int = len(class_names)
147 |
148 | def __init__(
149 | self,
150 | *args,
151 | split: str = "train",
152 | seed: int = 0,
153 | image_dir: str = DEFAULT_IMAGE_DIR,
154 | examples_per_class: int = None,
155 | synthetic_probability: float = 0.5,
156 | return_onehot: bool = False,
157 | soft_scaler: float = 0.9,
158 | synthetic_dir: str = None,
159 | image_size: int = 512,
160 | crop_size: int = 448,
161 | **kwargs,
162 | ):
163 |
164 | super().__init__(
165 | *args,
166 | split=split,
167 | examples_per_class=examples_per_class,
168 | synthetic_probability=synthetic_probability,
169 | return_onehot=return_onehot,
170 | soft_scaler=soft_scaler,
171 | synthetic_dir=synthetic_dir,
172 | image_size=image_size,
173 | crop_size=crop_size,
174 | **kwargs,
175 | )
176 |
177 | if split == "train":
178 | data_mat = loadmat(os.path.join(image_dir, "train_list.mat"))
179 | else:
180 | data_mat = loadmat(os.path.join(image_dir, "test_list.mat"))
181 |
182 | image_files = [
183 | os.path.join(image_dir, "Images", i[0][0]) for i in data_mat["file_list"]
184 | ]
185 | imagelabels = data_mat["labels"].squeeze()
186 | class_to_images = defaultdict(list)
187 |
188 | for image_idx, image_path in enumerate(image_files):
189 | class_name = self.class_names[imagelabels[image_idx] - 1]
190 | class_to_images[class_name].append(image_path)
191 |
192 | rng = np.random.default_rng(seed)
193 | class_to_ids = {
194 | key: rng.permutation(len(class_to_images[key])) for key in self.class_names
195 | }
196 |
197 | if examples_per_class is not None and examples_per_class > 0:
198 | class_to_ids = {
199 | key: ids[:examples_per_class] for key, ids in class_to_ids.items()
200 | }
201 |
202 | self.class_to_images = {
203 | key: [class_to_images[key][i] for i in ids]
204 | for key, ids in class_to_ids.items()
205 | }
206 | self.class2label = {key: i for i, key in enumerate(self.class_names)}
207 | self.label2class = {v: k for k, v in self.class2label.items()}
208 | self.all_images = sum(
209 | [self.class_to_images[key] for key in self.class_names], []
210 | )
211 | self.all_labels = [
212 | i
213 | for i, key in enumerate(self.class_names)
214 | for _ in self.class_to_images[key]
215 | ]
216 |
217 | self.label_to_indices = defaultdict(list)
218 | for i, label in enumerate(self.all_labels):
219 | self.label_to_indices[label].append(i)
220 |
221 | def __len__(self):
222 |
223 | return len(self.all_images)
224 |
225 | def get_image_by_idx(self, idx: int) -> Image.Image:
226 |
227 | return Image.open(self.all_images[idx]).convert("RGB")
228 |
229 | def get_label_by_idx(self, idx: int) -> int:
230 |
231 | return self.all_labels[idx]
232 |
233 | def get_metadata_by_idx(self, idx: int) -> dict:
234 |
235 | return dict(
236 | name=self.class_names[self.all_labels[idx]],
237 | super_class=self.super_class_name,
238 | )
239 |
240 |
241 | class StanfordDogDatasetForT2I(torch.utils.data.Dataset):
242 |
243 | class_names = CLASS_NAME
244 | super_class_name = SUPER_CLASS_NAME
245 |
246 | def __init__(
247 | self,
248 | *args,
249 | split: str = "train",
250 | seed: int = 0,
251 | image_dir: str = DEFAULT_IMAGE_DIR,
252 | max_train_samples: int = -1,
253 | class_prompts_ratio: float = 0.5,
254 | resolution: int = 512,
255 | center_crop: bool = False,
256 | random_flip: bool = False,
257 | use_placeholder: bool = False,
258 | examples_per_class: int = -1,
259 | **kwargs,
260 | ):
261 |
262 | super().__init__()
263 | if split == "train":
264 | data_mat = loadmat(os.path.join(image_dir, "train_list.mat"))
265 | else:
266 | data_mat = loadmat(os.path.join(image_dir, "test_list.mat"))
267 |
268 | image_files = [
269 | os.path.join(image_dir, "Images", i[0][0]) for i in data_mat["file_list"]
270 | ]
271 | imagelabels = data_mat["labels"].squeeze()
272 | class_to_images = defaultdict(list)
273 | random.seed(seed)
274 | np.random.seed(seed)
275 | if max_train_samples is not None and max_train_samples > 0:
276 | dataset = dataset.shuffle(seed=seed).select(range(max_train_samples))
277 | class_to_images = defaultdict(list)
278 |
279 | for image_idx, image_path in enumerate(image_files):
280 | class_name = self.class_names[imagelabels[image_idx] - 1]
281 | class_to_images[class_name].append(image_path)
282 |
283 | rng = np.random.default_rng(seed)
284 | class_to_ids = {
285 | key: rng.permutation(len(class_to_images[key])) for key in self.class_names
286 | }
287 |
288 | # Split 0.9/0.1 as Train/Test subset
289 | class_to_ids = {
290 | key: np.array_split(class_to_ids[key], [int(0.5 * len(class_to_ids[key]))])[
291 | 0 if split == "train" else 1
292 | ]
293 | for key in self.class_names
294 | }
295 | if examples_per_class is not None and examples_per_class > 0:
296 | class_to_ids = {
297 | key: ids[:examples_per_class] for key, ids in class_to_ids.items()
298 | }
299 | self.class_to_images = {
300 | key: [class_to_images[key][i] for i in ids]
301 | for key, ids in class_to_ids.items()
302 | }
303 | self.all_images = sum(
304 | [self.class_to_images[key] for key in self.class_names], []
305 | )
306 | self.class2label = {key: i for i, key in enumerate(self.class_names)}
307 | self.label2class = {v: k for k, v in self.class2label.items()}
308 | self.all_labels = [
309 | i
310 | for i, key in enumerate(self.class_names)
311 | for _ in self.class_to_images[key]
312 | ]
313 |
314 | self.label_to_indices = defaultdict(list)
315 | for i, label in enumerate(self.all_labels):
316 | self.label_to_indices[label].append(i)
317 |
318 | self.num_classes = len(self.class_names)
319 | self.class_prompts_ratio = class_prompts_ratio
320 | self.use_placeholder = use_placeholder
321 | self.name2placeholder = {}
322 | self.placeholder2name = {}
323 |
324 | self.transform = transforms.Compose(
325 | [
326 | transforms.Resize(
327 | resolution, interpolation=transforms.InterpolationMode.BILINEAR
328 | ),
329 | (
330 | transforms.CenterCrop(resolution)
331 | if center_crop
332 | else transforms.RandomCrop(resolution)
333 | ),
334 | (
335 | transforms.RandomHorizontalFlip()
336 | if random_flip
337 | else transforms.Lambda(lambda x: x)
338 | ),
339 | transforms.ToTensor(),
340 | transforms.Normalize([0.5], [0.5]),
341 | ]
342 | )
343 |
344 | def __len__(self):
345 |
346 | return len(self.all_images)
347 |
348 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
349 |
350 | image = self.get_image_by_idx(idx)
351 | prompt = self.get_prompt_by_idx(idx)
352 | # label = self.get_label_by_idx(idx)
353 |
354 | return dict(pixel_values=self.transform(image), caption=prompt)
355 |
356 | def get_image_by_idx(self, idx: int) -> Image.Image:
357 |
358 | return Image.open(self.all_images[idx]).convert("RGB")
359 |
360 | def get_label_by_idx(self, idx: int) -> int:
361 |
362 | return self.all_labels[idx]
363 |
364 | def get_prompt_by_idx(self, idx: int) -> int:
365 | # randomly choose from class name or description
366 |
367 | if self.use_placeholder:
368 | content = (
369 | self.name2placeholder[self.label2class[self.get_label_by_idx(idx)]]
370 | + f"{self.super_class_name}"
371 | )
372 | else:
373 | content = self.label2class[self.get_label_by_idx(idx)]
374 | prompt = random.choice(IMAGENET_TEMPLATES_TINY).format(content)
375 |
376 | return prompt
377 |
378 | def get_metadata_by_idx(self, idx: int) -> dict:
379 | return dict(name=self.class_names[self.all_labels[idx]])
380 |
381 |
382 | if __name__ == "__main__":
383 |
384 | ds2 = StanfordDogDataset()
385 | # ds = DogDatasetForT2I(return_onehot=True)
386 | print(ds2.class_to_images == ds2.class_to_images)
387 | print("l")
388 |
--------------------------------------------------------------------------------
/dataset/instance/food.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import defaultdict
3 | from typing import Any, Dict, Tuple
4 |
5 | import numpy as np
6 | import torch
7 | import torchvision.transforms as transforms
8 | from datasets import load_dataset
9 | from PIL import Image
10 |
11 | from dataset.base import HugFewShotDataset
12 | from dataset.template import IMAGENET_TEMPLATES_TINY
13 |
14 | SUPER_CLASS_NAME = "food"
15 | HUG_LOCAL_IMAGE_TRAIN_DIR = r"Multimodal-Fatima/Food101_train"
16 | HUG_LOCAL_IMAGE_TEST_DIR = r"Multimodal-Fatima/Food101_test"
17 |
18 |
19 | class FoodHugDataset(HugFewShotDataset):
20 | super_class_name = SUPER_CLASS_NAME
21 |
22 | def __init__(
23 | self,
24 | *args,
25 | split: str = "train",
26 | seed: int = 0,
27 | image_train_dir: str = HUG_LOCAL_IMAGE_TRAIN_DIR,
28 | image_test_dir: str = HUG_LOCAL_IMAGE_TEST_DIR,
29 | examples_per_class: int = -1,
30 | synthetic_probability: float = 0.5,
31 | return_onehot: bool = False,
32 | soft_scaler: float = 0.9,
33 | synthetic_dir: str = None,
34 | image_size: int = 512,
35 | crop_size: int = 448,
36 | **kwargs,
37 | ):
38 |
39 | super(FoodHugDataset, self).__init__(
40 | *args,
41 | split=split,
42 | examples_per_class=examples_per_class,
43 | synthetic_probability=synthetic_probability,
44 | return_onehot=return_onehot,
45 | soft_scaler=soft_scaler,
46 | synthetic_dir=synthetic_dir,
47 | image_size=image_size,
48 | crop_size=crop_size,
49 | **kwargs,
50 | )
51 |
52 | if split == "train":
53 | dataset = load_dataset(image_train_dir, split="train")
54 | else:
55 | dataset = load_dataset(image_test_dir, split="test")
56 |
57 | self.class_names = [
58 | name.replace("/", " ") for name in dataset.features["label"].names
59 | ]
60 |
61 | random.seed(seed)
62 | np.random.seed(seed)
63 | if examples_per_class is not None and examples_per_class > 0:
64 | all_labels = dataset["label"]
65 | label_to_indices = defaultdict(list)
66 | for i, label in enumerate(all_labels):
67 | label_to_indices[label].append(i)
68 |
69 | _all_indices = []
70 | for key, items in label_to_indices.items():
71 | try:
72 | sampled_indices = random.sample(items, examples_per_class)
73 | except ValueError:
74 | print(
75 | f"{key}: Sample larger than population or is negative, use random.choices instead"
76 | )
77 | sampled_indices = random.choices(items, k=examples_per_class)
78 |
79 | label_to_indices[key] = sampled_indices
80 | _all_indices.extend(sampled_indices)
81 | dataset = dataset.select(_all_indices)
82 |
83 | self.dataset = dataset
84 | class2label = self.dataset.features["label"]._str2int
85 | self.class2label = {k.replace("/", " "): v for k, v in class2label.items()}
86 | self.label2class = {v: k.replace("/", " ") for k, v in class2label.items()}
87 | self.class_names = [
88 | name.replace("/", " ") for name in dataset.features["label"].names
89 | ]
90 | self.num_classes = len(self.class_names)
91 |
92 | self.label_to_indices = defaultdict(list)
93 | for i, label in enumerate(self.dataset["label"]):
94 | self.label_to_indices[label].append(i)
95 |
96 | def __len__(self):
97 |
98 | return len(self.dataset)
99 |
100 | def get_image_by_idx(self, idx: int) -> Image.Image:
101 |
102 | return self.dataset[idx]["image"].convert("RGB")
103 |
104 | def get_label_by_idx(self, idx: int) -> int:
105 |
106 | return self.dataset[idx]["label"]
107 |
108 | def get_metadata_by_idx(self, idx: int) -> dict:
109 |
110 | return dict(
111 | name=self.label2class[self.get_label_by_idx(idx)],
112 | super_class=self.super_class_name,
113 | )
114 |
115 |
116 | class FoodHugDatasetForT2I(torch.utils.data.Dataset):
117 | super_class_name = SUPER_CLASS_NAME
118 |
119 | def __init__(
120 | self,
121 | *args,
122 | split: str = "train",
123 | seed: int = 0,
124 | image_train_dir: str = HUG_LOCAL_IMAGE_TRAIN_DIR,
125 | max_train_samples: int = -1,
126 | class_prompts_ratio: float = 0.5,
127 | resolution: int = 512,
128 | center_crop: bool = False,
129 | random_flip: bool = False,
130 | use_placeholder: bool = False,
131 | **kwargs,
132 | ):
133 |
134 | super().__init__()
135 |
136 | dataset = load_dataset(image_train_dir, split="train")
137 |
138 | random.seed(seed)
139 | np.random.seed(seed)
140 | if max_train_samples is not None and max_train_samples > 0:
141 | dataset = dataset.shuffle(seed=seed).select(range(max_train_samples))
142 |
143 | self.dataset = dataset
144 | class2label = self.dataset.features["label"]._str2int
145 | self.class2label = {k.replace("/", " "): v for k, v in class2label.items()}
146 | self.label2class = {v: k.replace("/", " ") for k, v in class2label.items()}
147 | self.class_names = [
148 | name.replace("/", " ") for name in dataset.features["label"].names
149 | ]
150 | self.num_classes = len(self.class_names)
151 | self.use_placeholder = use_placeholder
152 | self.name2placeholder = {}
153 | self.placeholder2name = {}
154 | self.label_to_indices = defaultdict(list)
155 | for i, label in enumerate(self.dataset["label"]):
156 | self.label_to_indices[label].append(i)
157 |
158 | self.transform = transforms.Compose(
159 | [
160 | transforms.Resize(
161 | resolution, interpolation=transforms.InterpolationMode.BILINEAR
162 | ),
163 | (
164 | transforms.CenterCrop(resolution)
165 | if center_crop
166 | else transforms.RandomCrop(resolution)
167 | ),
168 | (
169 | transforms.RandomHorizontalFlip()
170 | if random_flip
171 | else transforms.Lambda(lambda x: x)
172 | ),
173 | transforms.ToTensor(),
174 | transforms.Normalize([0.5], [0.5]),
175 | ]
176 | )
177 |
178 | def __len__(self):
179 |
180 | return len(self.dataset)
181 |
182 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
183 |
184 | image = self.get_image_by_idx(idx)
185 | prompt = self.get_prompt_by_idx(idx)
186 | # label = self.get_label_by_idx(idx)
187 |
188 | return dict(pixel_values=self.transform(image), caption=prompt)
189 |
190 | def get_image_by_idx(self, idx: int) -> Image.Image:
191 |
192 | return self.dataset[idx]["image"].convert("RGB")
193 |
194 | def get_label_by_idx(self, idx: int) -> int:
195 |
196 | return self.dataset[idx]["label"]
197 |
198 | def get_prompt_by_idx(self, idx: int) -> int:
199 | # randomly choose from class name or description
200 | if self.use_placeholder:
201 | content = (
202 | self.name2placeholder[self.label2class[self.dataset[idx]["label"]]]
203 | + f" {self.super_class_name}"
204 | )
205 | else:
206 | content = self.label2class[self.dataset[idx]["label"]]
207 | prompt = random.choice(IMAGENET_TEMPLATES_TINY).format(content)
208 |
209 | return prompt
210 |
211 | def get_metadata_by_idx(self, idx: int) -> dict:
212 |
213 | return dict(name=self.label2class[self.get_label_by_idx(idx)])
214 |
215 |
216 | if __name__ == "__main__":
217 | ds_train = FoodHugDataset()
218 | print(ds_train[0])
219 |
--------------------------------------------------------------------------------
/dataset/instance/pascal.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from collections import defaultdict
4 | from typing import Any, Dict, Tuple
5 |
6 | import numpy as np
7 | import torch
8 | import torchvision.transforms as transforms
9 | from datasets import load_dataset
10 | from PIL import Image
11 |
12 | from dataset.base import HugFewShotDataset
13 | from dataset.template import IMAGENET_TEMPLATES_TINY
14 |
15 | PASCAL_DIR = "/data/zhicai/datasets/VOCdevkit/VOC2012/"
16 |
17 | TRAIN_IMAGE_SET = os.path.join(PASCAL_DIR, "ImageSets/Segmentation/train.txt")
18 | VAL_IMAGE_SET = os.path.join(PASCAL_DIR, "ImageSets/Segmentation/val.txt")
19 | DEFAULT_IMAGE_DIR = os.path.join(PASCAL_DIR, "JPEGImages")
20 | DEFAULT_LABEL_DIR = os.path.join(PASCAL_DIR, "SegmentationClass")
21 | DEFAULT_INSTANCE_DIR = os.path.join(PASCAL_DIR, "SegmentationObject")
22 |
23 | SUPER_CLASS_NAME = ""
24 | CLASS_NAME = [
25 | "airplane",
26 | "bicycle",
27 | "bird",
28 | "boat",
29 | "bottle",
30 | "bus",
31 | "car",
32 | "cat",
33 | "chair",
34 | "cow",
35 | "dining table",
36 | "dog",
37 | "horse",
38 | "motorcycle",
39 | "person",
40 | "potted plant",
41 | "sheep",
42 | "sofa",
43 | "train",
44 | "television",
45 | ]
46 |
47 |
48 | class PascalDataset(HugFewShotDataset):
49 |
50 | class_names = CLASS_NAME
51 | num_classes: int = len(class_names)
52 | super_class_name = SUPER_CLASS_NAME
53 |
54 | def __init__(
55 | self,
56 | *args,
57 | split: str = "train",
58 | seed: int = 0,
59 | image_dir: str = DEFAULT_IMAGE_DIR,
60 | examples_per_class: int = None,
61 | synthetic_probability: float = 0.5,
62 | return_onehot: bool = False,
63 | soft_scaler: float = 0.9,
64 | synthetic_dir: str = None,
65 | image_size: int = 512,
66 | crop_size: int = 448,
67 | **kwargs,
68 | ):
69 |
70 | super().__init__(
71 | *args,
72 | split=split,
73 | examples_per_class=examples_per_class,
74 | synthetic_probability=synthetic_probability,
75 | return_onehot=return_onehot,
76 | soft_scaler=soft_scaler,
77 | synthetic_dir=synthetic_dir,
78 | image_size=image_size,
79 | crop_size=crop_size,
80 | **kwargs,
81 | )
82 |
83 | image_set = {"train": TRAIN_IMAGE_SET, "val": VAL_IMAGE_SET}[split]
84 |
85 | with open(image_set, "r") as f:
86 | image_set_lines = [x.strip() for x in f.readlines()]
87 |
88 | class_to_images = defaultdict(list)
89 | class_to_annotations = defaultdict(list)
90 |
91 | for image_id in image_set_lines:
92 |
93 | labels = os.path.join(DEFAULT_LABEL_DIR, image_id + ".png")
94 | instances = os.path.join(DEFAULT_INSTANCE_DIR, image_id + ".png")
95 |
96 | labels = np.asarray(Image.open(labels))
97 | instances = np.asarray(Image.open(instances))
98 |
99 | instance_ids, pixel_loc, counts = np.unique(
100 | instances, return_index=True, return_counts=True
101 | )
102 |
103 | counts[0] = counts[-1] = 0 # remove background
104 |
105 | argmax_index = counts.argmax()
106 |
107 | mask = np.equal(instances, instance_ids[argmax_index])
108 | class_name = self.class_names[labels.flat[pixel_loc[argmax_index]] - 1]
109 |
110 | class_to_images[class_name].append(
111 | os.path.join(image_dir, image_id + ".jpg")
112 | )
113 | class_to_annotations[class_name].append(dict(mask=mask))
114 |
115 | rng = np.random.default_rng(seed)
116 | class_to_ids = {
117 | key: rng.permutation(len(class_to_images[key])) for key in self.class_names
118 | }
119 |
120 | if examples_per_class is not None and examples_per_class > 0:
121 | class_to_ids = {
122 | key: ids[:examples_per_class] for key, ids in class_to_ids.items()
123 | }
124 |
125 | self.class_to_images = {
126 | key: [class_to_images[key][i] for i in ids]
127 | for key, ids in class_to_ids.items()
128 | }
129 |
130 | self.class_to_annotations = {
131 | key: [class_to_annotations[key][i] for i in ids]
132 | for key, ids in class_to_ids.items()
133 | }
134 |
135 | self.class2label = {key: i for i, key in enumerate(self.class_names)}
136 | self.label2class = {v: k for k, v in self.class2label.items()}
137 | self.all_images = sum(
138 | [self.class_to_images[key] for key in self.class_names], []
139 | )
140 | self.all_labels = [
141 | i
142 | for i, key in enumerate(self.class_names)
143 | for _ in self.class_to_images[key]
144 | ]
145 |
146 | self.label_to_indices = defaultdict(list)
147 | for i, label in enumerate(self.all_labels):
148 | self.label_to_indices[label].append(i)
149 |
150 | def __len__(self):
151 |
152 | return len(self.all_images)
153 |
154 | def get_image_by_idx(self, idx: int) -> Image.Image:
155 |
156 | return Image.open(self.all_images[idx]).convert("RGB")
157 |
158 | def get_label_by_idx(self, idx: int) -> int:
159 |
160 | return self.all_labels[idx]
161 |
162 | def get_metadata_by_idx(self, idx: int) -> dict:
163 |
164 | return dict(
165 | name=self.class_names[self.all_labels[idx]],
166 | super_class=self.super_class_name,
167 | )
168 |
169 |
170 | class PascalDatasetForT2I(torch.utils.data.Dataset):
171 | super_class_name = SUPER_CLASS_NAME
172 | class_names = CLASS_NAME
173 |
174 | def __init__(
175 | self,
176 | *args,
177 | split: str = "train",
178 | seed: int = 0,
179 | image_dir: str = DEFAULT_IMAGE_DIR,
180 | max_train_samples: int = -1,
181 | class_prompts_ratio: float = 0.5,
182 | resolution: int = 512,
183 | center_crop: bool = False,
184 | random_flip: bool = False,
185 | use_placeholder: bool = False,
186 | examples_per_class: int = -1,
187 | **kwargs,
188 | ):
189 |
190 | super().__init__()
191 | image_set = {"train": TRAIN_IMAGE_SET, "val": VAL_IMAGE_SET}[split]
192 |
193 | with open(image_set, "r") as f:
194 | image_set_lines = [x.strip() for x in f.readlines()]
195 |
196 | class_to_images = defaultdict(list)
197 | class_to_annotations = defaultdict(list)
198 |
199 | for image_id in image_set_lines:
200 |
201 | labels = os.path.join(DEFAULT_LABEL_DIR, image_id + ".png")
202 | instances = os.path.join(DEFAULT_INSTANCE_DIR, image_id + ".png")
203 |
204 | labels = np.asarray(Image.open(labels))
205 | instances = np.asarray(Image.open(instances))
206 |
207 | instance_ids, pixel_loc, counts = np.unique(
208 | instances, return_index=True, return_counts=True
209 | )
210 |
211 | counts[0] = counts[-1] = 0 # remove background
212 |
213 | argmax_index = counts.argmax()
214 |
215 | mask = np.equal(instances, instance_ids[argmax_index])
216 | class_name = self.class_names[labels.flat[pixel_loc[argmax_index]] - 1]
217 |
218 | class_to_images[class_name].append(
219 | os.path.join(image_dir, image_id + ".jpg")
220 | )
221 | class_to_annotations[class_name].append(dict(mask=mask))
222 |
223 | rng = np.random.default_rng(seed)
224 | class_to_ids = {
225 | key: rng.permutation(len(class_to_images[key])) for key in self.class_names
226 | }
227 |
228 | if examples_per_class is not None and examples_per_class > 0:
229 | class_to_ids = {
230 | key: ids[:examples_per_class] for key, ids in class_to_ids.items()
231 | }
232 |
233 | self.class_to_images = {
234 | key: [class_to_images[key][i] for i in ids]
235 | for key, ids in class_to_ids.items()
236 | }
237 |
238 | self.class_to_annotations = {
239 | key: [class_to_annotations[key][i] for i in ids]
240 | for key, ids in class_to_ids.items()
241 | }
242 |
243 | self.class2label = {key: i for i, key in enumerate(self.class_names)}
244 | self.label2class = {v: k for k, v in self.class2label.items()}
245 | self.all_images = sum(
246 | [self.class_to_images[key] for key in self.class_names], []
247 | )
248 | self.all_labels = [
249 | i
250 | for i, key in enumerate(self.class_names)
251 | for _ in self.class_to_images[key]
252 | ]
253 |
254 | self.label_to_indices = defaultdict(list)
255 | for i, label in enumerate(self.all_labels):
256 | self.label_to_indices[label].append(i)
257 | self.num_classes = len(self.class_names)
258 | self.class_prompts_ratio = class_prompts_ratio
259 | self.use_placeholder = use_placeholder
260 | self.name2placeholder = {}
261 | self.placeholder2name = {}
262 |
263 | self.transform = transforms.Compose(
264 | [
265 | transforms.Resize(
266 | resolution, interpolation=transforms.InterpolationMode.BILINEAR
267 | ),
268 | (
269 | transforms.CenterCrop(resolution)
270 | if center_crop
271 | else transforms.RandomCrop(resolution)
272 | ),
273 | (
274 | transforms.RandomHorizontalFlip()
275 | if random_flip
276 | else transforms.Lambda(lambda x: x)
277 | ),
278 | transforms.ToTensor(),
279 | transforms.Normalize([0.5], [0.5]),
280 | ]
281 | )
282 |
283 | def __len__(self):
284 |
285 | return len(self.all_images)
286 |
287 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
288 |
289 | image = self.get_image_by_idx(idx)
290 | prompt = self.get_prompt_by_idx(idx)
291 | # label = self.get_label_by_idx(idx)
292 |
293 | return dict(pixel_values=self.transform(image), caption=prompt)
294 |
295 | def get_image_by_idx(self, idx: int) -> Image.Image:
296 |
297 | return Image.open(self.all_images[idx]).convert("RGB")
298 |
299 | def get_label_by_idx(self, idx: int) -> int:
300 |
301 | return self.all_labels[idx]
302 |
303 | def get_prompt_by_idx(self, idx: int) -> int:
304 | # randomly choose from class name or description
305 |
306 | if self.use_placeholder:
307 | content = (
308 | self.name2placeholder[self.label2class[self.get_label_by_idx(idx)]]
309 | + f"{self.super_class_name}"
310 | )
311 | else:
312 | content = self.label2class[self.get_label_by_idx(idx)]
313 | prompt = random.choice(IMAGENET_TEMPLATES_TINY).format(content)
314 |
315 | return prompt
316 |
317 | def get_metadata_by_idx(self, idx: int) -> dict:
318 | return dict(name=self.class_names[self.all_labels[idx]])
319 |
320 |
321 | if __name__ == "__main__":
322 | ds = PascalDataset(return_onehot=True)
323 | ds2 = PascalDatasetForT2I()
324 | print(ds.class_to_images == ds2.class_to_images)
325 | print("l")
326 |
--------------------------------------------------------------------------------
/dataset/instance/pet.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import defaultdict
3 | from typing import Tuple
4 |
5 | import numpy as np
6 | import torch
7 | import torchvision.transforms as transforms
8 | from datasets import load_dataset
9 | from PIL import Image
10 |
11 | from dataset.base import HugFewShotDataset
12 | from dataset.template import IMAGENET_TEMPLATES_TINY
13 |
14 | SUPER_CLASS_NAME = "animal"
15 | HUG_LOCAL_IMAGE_TRAIN_DIR = r"jonathancui/oxford-pets"
16 | HUG_LOCAL_IMAGE_TEST_DIR = r"jonathancui/oxford-pets"
17 | DEFAULT_IMAGE_LOCAL_DIR = r"/home/zhicai/.cache/huggingface/local/pet"
18 | HUB_LOCAL_DIR = "/home/zhicai/.cache/huggingface/datasets/pcuenq___oxford-pets/default-7fcafd63f4da1c6c"
19 |
20 |
21 | class PetHugDataset(HugFewShotDataset):
22 |
23 | super_class_name = SUPER_CLASS_NAME
24 |
25 | def __init__(
26 | self,
27 | *args,
28 | split: str = "train",
29 | seed: int = 0,
30 | image_train_dir: str = HUG_LOCAL_IMAGE_TRAIN_DIR,
31 | image_test_dir: str = HUG_LOCAL_IMAGE_TEST_DIR,
32 | examples_per_class: int = -1,
33 | synthetic_probability: float = 0.5,
34 | return_onehot: bool = False,
35 | soft_scaler: float = 0.9,
36 | synthetic_dir: str = None,
37 | image_size: int = 512,
38 | crop_size: int = 448,
39 | **kwargs,
40 | ):
41 |
42 | super(PetHugDataset, self).__init__(
43 | *args,
44 | split=split,
45 | examples_per_class=examples_per_class,
46 | synthetic_probability=synthetic_probability,
47 | return_onehot=return_onehot,
48 | soft_scaler=soft_scaler,
49 | synthetic_dir=synthetic_dir,
50 | image_size=image_size,
51 | crop_size=crop_size,
52 | **kwargs,
53 | )
54 |
55 | if split == "train":
56 | dataset = load_dataset(HUB_LOCAL_DIR, split="train")
57 | else:
58 | dataset = load_dataset(HUB_LOCAL_DIR, split="test")
59 |
60 | random.seed(seed)
61 | np.random.seed(seed)
62 | if examples_per_class is not None and examples_per_class > 0:
63 | all_labels = dataset["label"]
64 | label_to_indices = defaultdict(list)
65 | for i, label in enumerate(all_labels):
66 | label_to_indices[label].append(i)
67 |
68 | _all_indices = []
69 | for key, items in label_to_indices.items():
70 | try:
71 | sampled_indices = random.sample(items, examples_per_class)
72 | except ValueError:
73 | print(
74 | f"{key}: Sample larger than population or is negative, use random.choices instead"
75 | )
76 | sampled_indices = random.choices(items, k=examples_per_class)
77 |
78 | label_to_indices[key] = sampled_indices
79 | _all_indices.extend(sampled_indices)
80 | dataset = dataset.select(_all_indices)
81 |
82 | self.dataset = dataset
83 | class2label = self.dataset.features["label"]._str2int
84 | self.class2label = {
85 | k.replace("/", " ").replace("_", " "): v for k, v in class2label.items()
86 | }
87 | self.label2class = {
88 | v: k.replace("/", " ").replace("_", " ") for k, v in class2label.items()
89 | }
90 | self.class_names = [
91 | name.replace("/", " ").replace("_", " ")
92 | for name in dataset.features["label"].names
93 | ]
94 | self.num_classes = len(self.class_names)
95 |
96 | self.label_to_indices = defaultdict(list)
97 | for i, label in enumerate(self.dataset["label"]):
98 | self.label_to_indices[label].append(i)
99 |
100 | def __len__(self):
101 |
102 | return len(self.dataset)
103 |
104 | def get_image_by_idx(self, idx: int) -> Image.Image:
105 |
106 | return self.dataset[idx]["image"].convert("RGB")
107 |
108 | def get_label_by_idx(self, idx: int) -> int:
109 |
110 | return self.dataset[idx]["label"]
111 |
112 | def get_metadata_by_idx(self, idx: int) -> dict:
113 |
114 | return dict(
115 | name=self.label2class[self.get_label_by_idx(idx)],
116 | super_class=self.super_class_name,
117 | )
118 |
119 |
120 | class PetHugDatasetForT2I(torch.utils.data.Dataset):
121 |
122 | super_class_name = SUPER_CLASS_NAME
123 |
124 | def __init__(
125 | self,
126 | *args,
127 | split: str = "train",
128 | seed: int = 0,
129 | image_train_dir: str = HUG_LOCAL_IMAGE_TRAIN_DIR,
130 | max_train_samples: int = -1,
131 | class_prompts_ratio: float = 0.5,
132 | resolution: int = 512,
133 | center_crop: bool = False,
134 | random_flip: bool = False,
135 | use_placeholder: bool = False,
136 | **kwargs,
137 | ):
138 |
139 | super().__init__()
140 |
141 | dataset = load_dataset(HUB_LOCAL_DIR, split="train")
142 | dataset = dataset.class_encode_column("label")
143 |
144 | random.seed(seed)
145 | np.random.seed(seed)
146 | if max_train_samples is not None and max_train_samples > 0:
147 | dataset = dataset.shuffle(seed=seed).select(range(max_train_samples))
148 |
149 | self.dataset = dataset
150 | class2label = self.dataset.features["label"]._str2int
151 | self.class2label = {k.replace("/", " "): v for k, v in class2label.items()}
152 | self.label2class = {v: k.replace("/", " ") for k, v in class2label.items()}
153 | self.class_names = [
154 | name.replace("/", " ") for name in dataset.features["label"].names
155 | ]
156 | self.num_classes = len(self.class_names)
157 | self.use_placeholder = use_placeholder
158 | self.name2placeholder = {}
159 | self.placeholder2name = {}
160 | self.label_to_indices = defaultdict(list)
161 | for i, label in enumerate(self.dataset["label"]):
162 | self.label_to_indices[label].append(i)
163 |
164 | self.transform = transforms.Compose(
165 | [
166 | transforms.Resize(
167 | resolution, interpolation=transforms.InterpolationMode.BILINEAR
168 | ),
169 | (
170 | transforms.CenterCrop(resolution)
171 | if center_crop
172 | else transforms.RandomCrop(resolution)
173 | ),
174 | (
175 | transforms.RandomHorizontalFlip()
176 | if random_flip
177 | else transforms.Lambda(lambda x: x)
178 | ),
179 | transforms.ToTensor(),
180 | transforms.Normalize([0.5], [0.5]),
181 | ]
182 | )
183 |
184 | def __len__(self):
185 |
186 | return len(self.dataset)
187 |
188 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
189 |
190 | image = self.get_image_by_idx(idx)
191 | prompt = self.get_prompt_by_idx(idx)
192 | # label = self.get_label_by_idx(idx)
193 |
194 | return dict(pixel_values=self.transform(image), caption=prompt)
195 |
196 | def get_image_by_idx(self, idx: int) -> Image.Image:
197 |
198 | return self.dataset[idx]["image"].convert("RGB")
199 |
200 | def get_label_by_idx(self, idx: int) -> int:
201 |
202 | return self.dataset[idx]["label"]
203 |
204 | def get_prompt_by_idx(self, idx: int) -> int:
205 | # randomly choose from class name or description
206 | if self.use_placeholder:
207 | content = (
208 | self.name2placeholder[self.label2class[self.dataset[idx]["label"]]]
209 | + f" {self.super_class_name}"
210 | )
211 | else:
212 | content = self.label2class[self.dataset[idx]["label"]]
213 | prompt = random.choice(IMAGENET_TEMPLATES_TINY).format(content)
214 |
215 | return prompt
216 |
217 | def get_metadata_by_idx(self, idx: int) -> dict:
218 |
219 | return dict(name=self.label2class[self.get_label_by_idx(idx)])
220 |
--------------------------------------------------------------------------------
/dataset/instance/waterbird.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import torch
6 | import torchvision.transforms as transforms
7 | from PIL import Image
8 | from torch.utils.data import Dataset
9 |
10 | DATA_DIR = "/data/zhicai/datasets/waterbird_complete95_forest2water2/"
11 |
12 |
13 | def onehot(size: int, target: int):
14 | vec = torch.zeros(size, dtype=torch.float32)
15 | vec[target] = 1.0
16 | return vec
17 |
18 |
19 | class WaterBird(Dataset):
20 | def __init__(self, split=2, image_size=256, crop_size=224, return_onehot=False):
21 | self.root_dir = DATA_DIR
22 | dataframe = pd.read_csv(os.path.join(self.root_dir, "metadata.csv"))
23 | dataframe = dataframe[dataframe["split"] == split].reset_index()
24 | self.labels = list(
25 | map(lambda x: int(x.split(".")[0]) - 1, dataframe["img_filename"])
26 | )
27 | self.dataframe = dataframe
28 | self.image_paths = dataframe["img_filename"]
29 | self.groups = dataframe.apply(
30 | lambda row: f"{row['y']}{row['place']}", axis=1
31 | ).tolist()
32 | self.return_onehot = return_onehot
33 | self.num_classes = len(set(self.labels))
34 | self.transform = transforms.Compose(
35 | [
36 | transforms.Resize((image_size, image_size)),
37 | transforms.CenterCrop((crop_size, crop_size)),
38 | transforms.ToTensor(),
39 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
40 | ]
41 | )
42 |
43 | def __len__(self):
44 | return len(self.dataframe)
45 |
46 | def __getitem__(self, idx):
47 |
48 | img_path = self.image_paths[idx]
49 | # Load image
50 | img = Image.open(os.path.join(self.root_dir, img_path)).convert("RGB")
51 |
52 | label = self.labels[idx]
53 | group = self.groups[idx]
54 | if self.transform:
55 | img = self.transform(img)
56 | if self.return_onehot:
57 | if isinstance(label, (int, np.int64)):
58 | label = onehot(self.num_classes, label)
59 | return img, label, group
60 |
--------------------------------------------------------------------------------
/dataset/template.py:
--------------------------------------------------------------------------------
1 | IMAGENET_TEMPLATES_TINY = [
2 | "a photo of a {}",
3 | ]
4 |
5 | IMAGENET_TEMPLATES_SMALL = [
6 | "a photo of a {}",
7 | "a rendering of a {}",
8 | "a cropped photo of the {}",
9 | "the photo of a {}",
10 | "a photo of a clean {}",
11 | "a photo of a dirty {}",
12 | "a dark photo of the {}",
13 | "a photo of my {}",
14 | "a photo of the cool {}",
15 | "a close-up photo of a {}",
16 | "a bright photo of the {}",
17 | "a cropped photo of a {}",
18 | "a photo of the {}",
19 | "a good photo of the {}",
20 | "a photo of one {}",
21 | "a close-up photo of the {}",
22 | "a rendition of the {}",
23 | "a photo of the clean {}",
24 | "a rendition of a {}",
25 | "a photo of a nice {}",
26 | "a good photo of a {}",
27 | "a photo of the nice {}",
28 | "a photo of the small {}",
29 | "a photo of the weird {}",
30 | "a photo of the large {}",
31 | "a photo of a cool {}",
32 | "a photo of a small {}",
33 | ]
34 |
35 | IMAGENET_STYLE_TEMPLATES_SMALL = [
36 | "a painting in the style of {}",
37 | "a rendering in the style of {}",
38 | "a cropped painting in the style of {}",
39 | "the painting in the style of {}",
40 | "a clean painting in the style of {}",
41 | "a dirty painting in the style of {}",
42 | "a dark painting in the style of {}",
43 | "a picture in the style of {}",
44 | "a cool painting in the style of {}",
45 | "a close-up painting in the style of {}",
46 | "a bright painting in the style of {}",
47 | "a cropped painting in the style of {}",
48 | "a good painting in the style of {}",
49 | "a close-up painting in the style of {}",
50 | "a rendition in the style of {}",
51 | "a nice painting in the style of {}",
52 | "a small painting in the style of {}",
53 | "a weird painting in the style of {}",
54 | "a large painting in the style of {}",
55 | ]
56 |
--------------------------------------------------------------------------------
/downstream_tasks/imb_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhicaiwww/Diff-Mix/a81337b1492bcc7f7a8a61921836e94191a3a0ef/downstream_tasks/imb_utils/__init__.py
--------------------------------------------------------------------------------
/downstream_tasks/imb_utils/autoaug.py:
--------------------------------------------------------------------------------
1 | """
2 | Code from: https://github.com/dvlab-research/Parametric-Contrastive-Learning/blob/main/LT/autoaug.py
3 | """
4 |
5 | from PIL import Image, ImageEnhance, ImageOps
6 | import numpy as np
7 | import random
8 | import torch
9 |
10 |
11 | class Cutout(object):
12 | def __init__(self, n_holes, length):
13 | self.n_holes = n_holes
14 | self.length = length
15 |
16 | def __call__(self, img):
17 | h = img.size(1)
18 | w = img.size(2)
19 |
20 | mask = np.ones((h, w), np.float32)
21 |
22 | for n in range(self.n_holes):
23 | y = np.random.randint(h)
24 | x = np.random.randint(w)
25 |
26 | y1 = np.clip(y - self.length // 2, 0, h)
27 | y2 = np.clip(y + self.length // 2, 0, h)
28 | x1 = np.clip(x - self.length // 2, 0, w)
29 | x2 = np.clip(x + self.length // 2, 0, w)
30 |
31 | mask[y1: y2, x1: x2] = 0.
32 |
33 | mask = torch.from_numpy(mask)
34 | mask = mask.expand_as(img)
35 | img = img * mask
36 |
37 | return img
38 |
39 | class ImageNetPolicy(object):
40 | """ Randomly choose one of the best 24 Sub-policies on ImageNet.
41 | Example:
42 | >>> policy = ImageNetPolicy()
43 | >>> transformed = policy(image)
44 | Example as a PyTorch Transform:
45 | >>> transform=transforms.Compose([
46 | >>> transforms.Resize(256),
47 | >>> ImageNetPolicy(),
48 | >>> transforms.ToTensor()])
49 | """
50 | def __init__(self, fillcolor=(128, 128, 128)):
51 | self.policies = [
52 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
53 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
54 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
55 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
56 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
57 |
58 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
59 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
60 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
61 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
62 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
63 |
64 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
65 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
66 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
67 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
68 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
69 |
70 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
71 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
72 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
73 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
74 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
75 |
76 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
77 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
78 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
79 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
80 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
81 | ]
82 |
83 |
84 | def __call__(self, img):
85 | policy_idx = random.randint(0, len(self.policies) - 1)
86 | return self.policies[policy_idx](img)
87 |
88 | def __repr__(self):
89 | return "AutoAugment ImageNet Policy"
90 |
91 |
92 | class CIFAR10Policy(object):
93 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10.
94 | Example:
95 | >>> policy = CIFAR10Policy()
96 | >>> transformed = policy(image)
97 | Example as a PyTorch Transform:
98 | >>> transform=transforms.Compose([
99 | >>> transforms.Resize(256),
100 | >>> CIFAR10Policy(),
101 | >>> transforms.ToTensor()])
102 | """
103 | def __init__(self, fillcolor=(128, 128, 128)):
104 | self.policies = [
105 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
106 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
107 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
108 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
109 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
110 |
111 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
112 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
113 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
114 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
115 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
116 |
117 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
118 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
119 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
120 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
121 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
122 |
123 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
124 | SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
125 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
126 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
127 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
128 |
129 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
130 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
131 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
132 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
133 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
134 | ]
135 |
136 |
137 | def __call__(self, img):
138 | policy_idx = random.randint(0, len(self.policies) - 1)
139 | return self.policies[policy_idx](img)
140 |
141 | def __repr__(self):
142 | return "AutoAugment CIFAR10 Policy"
143 |
144 |
145 | class SVHNPolicy(object):
146 | """ Randomly choose one of the best 25 Sub-policies on SVHN.
147 | Example:
148 | >>> policy = SVHNPolicy()
149 | >>> transformed = policy(image)
150 | Example as a PyTorch Transform:
151 | >>> transform=transforms.Compose([
152 | >>> transforms.Resize(256),
153 | >>> SVHNPolicy(),
154 | >>> transforms.ToTensor()])
155 | """
156 | def __init__(self, fillcolor=(128, 128, 128)):
157 | self.policies = [
158 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
159 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
160 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
161 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
162 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
163 |
164 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
165 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
166 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
167 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
168 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
169 |
170 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
171 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
172 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
173 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
174 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
175 |
176 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
177 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
178 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
179 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
180 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
181 |
182 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
183 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
184 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
185 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
186 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
187 | ]
188 |
189 |
190 | def __call__(self, img):
191 | policy_idx = random.randint(0, len(self.policies) - 1)
192 | return self.policies[policy_idx](img)
193 |
194 | def __repr__(self):
195 | return "AutoAugment SVHN Policy"
196 |
197 |
198 | class SubPolicy(object):
199 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
200 | ranges = {
201 | "shearX": np.linspace(0, 0.3, 10),
202 | "shearY": np.linspace(0, 0.3, 10),
203 | "translateX": np.linspace(0, 150 / 331, 10),
204 | "translateY": np.linspace(0, 150 / 331, 10),
205 | "rotate": np.linspace(0, 30, 10),
206 | "color": np.linspace(0.0, 0.9, 10),
207 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
208 | "solarize": np.linspace(256, 0, 10),
209 | "contrast": np.linspace(0.0, 0.9, 10),
210 | "sharpness": np.linspace(0.0, 0.9, 10),
211 | "brightness": np.linspace(0.0, 0.9, 10),
212 | "autocontrast": [0] * 10,
213 | "equalize": [0] * 10,
214 | "invert": [0] * 10
215 | }
216 |
217 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
218 | def rotate_with_fill(img, magnitude):
219 | rot = img.convert("RGBA").rotate(magnitude)
220 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
221 |
222 | func = {
223 | "shearX": lambda img, magnitude: img.transform(
224 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
225 | Image.BICUBIC, fillcolor=fillcolor),
226 | "shearY": lambda img, magnitude: img.transform(
227 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
228 | Image.BICUBIC, fillcolor=fillcolor),
229 | "translateX": lambda img, magnitude: img.transform(
230 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
231 | fillcolor=fillcolor),
232 | "translateY": lambda img, magnitude: img.transform(
233 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
234 | fillcolor=fillcolor),
235 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
236 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
237 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
238 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
239 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
240 | 1 + magnitude * random.choice([-1, 1])),
241 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
242 | 1 + magnitude * random.choice([-1, 1])),
243 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
244 | 1 + magnitude * random.choice([-1, 1])),
245 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
246 | "equalize": lambda img, magnitude: ImageOps.equalize(img),
247 | "invert": lambda img, magnitude: ImageOps.invert(img)
248 | }
249 |
250 | self.p1 = p1
251 | self.operation1 = func[operation1]
252 | self.magnitude1 = ranges[operation1][magnitude_idx1]
253 | self.p2 = p2
254 | self.operation2 = func[operation2]
255 | self.magnitude2 = ranges[operation2][magnitude_idx2]
256 |
257 |
258 | def __call__(self, img):
259 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1)
260 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2)
261 | return img
--------------------------------------------------------------------------------
/downstream_tasks/imb_utils/moco_loader.py:
--------------------------------------------------------------------------------
1 | """https://github.com/facebookresearch/moco"""
2 |
3 | import random
4 |
5 | from PIL import ImageFilter
6 |
7 |
8 | class TwoCropsTransform:
9 | """Take two random crops of one image as the query and key."""
10 |
11 | def __init__(self, base_transform):
12 | self.base_transform = base_transform
13 |
14 | def __call__(self, x):
15 | q = self.base_transform(x)
16 | k = self.base_transform(x)
17 | return [q, k]
18 |
19 |
20 | class GaussianBlur(object):
21 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
22 |
23 | def __init__(self, sigma=[0.1, 2.0]):
24 | self.sigma = sigma
25 |
26 | def __call__(self, x):
27 | sigma = random.uniform(self.sigma[0], self.sigma[1])
28 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
29 | return x
30 |
--------------------------------------------------------------------------------
/downstream_tasks/imb_utils/randaugment.py:
--------------------------------------------------------------------------------
1 | # original code: https://github.com/dvlab-research/Parametric-Contrastive-Learning/blob/main/LT/randaugment.py
2 | """ AutoAugment and RandAugment
3 | Implementation adapted from:
4 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
5 | Papers: https://arxiv.org/abs/1805.09501, https://arxiv.org/abs/1906.11172, and https://arxiv.org/abs/1909.13719
6 | Hacked together by Ross Wightman
7 | """
8 | import random
9 | import math
10 | import re
11 | from PIL import Image, ImageOps, ImageEnhance
12 | import PIL
13 | import numpy as np
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torchvision.transforms as transforms
18 |
19 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
20 |
21 | _FILL = (128, 128, 128)
22 |
23 | # This signifies the max integer that the controller RNN could predict for the
24 | # augmentation scheme.
25 | _MAX_LEVEL = 10.
26 |
27 | _HPARAMS_DEFAULT = dict(
28 | translate_const=250,
29 | img_mean=_FILL,
30 | )
31 |
32 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
33 |
34 |
35 | def _interpolation(kwargs):
36 | interpolation = kwargs.pop('resample', Image.BILINEAR)
37 | if isinstance(interpolation, (list, tuple)):
38 | return random.choice(interpolation)
39 | else:
40 | return interpolation
41 |
42 |
43 | def _check_args_tf(kwargs):
44 | if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
45 | kwargs.pop('fillcolor')
46 | kwargs['resample'] = _interpolation(kwargs)
47 |
48 |
49 | def shear_x(img, factor, **kwargs):
50 | _check_args_tf(kwargs)
51 | return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
52 |
53 |
54 | def shear_y(img, factor, **kwargs):
55 | _check_args_tf(kwargs)
56 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
57 |
58 |
59 | def translate_x_rel(img, pct, **kwargs):
60 | pixels = pct * img.size[0]
61 | _check_args_tf(kwargs)
62 | return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
63 |
64 |
65 | def translate_y_rel(img, pct, **kwargs):
66 | pixels = pct * img.size[1]
67 | _check_args_tf(kwargs)
68 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
69 |
70 |
71 | def translate_x_abs(img, pixels, **kwargs):
72 | _check_args_tf(kwargs)
73 | return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
74 |
75 |
76 | def translate_y_abs(img, pixels, **kwargs):
77 | _check_args_tf(kwargs)
78 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
79 |
80 |
81 | def rotate(img, degrees, **kwargs):
82 | _check_args_tf(kwargs)
83 | if _PIL_VER >= (5, 2):
84 | return img.rotate(degrees, **kwargs)
85 | elif _PIL_VER >= (5, 0):
86 | w, h = img.size
87 | post_trans = (0, 0)
88 | rotn_center = (w / 2.0, h / 2.0)
89 | angle = -math.radians(degrees)
90 | matrix = [
91 | round(math.cos(angle), 15),
92 | round(math.sin(angle), 15),
93 | 0.0,
94 | round(-math.sin(angle), 15),
95 | round(math.cos(angle), 15),
96 | 0.0,
97 | ]
98 |
99 | def transform(x, y, matrix):
100 | (a, b, c, d, e, f) = matrix
101 | return a * x + b * y + c, d * x + e * y + f
102 |
103 | matrix[2], matrix[5] = transform(
104 | -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
105 | )
106 | matrix[2] += rotn_center[0]
107 | matrix[5] += rotn_center[1]
108 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
109 | else:
110 | return img.rotate(degrees, resample=kwargs['resample'])
111 |
112 |
113 | def auto_contrast(img, **__):
114 | return ImageOps.autocontrast(img)
115 |
116 |
117 | def invert(img, **__):
118 | return ImageOps.invert(img)
119 |
120 |
121 | def identity(img, **__):
122 | return img
123 |
124 |
125 | def equalize(img, **__):
126 | return ImageOps.equalize(img)
127 |
128 |
129 | def solarize(img, thresh, **__):
130 | return ImageOps.solarize(img, thresh)
131 |
132 |
133 | def solarize_add(img, add, thresh=128, **__):
134 | lut = []
135 | for i in range(256):
136 | if i < thresh:
137 | lut.append(min(255, i + add))
138 | else:
139 | lut.append(i)
140 | if img.mode in ("L", "RGB"):
141 | if img.mode == "RGB" and len(lut) == 256:
142 | lut = lut + lut + lut
143 | return img.point(lut)
144 | else:
145 | return img
146 |
147 |
148 | def posterize(img, bits_to_keep, **__):
149 | if bits_to_keep >= 8:
150 | return img
151 | return ImageOps.posterize(img, bits_to_keep)
152 |
153 |
154 | def contrast(img, factor, **__):
155 | return ImageEnhance.Contrast(img).enhance(factor)
156 |
157 |
158 | def color(img, factor, **__):
159 | return ImageEnhance.Color(img).enhance(factor)
160 |
161 |
162 | def brightness(img, factor, **__):
163 | return ImageEnhance.Brightness(img).enhance(factor)
164 |
165 |
166 | def sharpness(img, factor, **__):
167 | return ImageEnhance.Sharpness(img).enhance(factor)
168 |
169 |
170 | def _randomly_negate(v):
171 | """With 50% prob, negate the value"""
172 | return -v if random.random() > 0.5 else v
173 |
174 |
175 | def _rotate_level_to_arg(level, _hparams):
176 | # range [-30, 30]
177 | level = (level / _MAX_LEVEL) * 30.
178 | level = _randomly_negate(level)
179 | return level,
180 |
181 |
182 | def _enhance_level_to_arg(level, _hparams):
183 | # range [0.1, 1.9]
184 | return (level / _MAX_LEVEL) * 1.8 + 0.1,
185 |
186 |
187 | def _shear_level_to_arg(level, _hparams):
188 | # range [-0.3, 0.3]
189 | level = (level / _MAX_LEVEL) * 0.3
190 | level = _randomly_negate(level)
191 | return level,
192 |
193 |
194 | def _translate_abs_level_to_arg(level, hparams):
195 | translate_const = hparams['translate_const']
196 | level = (level / _MAX_LEVEL) * float(translate_const)
197 | level = _randomly_negate(level)
198 | return level,
199 |
200 |
201 | def _translate_rel_level_to_arg(level, _hparams):
202 | # range [-0.45, 0.45]
203 | level = (level / _MAX_LEVEL) * 0.45
204 | level = _randomly_negate(level)
205 | return level,
206 |
207 |
208 | def _posterize_original_level_to_arg(level, _hparams):
209 | # As per original AutoAugment paper description
210 | # range [4, 8], 'keep 4 up to 8 MSB of image'
211 | return int((level / _MAX_LEVEL) * 4) + 4,
212 |
213 |
214 | def _posterize_research_level_to_arg(level, _hparams):
215 | # As per Tensorflow models research and UDA impl
216 | # range [4, 0], 'keep 4 down to 0 MSB of original image'
217 | return 4 - int((level / _MAX_LEVEL) * 4),
218 |
219 |
220 | def _posterize_tpu_level_to_arg(level, _hparams):
221 | # As per Tensorflow TPU EfficientNet impl
222 | # range [0, 4], 'keep 0 up to 4 MSB of original image'
223 | return int((level / _MAX_LEVEL) * 4),
224 |
225 |
226 | def _solarize_level_to_arg(level, _hparams):
227 | # range [0, 256]
228 | return int((level / _MAX_LEVEL) * 256),
229 |
230 |
231 | def _solarize_add_level_to_arg(level, _hparams):
232 | # range [0, 110]
233 | return int((level / _MAX_LEVEL) * 110),
234 |
235 |
236 | LEVEL_TO_ARG = {
237 | 'AutoContrast': None,
238 | 'Equalize': None,
239 | 'Invert': None,
240 | 'Identity': None,
241 | 'Rotate': _rotate_level_to_arg,
242 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
243 | 'PosterizeOriginal': _posterize_original_level_to_arg,
244 | 'PosterizeResearch': _posterize_research_level_to_arg,
245 | 'PosterizeTpu': _posterize_tpu_level_to_arg,
246 | 'Solarize': _solarize_level_to_arg,
247 | 'SolarizeAdd': _solarize_add_level_to_arg,
248 | 'Color': _enhance_level_to_arg,
249 | 'Contrast': _enhance_level_to_arg,
250 | 'Brightness': _enhance_level_to_arg,
251 | 'Sharpness': _enhance_level_to_arg,
252 | 'ShearX': _shear_level_to_arg,
253 | 'ShearY': _shear_level_to_arg,
254 | 'TranslateX': _translate_abs_level_to_arg,
255 | 'TranslateY': _translate_abs_level_to_arg,
256 | 'TranslateXRel': _translate_rel_level_to_arg,
257 | 'TranslateYRel': _translate_rel_level_to_arg,
258 | }
259 |
260 |
261 | NAME_TO_OP = {
262 | 'AutoContrast': auto_contrast,
263 | 'Equalize': equalize,
264 | 'Invert': invert,
265 | 'Identity': identity,
266 | 'Rotate': rotate,
267 | 'PosterizeOriginal': posterize,
268 | 'PosterizeResearch': posterize,
269 | 'PosterizeTpu': posterize,
270 | 'Solarize': solarize,
271 | 'SolarizeAdd': solarize_add,
272 | 'Color': color,
273 | 'Contrast': contrast,
274 | 'Brightness': brightness,
275 | 'Sharpness': sharpness,
276 | 'ShearX': shear_x,
277 | 'ShearY': shear_y,
278 | 'TranslateX': translate_x_abs,
279 | 'TranslateY': translate_y_abs,
280 | 'TranslateXRel': translate_x_rel,
281 | 'TranslateYRel': translate_y_rel,
282 | }
283 |
284 |
285 | class AutoAugmentOp:
286 |
287 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
288 | hparams = hparams or _HPARAMS_DEFAULT
289 | self.aug_fn = NAME_TO_OP[name]
290 | self.level_fn = LEVEL_TO_ARG[name]
291 | self.prob = prob
292 | self.magnitude = magnitude
293 | self.hparams = hparams.copy()
294 | self.kwargs = dict(
295 | fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
296 | resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
297 | )
298 |
299 | # If magnitude_std is > 0, we introduce some randomness
300 | # in the usually fixed policy and sample magnitude from a normal distribution
301 | # with mean `magnitude` and std-dev of `magnitude_std`.
302 | # NOTE This is my own hack, being tested, not in papers or reference impls.
303 | self.magnitude_std = self.hparams.get('magnitude_std', 0)
304 |
305 | def __call__(self, img):
306 | if random.random() > self.prob:
307 | return img
308 | magnitude = self.magnitude
309 | if self.magnitude_std and self.magnitude_std > 0:
310 | magnitude = random.gauss(magnitude, self.magnitude_std)
311 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
312 | level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
313 | return self.aug_fn(img, *level_args, **self.kwargs)
314 |
315 |
316 | _RAND_TRANSFORMS = [
317 | 'AutoContrast',
318 | 'Equalize',
319 | 'Invert',
320 | 'Rotate',
321 | 'PosterizeTpu',
322 | 'Solarize',
323 | 'SolarizeAdd',
324 | 'Color',
325 | 'Contrast',
326 | 'Brightness',
327 | 'Sharpness',
328 | 'ShearX',
329 | 'ShearY',
330 | 'TranslateXRel',
331 | 'TranslateYRel',
332 | #'Cutout' # FIXME I implement this as random erasing separately
333 | ]
334 |
335 | _RAND_TRANSFORMS_CMC = [
336 | 'AutoContrast',
337 | 'Identity',
338 | 'Rotate',
339 | 'Sharpness',
340 | 'ShearX',
341 | 'ShearY',
342 | 'TranslateXRel',
343 | 'TranslateYRel',
344 | #'Cutout' # FIXME I implement this as random erasing separately
345 | ]
346 |
347 |
348 | # These experimental weights are based loosely on the relative improvements mentioned in paper.
349 | # They may not result in increased performance, but could likely be tuned to so.
350 | _RAND_CHOICE_WEIGHTS_0 = {
351 | 'Rotate': 0.3,
352 | 'ShearX': 0.2,
353 | 'ShearY': 0.2,
354 | 'TranslateXRel': 0.1,
355 | 'TranslateYRel': 0.1,
356 | 'Color': .025,
357 | 'Sharpness': 0.025,
358 | 'AutoContrast': 0.025,
359 | 'Solarize': .005,
360 | 'SolarizeAdd': .005,
361 | 'Contrast': .005,
362 | 'Brightness': .005,
363 | 'Equalize': .005,
364 | 'PosterizeTpu': 0,
365 | 'Invert': 0,
366 | }
367 |
368 |
369 | def _select_rand_weights(weight_idx=0, transforms=None):
370 | transforms = transforms or _RAND_TRANSFORMS
371 | assert weight_idx == 0 # only one set of weights currently
372 | rand_weights = _RAND_CHOICE_WEIGHTS_0
373 | probs = [rand_weights[k] for k in transforms]
374 | probs /= np.sum(probs)
375 | return probs
376 |
377 |
378 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
379 | """rand augment ops for RGB images"""
380 | hparams = hparams or _HPARAMS_DEFAULT
381 | transforms = transforms or _RAND_TRANSFORMS
382 | return [AutoAugmentOp(
383 | name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
384 |
385 |
386 | def rand_augment_ops_cmc(magnitude=10, hparams=None, transforms=None):
387 | """rand augment ops for CMC images (removing color ops)"""
388 | hparams = hparams or _HPARAMS_DEFAULT
389 | transforms = transforms or _RAND_TRANSFORMS_CMC
390 | return [AutoAugmentOp(
391 | name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
392 |
393 |
394 | class RandAugment:
395 | def __init__(self, ops, num_layers=2, choice_weights=None):
396 | self.ops = ops
397 | self.num_layers = num_layers
398 | self.choice_weights = choice_weights
399 |
400 | def __call__(self, img):
401 | # no replacement when using weighted choice
402 | ops = np.random.choice(
403 | self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
404 | for op in ops:
405 | img = op(img)
406 | return img
407 |
408 |
409 | def rand_augment_transform(config_str, hparams, use_cmc=False):
410 | """
411 | Create a RandAugment transform
412 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
413 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
414 | sections, not order sepecific determine
415 | 'm' - integer magnitude of rand augment
416 | 'n' - integer num layers (number of transform ops selected per image)
417 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
418 | 'mstd' - float std deviation of magnitude noise applied
419 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
420 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
421 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
422 | :param use_cmc: Flag indicates removing augmentation for coloring ops.
423 | :return: A PyTorch compatible Transform
424 | """
425 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
426 | num_layers = 2 # default to 2 ops per image
427 | weight_idx = None # default to no probability weights for op choice
428 | config = config_str.split('-')
429 | assert config[0] == 'rand'
430 | config = config[1:]
431 | for c in config:
432 | cs = re.split(r'(\d.*)', c)
433 | if len(cs) < 2:
434 | continue
435 | key, val = cs[:2]
436 | if key == 'mstd':
437 | # noise param injected via hparams for now
438 | hparams.setdefault('magnitude_std', float(val))
439 | elif key == 'm':
440 | magnitude = int(val)
441 | elif key == 'n':
442 | num_layers = int(val)
443 | elif key == 'w':
444 | weight_idx = int(val)
445 | else:
446 | assert False, 'Unknown RandAugment config section'
447 | if use_cmc:
448 | ra_ops = rand_augment_ops_cmc(magnitude=magnitude, hparams=hparams)
449 | else:
450 | ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams)
451 | choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
452 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
453 |
454 |
455 | class GaussianBlur(object):
456 | """blur a single image on CPU"""
457 | def __init__(self, kernel_size):
458 | radias = kernel_size // 2
459 | kernel_size = radias * 2 + 1
460 | self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
461 | stride=1, padding=0, bias=False, groups=3)
462 | self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
463 | stride=1, padding=0, bias=False, groups=3)
464 | self.k = kernel_size
465 | self.r = radias
466 |
467 | self.blur = nn.Sequential(
468 | nn.ReflectionPad2d(radias),
469 | self.blur_h,
470 | self.blur_v
471 | )
472 |
473 | self.pil_to_tensor = transforms.ToTensor()
474 | self.tensor_to_pil = transforms.ToPILImage()
475 |
476 | def __call__(self, img):
477 | img = self.pil_to_tensor(img).unsqueeze(0)
478 |
479 | sigma = np.random.uniform(0.1, 2.0)
480 | x = np.arange(-self.r, self.r + 1)
481 | x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
482 | x = x / x.sum()
483 | x = torch.from_numpy(x).view(1, -1).repeat(3, 1)
484 |
485 | self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
486 | self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))
487 |
488 | with torch.no_grad():
489 | img = self.blur(img)
490 | img = img.squeeze()
491 |
492 | img = self.tensor_to_pil(img)
493 |
494 | return img
495 |
--------------------------------------------------------------------------------
/downstream_tasks/imb_utils/util.py:
--------------------------------------------------------------------------------
1 | # original code: https://github.com/kaidic/LDAM-DRW/blob/master/utils.py
2 | import torch
3 | import torch.distributed as dist
4 | import shutil
5 | import os
6 | import numpy as np
7 | import matplotlib
8 |
9 | matplotlib.use('Agg')
10 | import matplotlib.pyplot as plt
11 | from sklearn.metrics import confusion_matrix
12 | from sklearn.utils.multiclass import unique_labels
13 |
14 |
15 | class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
16 |
17 | def __init__(self, dataset, indices=None, num_samples=None):
18 | # if indices is not provided,
19 | # all elements in the dataset will be considered
20 | self.indices = list(range(len(dataset))) \
21 | if indices is None else indices
22 |
23 | # if num_samples is not provided,
24 | # draw `len(indices)` samples in each iteration
25 | self.num_samples = len(self.indices) \
26 | if num_samples is None else num_samples
27 |
28 | # distribution of classes in the dataset
29 | label_to_count = [0] * len(np.unique(dataset.targets))
30 | for idx in self.indices:
31 | label = self._get_label(dataset, idx)
32 | label_to_count[label] += 1
33 |
34 | beta = 0.9999
35 | effective_num = 1.0 - np.power(beta, label_to_count)
36 | per_cls_weights = (1.0 - beta) / np.array(effective_num)
37 |
38 | # weight for each sample
39 | weights = [per_cls_weights[self._get_label(dataset, idx)]
40 | for idx in self.indices]
41 | self.weights = torch.DoubleTensor(weights)
42 |
43 | def _get_label(self, dataset, idx):
44 | return dataset.targets[idx]
45 |
46 | def __iter__(self):
47 | return iter(torch.multinomial(self.weights, self.num_samples, replacement=True).tolist())
48 |
49 | def __len__(self):
50 | return self.num_samples
51 |
52 |
53 | def calc_confusion_mat(val_loader, model, args):
54 | model.eval()
55 | all_preds = []
56 | all_targets = []
57 | with torch.no_grad():
58 | for i, (input, target) in enumerate(val_loader):
59 | if args.gpu is not None:
60 | input = input.cuda(args.gpu, non_blocking=True)
61 | target = target.cuda(args.gpu, non_blocking=True)
62 |
63 | # compute output
64 | output = model(input)
65 | _, pred = torch.max(output, 1)
66 | all_preds.extend(pred.cpu().numpy())
67 | all_targets.extend(target.cpu().numpy())
68 | cf = confusion_matrix(all_targets, all_preds).astype(float)
69 |
70 | cls_cnt = cf.sum(axis=1)
71 | cls_hit = np.diag(cf)
72 |
73 | cls_acc = cls_hit / cls_cnt
74 |
75 | print('Class Accuracy : ')
76 | print(cls_acc)
77 | classes = [str(x) for x in args.cls_num_list]
78 | plot_confusion_matrix(all_targets, all_preds, classes)
79 | plt.savefig(os.path.join(args.root_log, args.store_name, 'confusion_matrix.png'))
80 |
81 |
82 | def plot_confusion_matrix(y_true, y_pred, classes,
83 | normalize=False,
84 | title=None,
85 | cmap=plt.cm.Blues):
86 | if not title:
87 | if normalize:
88 | title = 'Normalized confusion matrix'
89 | else:
90 | title = 'Confusion matrix, without normalization'
91 |
92 | # Compute confusion matrix
93 | cm = confusion_matrix(y_true, y_pred)
94 |
95 | fig, ax = plt.subplots()
96 | im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
97 | ax.figure.colorbar(im, ax=ax)
98 | # We want to show all ticks...
99 | ax.set(xticks=np.arange(cm.shape[1]),
100 | yticks=np.arange(cm.shape[0]),
101 | # ... and label them with the respective list entries
102 | xticklabels=classes, yticklabels=classes,
103 | title=title,
104 | ylabel='True label',
105 | xlabel='Predicted label')
106 |
107 | # Rotate the tick labels and set their alignment.
108 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
109 | rotation_mode="anchor")
110 |
111 | # Loop over data dimensions and create text annotations.
112 | fmt = '.2f' if normalize else 'd'
113 | thresh = cm.max() / 2.
114 | for i in range(cm.shape[0]):
115 | for j in range(cm.shape[1]):
116 | ax.text(j, i, format(cm[i, j], fmt),
117 | ha="center", va="center",
118 | color="white" if cm[i, j] > thresh else "black")
119 | fig.tight_layout()
120 | return ax
121 |
122 |
123 | def prepare_folders(args):
124 | folders_util = [args.root_log, args.root_model,
125 | os.path.join(args.root_log, args.store_name),
126 | os.path.join(args.root_model, args.store_name)]
127 | for folder in folders_util:
128 | if not os.path.exists(folder):
129 | print('creating folder ' + folder)
130 | os.mkdir(folder)
131 |
132 |
133 | def save_checkpoint(args, state, is_best, epoch):
134 | filename = '%s/%s/ckpt.pth.tar' % (args.root_model, args.store_name)
135 | torch.save(state, filename)
136 | if is_best:
137 | shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar'))
138 | if epoch % 20 == 0:
139 | filename = '%s/%s/%s_ckpt.pth.tar' % (args.root_model, args.store_name, str(epoch))
140 | torch.save(state, filename)
141 |
142 |
143 | class AverageMeter(object):
144 |
145 | def __init__(self, name, fmt=':f'):
146 | self.name = name
147 | self.fmt = fmt
148 | self.reset()
149 |
150 | def reset(self):
151 | self.val = 0
152 | self.avg = 0
153 | self.sum = 0
154 | self.count = 0
155 |
156 | def update(self, val, n=1):
157 | self.val = val
158 | self.sum += val * n
159 | self.count += n
160 | self.avg = self.sum / self.count
161 |
162 | def __str__(self):
163 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
164 | return fmtstr.format(**self.__dict__)
165 |
166 |
167 | def accuracy(output, target, topk=(1,)):
168 | # print("output", output.shape)
169 | # print("target", target.shape)
170 | with torch.no_grad():
171 | maxk = max(topk)
172 | batch_size = target.size(0)
173 |
174 | _, pred = output.topk(maxk, 1, True, True)
175 | pred = pred.t()
176 | if target.dim() == 1:
177 | target = target.reshape(1, -1)
178 | elif target.dim() == 2:
179 | target = target.argmax(dim=1).reshape(1, -1)
180 | correct = pred.eq(target.expand_as(pred))
181 |
182 | res = []
183 | for k in topk:
184 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
185 | res.append(correct_k.mul_(100.0 / batch_size))
186 | return res
187 |
188 |
--------------------------------------------------------------------------------
/downstream_tasks/losses.py:
--------------------------------------------------------------------------------
1 | # reference code: https://github.com/kaidic/LDAM-DRW/blob/master/cifar_train.py
2 |
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.autograd import Variable
9 |
10 |
11 | def focal_loss(input_values, gamma):
12 | """Computes the focal loss"""
13 | p = torch.exp(-input_values)
14 | # loss = (1 - p) ** gamma * input_values
15 | loss = (1 - p) ** gamma * input_values * 10
16 | return loss.mean()
17 |
18 |
19 | class FocalLoss(nn.Module):
20 | def __init__(self, weight=None, gamma=0.0):
21 | super(FocalLoss, self).__init__()
22 | assert gamma >= 0
23 | self.gamma = gamma
24 | self.weight = weight
25 |
26 | def forward(self, input, target):
27 | return focal_loss(
28 | F.cross_entropy(input, target, reduction="none", weight=self.weight),
29 | self.gamma,
30 | )
31 |
32 |
33 | class LDAMLoss(nn.Module):
34 | def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30):
35 | super(LDAMLoss, self).__init__()
36 | m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list))
37 | m_list = m_list * (max_m / np.max(m_list))
38 | m_list = torch.cuda.FloatTensor(m_list)
39 | self.m_list = m_list
40 | assert s > 0
41 | self.s = s
42 | self.weight = weight
43 |
44 | def forward(self, x, target):
45 | index = torch.zeros_like(x, dtype=torch.uint8)
46 | index.scatter_(1, target.data.view(-1, 1), 1)
47 |
48 | index_float = index.type(torch.cuda.FloatTensor)
49 | batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0, 1))
50 | batch_m = batch_m.view((-1, 1))
51 | x_m = x - batch_m
52 |
53 | output = torch.where(index, x_m, x)
54 | return F.cross_entropy(self.s * output, target, weight=self.weight)
55 |
56 |
57 | class BalancedSoftmaxLoss(nn.Module):
58 | def __init__(self, cls_num_list):
59 | super().__init__()
60 | cls_prior = cls_num_list / sum(cls_num_list)
61 | self.log_prior = torch.log(cls_prior).unsqueeze(0)
62 | # self.min_prob = 1e-9
63 | # print(f'Use BalancedSoftmaxLoss, class_prior: {cls_prior}')
64 |
65 | def forward(self, logits, labels):
66 | adjusted_logits = logits + self.log_prior
67 | label_loss = F.cross_entropy(adjusted_logits, labels)
68 |
69 | return label_loss
70 |
71 |
72 | class LabelSmoothing(nn.Module):
73 | # "Implement label smoothing."
74 |
75 | def __init__(self, size, smoothing=0.0):
76 | super(LabelSmoothing, self).__init__()
77 | self.criterion = nn.KLDivLoss(size_average=False)
78 | # self.padding_idx = padding_idx
79 | self.confidence = 1.0 - smoothing
80 | self.smoothing = smoothing
81 | self.size = size
82 | self.true_dist = None
83 |
84 | def forward(self, x, target):
85 | """
86 | x表示输入 (M,N)N个样本,M表示总类数,每一个类的概率log P
87 | target表示label(M,)
88 | """
89 | assert x.size(1) == self.size
90 | x = x.log()
91 | true_dist = x.data.clone() # 先深复制过来
92 | # print true_dist
93 | true_dist.fill_(self.smoothing / (self.size - 1)) # otherwise的公式
94 | # print true_dist
95 | # 变成one-hot编码,1表示按列填充,
96 | # target.data.unsqueeze(1)表示索引,confidence表示填充的数字
97 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
98 |
99 | self.true_dist = true_dist
100 | print(x.shape, true_dist.shape)
101 |
102 | return self.criterion(x, Variable(true_dist, requires_grad=False))
103 |
104 |
105 | class LabelSmoothingLoss(nn.Module):
106 | def __init__(self, classes, smoothing=0.0, dim=-1):
107 | super(LabelSmoothingLoss, self).__init__()
108 | self.confidence = 1.0 - smoothing
109 | self.smoothing = smoothing
110 | self.cls = classes
111 | self.dim = dim
112 |
113 | def forward(self, pred, target):
114 | pred = pred.log_softmax(dim=self.dim)
115 | with torch.no_grad():
116 | # true_dist = pred.data.clone()
117 | true_dist = torch.zeros_like(pred)
118 | true_dist.fill_(self.smoothing / (self.cls - 1))
119 | if target.dim() == 1:
120 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
121 | else:
122 | true_dist = true_dist + target * self.confidence
123 | return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
124 |
--------------------------------------------------------------------------------
/downstream_tasks/mixup.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import torch
5 | from cutmix.cutmix import CutMix
6 | from cutmix.utils import onehot, rand_bbox
7 | from torch.utils.data.dataset import Dataset
8 |
9 |
10 | def calculate_confusion_matrix(pred, target):
11 | """Calculate confusion matrix according to the prediction and target.
12 |
13 | Args:
14 | pred (torch.Tensor | np.array): The model prediction with shape (N, C).
15 | target (torch.Tensor | np.array): The target of each prediction with
16 | shape (N, 1) or (N,).
17 |
18 | Returns:
19 | torch.Tensor: Confusion matrix
20 | The shape is (C, C), where C is the number of classes.
21 | """
22 |
23 | if isinstance(pred, np.ndarray):
24 | pred = torch.from_numpy(pred)
25 | if isinstance(target, np.ndarray):
26 | target = torch.from_numpy(target)
27 | assert isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor), (
28 | f"pred and target should be torch.Tensor or np.ndarray, "
29 | f"but got {type(pred)} and {type(target)}."
30 | )
31 |
32 | # Modified from PyTorch-Ignite
33 | num_classes = pred.size(1)
34 | pred_label = torch.argmax(pred, dim=1).flatten()
35 | target_label = target.flatten()
36 | assert len(pred_label) == len(target_label)
37 |
38 | with torch.no_grad():
39 | indices = num_classes * target_label + pred_label
40 | matrix = torch.bincount(indices, minlength=num_classes**2)
41 | matrix = matrix.reshape(num_classes, num_classes)
42 | return matrix.detach().cpu().numpy()
43 |
44 |
45 | def calculate_accuracy(pred, target):
46 | _, predicted_labels = torch.max(pred, dim=1)
47 | correct_predictions = torch.sum(predicted_labels == target)
48 | total_samples = target.size(0)
49 |
50 | accuracy = correct_predictions.item() / total_samples
51 | return accuracy
52 |
53 |
54 | def is_vector_label(x):
55 | if isinstance(x, np.ndarray):
56 | return x.size > 1
57 | elif isinstance(x, torch.Tensor):
58 | return x.size().numel() > 1
59 | elif isinstance(x, int):
60 | return False
61 | else:
62 | raise TypeError(f"Unknown type {type(x)}")
63 |
64 |
65 | class CutMix(Dataset):
66 | def __init__(self, dataset, num_class, num_mix=1, beta=1.0, prob=1.0):
67 | self.dataset = dataset
68 | self.num_class = num_class
69 | self.num_mix = num_mix
70 | self.beta = beta
71 | self.prob = prob
72 |
73 | def __getitem__(self, index):
74 | example = self.dataset[index]
75 | img, lb = example["pixel_values"], example["label"]
76 | lb_onehot = lb if is_vector_label(lb) else onehot(self.num_class, lb)
77 |
78 | for _ in range(self.num_mix):
79 | r = np.random.rand(1)
80 | if self.beta <= 0 or r > self.prob:
81 | continue
82 |
83 | # generate mixed sample
84 | lam = np.random.beta(self.beta, self.beta)
85 | rand_index = random.choice(range(len(self)))
86 |
87 | rand_example = self.dataset[rand_index]
88 | img2, lb2 = rand_example["pixel_values"], rand_example["label"]
89 | lb2_onehot = lb2 if is_vector_label(lb2) else onehot(self.num_class, lb2)
90 | bbx1, bby1, bbx2, bby2 = rand_bbox(img.size(), lam)
91 | img[:, bbx1:bbx2, bby1:bby2] = img2[:, bbx1:bbx2, bby1:bby2]
92 | lam = 1 - (
93 | (bbx2 - bbx1) * (bby2 - bby1) / (img.size()[-1] * img.size()[-2])
94 | )
95 | lb_onehot = lb_onehot * lam + lb2_onehot * (1.0 - lam)
96 |
97 | return {"pixel_values": img, "label": lb_onehot}
98 |
99 | def __len__(self):
100 | return len(self.dataset)
101 |
102 |
103 | def mixup_data(x, y, alpha=1, num_classes=200):
104 | """Compute the mixup data. Return mixed inputs, pairs of targets, and lambda"""
105 | if alpha > 0.0:
106 | lam = np.random.beta(alpha, alpha)
107 | else:
108 | lam = 1.0
109 | batch_size = x.size()[0]
110 | index = torch.randperm(batch_size).to(x.device)
111 |
112 | mixed_x = lam * x + (1 - lam) * x[index, :]
113 | if is_vector_label(y):
114 | mixed_y = lam * y + (1 - lam) * y[index]
115 | else:
116 | mixed_y = onehot(y, num_classes) * lam + onehot(y[index], num_classes) * (
117 | 1 - lam
118 | )
119 | return mixed_x, mixed_y
120 |
--------------------------------------------------------------------------------
/outputs:
--------------------------------------------------------------------------------
1 | ../../outputs/da-fusion/outputs/
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhicaiwww/Diff-Mix/a81337b1492bcc7f7a8a61921836e94191a3a0ef/requirement.txt
--------------------------------------------------------------------------------
/scripts/classification.sh:
--------------------------------------------------------------------------------
1 | #dataset (dog cub car aircraft flower pet food chest caltech pascal )
2 | #lr(resnet) (0.001 0.05 0.10 0.10 0.05 0.01 0.01 0.01 0.01 0.01 )
3 | #lr(vit) (0.00001 0.001 0.001 0.001 0.0005 0.0001 0.00005 0.00005 0.00005 0.00005 )
4 |
5 | GPU=1
6 | DATASET="cub"
7 | SHOT=-1
8 | # "shot{args.examples_per_class}_{args.sample_strategy}_{args.strength_strategy}_{args.aug_strength}"
9 | SYNDATA_DIR="aug_samples/cub/shot${SHOT}_diff-mix_fixed_0.7" # shot-1 denotes full shot
10 | SYNDATA_P=0.1
11 | GAMMA=0.8
12 |
13 | python downstream_tasks/train_hub.py \
14 | --dataset $DATASET \
15 | --syndata_dir $SYNDATA_DIR \
16 | --syndata_p $SYNDATA_P \
17 | --model "resnet50" \
18 | --gamma $GAMMA \
19 | --examples_per_class $SHOT \
20 | --gpu $GPU \
21 | --amp 2 \
22 | --note $(date +%m%d%H%M) \
23 | --group_note "fullshot" \
24 | --nepoch 120 \
25 | --res_mode 224 \
26 | --lr 0.05 \
27 | --seed 0 \
28 | --weight_decay 0.0005
--------------------------------------------------------------------------------
/scripts/classification_imb.sh:
--------------------------------------------------------------------------------
1 | # CMO
2 | gpu=1
3 | datast='cub'
4 | imb_factor=0.01
5 | GAMMA=0.8
6 | # "imb{args.imbalance_factor}_{args.sample_strategy}_{args.strength_strategy}_{args.aug_strength}"
7 | SYNDATA_DIR="aug_samples/cub/imb${imb_factor}_diff-mix_fixed_0.7"
8 | SYNDATA_P=0.1
9 |
10 | python downstream_tasks/train_hub_imb.py \
11 | --dataset $datast \
12 | --loss_type CE \
13 | --lr 0.005 \
14 | --epochs 200 \
15 | --imb_factor $imb_factor \
16 | -b 128 \
17 | --gpu $gpu \
18 | --root_log outputs/results_cmo \
19 | --data_aug CMO
20 |
21 | # DRW
22 | python downstream_tasks/train_hub_imb.py \
23 | --dataset $datast \
24 | --loss_type CE \
25 | --lr 0.005 \
26 | --epochs 200 \
27 | --imb_factor $imb_factor \
28 | -b 128 \
29 | --gpu $gpu \
30 | --data_aug vanilla \
31 | --root_log outputs/results_cmo \
32 | --train_rule DRW
33 |
34 | # baseline
35 | python downstream_tasks/train_hub_imb.py \
36 | --dataset $datast \
37 | --loss_type CE \
38 | --lr 0.005 \
39 | --epochs 200 \
40 | --imb_factor $imb_factor \
41 | -b 128 \
42 | --gpu $gpu \
43 | --data_aug vanilla \
44 | --root_log outputs/results_cmo
45 |
46 | # weightedSyn
47 | python downstream_tasks/train_hub_imb.py \
48 | --dataset $datast \
49 | --loss_type CE \
50 | --lr 0.005 \
51 | --epochs 200 \
52 | --imb_factor $imb_factor \
53 | -b 128 \
54 | --gpu $gpu \
55 | --data_aug vanilla \
56 | --root_log outputs/results_cmo \
57 | --syndata_dir $SYNDATA_DIR \
58 | --syndata_p $SYNDATA_P \
59 | --gamma $GAMMA \
60 | --use_weighted_syn
61 |
62 |
--------------------------------------------------------------------------------
/scripts/classification_waterbird.sh:
--------------------------------------------------------------------------------
1 | GPU=1
2 | DATASET="cub"
3 | SHOT=-1
4 | SYNDATA_DIR="aug_samples/cub/shot${SHOT}_diff-mix_fixed_0.7" # shot-1 denotes full shot
5 | SYNDATA_P=0.1
6 | GAMMA=0.8
7 |
8 | python downstream_tasks/train_hub_waterbird.py \
9 | --dataset $DATASET \
10 | --syndata_dir $SYNDATA_DIR \
11 | --syndata_p $SYNDATA_P \
12 | --model "resnet50" \
13 | --gamma $GAMMA \
14 | --examples_per_class $SHOT \
15 | --gpu $GPU \
16 | --amp 2 \
17 | --note $(date +%m%d%H%M) \
18 | --group_note "robustness" \
19 | --nepoch 120 \
20 | --res_mode 224 \
21 | --lr 0.05 \
22 | --seed 0 \
23 | --weight_decay 0.0005
--------------------------------------------------------------------------------
/scripts/compose_syn_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import re
4 | import shutil
5 | from collections import defaultdict
6 |
7 | import pandas as pd
8 | from PIL import Image
9 | from tqdm import tqdm
10 |
11 | from utils.misc import check_synthetic_dir_valid
12 |
13 |
14 | def generate_meta_csv(output_path):
15 | rootdir = os.path.join(output_path, "data")
16 | if os.path.exists(os.path.join(output_path, "meta.csv")):
17 | return
18 | pattern_level_1 = r"(.+)"
19 | pattern_level_2 = r"(.+)-(\d+)-(.+).png"
20 |
21 | # Generate meta.csv for indexing images
22 | data_dict = defaultdict(list)
23 | for dir in os.listdir(rootdir):
24 | if not os.path.isdir(os.path.join(rootdir, dir)):
25 | continue
26 | match_1 = re.match(pattern_level_1, dir)
27 | first_dir = match_1.group(1).replace("_", " ")
28 | for file in os.listdir(os.path.join(rootdir, dir)):
29 | match_2 = re.match(pattern_level_2, file)
30 | second_dir = match_2.group(1).replace("_", " ")
31 | num = int(match_2.group(2))
32 | floating_num = float(match_2.group(3))
33 | data_dict["First Directory"].append(first_dir)
34 | data_dict["Second Directory"].append(second_dir)
35 | data_dict["Number"].append(num)
36 | data_dict["Strength"].append(floating_num)
37 | data_dict["Path"].append(os.path.join(dir, file))
38 |
39 | df = pd.DataFrame(data_dict)
40 |
41 | # Validate generated images
42 | valid_rows = []
43 | for index, row in tqdm(df.iterrows(), total=len(df)):
44 | image_path = os.path.join(output_path, "data", row["Path"])
45 | try:
46 | img = Image.open(image_path)
47 | img.close()
48 | valid_rows.append(row)
49 | except Exception as e:
50 | os.remove(image_path)
51 | print(f"Deleted {image_path} due to error: {str(e)}")
52 |
53 | valid_df = pd.DataFrame(valid_rows)
54 | csv_path = os.path.join(output_path, "meta.csv")
55 | valid_df.to_csv(csv_path, index=False)
56 |
57 | print("DataFrame:")
58 | print(df)
59 |
60 |
61 | def main(source_directory_list, target_directory, num_samples):
62 |
63 | target_directory = os.path.join(target_directory, "data")
64 |
65 | os.makedirs(target_directory, exist_ok=True)
66 |
67 | image_files = []
68 | image_class_names = []
69 | for source_directory in source_directory_list:
70 | source_directory = os.path.join(source_directory, "data")
71 | for class_name in os.listdir(source_directory):
72 | class_directory = os.path.join(source_directory, class_name)
73 | target_class_directory = os.path.join(target_directory, class_name)
74 | os.makedirs(target_class_directory, exist_ok=True)
75 | for filename in os.listdir(class_directory):
76 | if filename.endswith(".png"):
77 | image_class_names.append(class_name)
78 | image_files.append(os.path.join(class_directory, filename))
79 | # random sample idx
80 | random_indices = random.sample(
81 | range(len(image_files)), min(int(num_samples), len(image_files))
82 | )
83 |
84 | selected_image_files = [image_files[i] for i in random_indices]
85 | selected_image_class_names = [image_class_names[i] for i in random_indices]
86 |
87 | for class_name, image_file in tqdm(
88 | zip(selected_image_class_names, selected_image_files),
89 | desc="Copying data",
90 | total=num_samples,
91 | ):
92 | shutil.copy(image_file, os.path.join(target_directory, class_name))
93 |
94 |
95 | # python sample_synthetic_subset.py --num_samples 1000 --dataset cub --source_syn realgen --target_directory outputs/aug_samples_1shot/cub/real-gen-Multi5
96 | # python sample_synthetic_subset.py --num_samples 16670 --dataset aircraft --source_syn mixup0.5 mixup0.7 mixup0.9 --target_directory outputs/aug_samples/aircraft/diff-mix-Uniform
97 | if __name__ == "__main__":
98 | import argparse
99 |
100 | parser = argparse.ArgumentParser()
101 | parser.add_argument(
102 | "--num_samples", type=int, default=40000, help="Number of samples"
103 | )
104 | parser.add_argument("--dataset", type=str, default="cub", help="Dataset name")
105 | parser.add_argument("--source_aug_dir", type=str, nargs="+", default="mixup")
106 | parser.add_argument(
107 | "--target_directory",
108 | type=str,
109 | default="aug_samples/cub/diff-mix-Uniform",
110 | help="Target directory",
111 | )
112 | args = parser.parse_args()
113 | source_directory_list = []
114 | for synthetic_dir in args.source_aug_dir:
115 | check_synthetic_dir_valid(synthetic_dir)
116 | generate_meta_csv(synthetic_dir)
117 | source_directory_list.append(synthetic_dir)
118 | target_directory = args.target_directory
119 | main(source_directory_list, target_directory, args.num_samples)
120 | generate_meta_csv(target_directory)
121 |
--------------------------------------------------------------------------------
/scripts/filter_syn_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | import tqdm
8 | from transformers import CLIPModel, CLIPProcessor
9 |
10 | from dataset.base import SyntheticDataset
11 |
12 | FILTER_CAPTIONS_MAPPING = {
13 | "cub": ["a photo with a bird on it", "a photo without a bird on it"],
14 | "aircraft": [
15 | "a photo with an aricraft on it. ",
16 | "a photo without an aricraft on it. ",
17 | ],
18 | "dog": [
19 | "a photo with a dog on it. ",
20 | "a photo without a dog on it. ",
21 | ],
22 | "flower": [
23 | "a photo with a flower on it. ",
24 | "a photo without a flower on it. ",
25 | ],
26 | "car": [
27 | "a photo with a car on it. ",
28 | "a photo without an car on it. ",
29 | ],
30 | }
31 |
32 |
33 | def to_tensor(x):
34 | if isinstance(x, int):
35 | return torch.tensor(x)
36 | elif isinstance(x, torch.Tensor):
37 | return x
38 | else:
39 | raise NotImplementedError
40 |
41 |
42 | def syn_collate_fn(examples):
43 | pixel_values = [example["pixel_values"] for example in examples]
44 | src_labels = torch.stack([to_tensor(example["src_label"]) for example in examples])
45 | tar_labels = torch.stack([to_tensor(example["tar_label"]) for example in examples])
46 | dtype = torch.float32 if len(src_labels.size()) == 2 else torch.long
47 | src_labels.to(dtype=dtype)
48 | tar_labels.to(dtype=dtype)
49 | return {
50 | "pixel_values": pixel_values,
51 | "src_labels": src_labels,
52 | "tar_labels": tar_labels,
53 | }
54 |
55 |
56 | def main(args):
57 | device = f"cuda:{args.gpu}"
58 | bs = args.batch_size
59 | model = CLIPModel.from_pretrained(
60 | "openai/clip-vit-base-patch32", local_files_only=True
61 | ).to(device)
62 | processor = CLIPProcessor.from_pretrained(
63 | "openai/clip-vit-base-patch32", local_files_only=True
64 | )
65 | ds_syn = SyntheticDataset(synthetic_dir=args.syndata_dir)
66 | ds_syn.transform = torch.nn.Identity()
67 | dataloader_syn = torch.utils.data.DataLoader(
68 | ds_syn, batch_size=bs, collate_fn=syn_collate_fn, shuffle=False, num_workers=4
69 | )
70 | positive_confidence = []
71 |
72 | for batch in tqdm.tqdm(dataloader_syn, total=len(dataloader_syn)):
73 |
74 | images = batch["pixel_values"]
75 | inputs = processor(
76 | text=FILTER_CAPTIONS_MAPPING[args.dataset],
77 | images=images,
78 | return_tensors="pt",
79 | padding=True,
80 | ).to(device)
81 |
82 | outputs = model(**inputs)
83 | logits_per_image = (
84 | outputs.logits_per_image
85 | ) # this is the image-text similarity score
86 | probs = logits_per_image.softmax(
87 | dim=1
88 | ) # we can take the softmax to get the label probabilities
89 | probs = probs.cpu().detach().numpy()
90 | positive_confidence = np.concatenate((positive_confidence, probs[:, 0]))
91 | # filter the least 10% confident samples
92 | positive_confidence = np.array(positive_confidence)
93 | bottom_threshold = np.percentile(positive_confidence, 10)
94 | up_threshold = np.percentile(positive_confidence, 90)
95 | meta_df = pd.read_csv(os.path.join(args.synthetic_dir, "meta.csv"))
96 | meta_df1 = meta_df[positive_confidence >= bottom_threshold]
97 | meta_df2 = meta_df[positive_confidence < bottom_threshold]
98 | meta_df3 = meta_df[positive_confidence >= up_threshold]
99 |
100 | meta_df1.to_csv(os.path.join(args.synthetic_dir, "meta_10-100per.csv"), index=False)
101 | meta_df2.to_csv(os.path.join(args.synthetic_dir, "meta_0-10per.csv"), index=False)
102 | meta_df3.to_csv(os.path.join(args.synthetic_dir, "meta_90-100per.csv"), index=False)
103 |
104 |
105 | if __name__ == "__main__":
106 | parsers = argparse.ArgumentParser()
107 | parsers.add_argument("--dataset", type=str, default="cub")
108 | parsers.add_argument("--syndata_dir", type=str, required=True)
109 | parsers.add_argument("-g", "--gpu", type=str, default="0")
110 | parsers.add_argument("-b", "--batch_size", type=int, default=200)
111 | args = parsers.parse_args()
112 | main(args)
113 |
--------------------------------------------------------------------------------
/scripts/finetune.sh:
--------------------------------------------------------------------------------
1 | MODEL_NAME="runwayml/stable-diffusion-v1-5"
2 | DATASET='cub'
3 | SHOT=-1 # set -1 for full shot
4 | OUTPUT_DIR="ckpts/${DATASET}/shot${SHOT}_lora_rank10"
5 |
6 | accelerate launch --mixed_precision='fp16' --main_process_port 29507 \
7 | train_lora.py \
8 | --pretrained_model_name_or_path=$MODEL_NAME \
9 | --dataset_name=$DATASET \
10 | --resolution=224 \
11 | --random_flip \
12 | --max_train_steps=35000 \
13 | --num_train_epochs=10 \
14 | --checkpointing_steps=5000 \
15 | --learning_rate=5e-05 \
16 | --lr_scheduler='constant' \
17 | --lr_warmup_steps=0 \
18 | --seed=42 \
19 | --rank=10 \
20 | --local_files_only \
21 | --examples_per_class $SHOT \
22 | --train_batch_size 2 \
23 | --output_dir=$OUTPUT_DIR \
24 | --report_to='tensorboard'"
25 |
--------------------------------------------------------------------------------
/scripts/finetune_imb.sh:
--------------------------------------------------------------------------------
1 | MODEL_NAME="runwayml/stable-diffusion-v1-5"
2 | DATASET='cub'
3 | SHOT=-1 # set -1 for full shot
4 | IMB_FACTOR=0.01
5 | OUTPUT_DIR="ckpts/${DATASET}/imb{$IMB_FACTOR}_lora_rank10"
6 |
7 | accelerate launch --mixed_precision='fp16' --main_process_port 29507 \
8 | train_lora.py \
9 | --pretrained_model_name_or_path=$MODEL_NAME \
10 | --dataset_name=$DATASET \
11 | --resolution=224 \
12 | --random_flip \
13 | --max_train_steps=35000 \
14 | --num_train_epochs=10 \
15 | --checkpointing_steps=5000 \
16 | --learning_rate=5e-05 \
17 | --lr_scheduler='constant' \
18 | --lr_warmup_steps=0 \
19 | --seed=42 \
20 | --rank=10 \
21 | --local_files_only \
22 | --examples_per_class $SHOT \
23 | --train_batch_size 2 \
24 | --output_dir=$OUTPUT_DIR \
25 | --report_to='tensorboard' \
26 | --task='imbalanced' \
27 | --imbalance_factor $IMB_FACTOR
28 |
29 |
--------------------------------------------------------------------------------
/scripts/sample.sh:
--------------------------------------------------------------------------------
1 |
2 | DATASET='cub'
3 | # set -1 for full shot
4 | SHOT=-1
5 | FINETUNED_CKPT="ckpts/cub/shot${SHOT}-lora-rank10"
6 | # ['diff-mix', 'diff-aug', 'diff-gen', 'real-mix', 'real-aug', 'real-gen', 'ti_mix', 'ti_aug']
7 | SAMPLE_STRATEGY='diff-mix'
8 | STRENGTH=0.8
9 | # ['fixed', 'uniform']. 'fixed': use fixed $STRENGTH, 'uniform': sample from [0.3, 0.5, 0.7, 0.9]
10 | STRENGTH_STRATEGY='fixed'
11 | # expand the dataset by 5 times
12 | MULTIPLIER=5
13 | # spwan 4 processes
14 | GPU_IDS=(0 1 2 3)
15 |
16 | python scripts/sample_mp.py \
17 | --model-path='runwayml/stable-diffusion-v1-5' \
18 | --output_root='outputs/aug_samples' \
19 | --dataset=$DATASET \
20 | --finetuned_ckpt=$FINETUNED_CKPT \
21 | --syn_dataset_mulitiplier=$MULTIPLIER \
22 | --strength_strategy=$STRENGTH_STRATEGY \
23 | --sample_strategy=$SAMPLE_STRATEGY \
24 | --examples_per_class=$SHOT \
25 | --resolution=512 \
26 | --batch_size=1 \
27 | --aug_strength=0.8 \
28 | --gpu-ids=${GPU_IDS[@]}
29 |
30 |
--------------------------------------------------------------------------------
/scripts/sample_imb.sh:
--------------------------------------------------------------------------------
1 |
2 | DATASET='cub'
3 | # set -1 for full shot
4 | SHOT=-1
5 | # ['diff-mix', 'diff-aug', 'diff-gen', 'real-mix', 'real-aug', 'real-gen', 'ti_mix', 'ti_aug']
6 | SAMPLE_STRATEGY='diff-mix'
7 | STRENGTH=0.8
8 | # ['fixed', 'uniform']. 'fixed': use fixed $STRENGTH, 'uniform': sample from [0.3, 0.5, 0.7, 0.9]
9 | STRENGTH_STRATEGY='fixed'
10 | # expand the dataset by 5 times
11 | MULTIPLIER=5
12 | # spwan 4 processes
13 | IMB_FACTOR=0.01
14 | FINETUNED_CKPT="ckpts/cub/imb${IMB_FACTOR}-lora-rank10"
15 | GPU_IDS=(0 1 2 3)
16 |
17 | python scripts/sample_mp.py \
18 | --model-path='runwayml/stable-diffusion-v1-5' \
19 | --output_root='outputs/aug_samples' \
20 | --dataset=$DATASET \
21 | --finetuned_ckpt=$FINETUNED_CKPT \
22 | --syn_dataset_mulitiplier=$MULTIPLIER \
23 | --strength_strategy=$STRENGTH_STRATEGY \
24 | --sample_strategy=$SAMPLE_STRATEGY \
25 | --examples_per_class=$SHOT \
26 | --resolution=512 \
27 | --batch_size=1 \
28 | --aug_strength=0.8 \
29 | --gpu-ids=${GPU_IDS[@]} \
30 | --task='imbalanced' \
31 | --imbalance_factor=$IMB_FACTOR
32 |
33 |
--------------------------------------------------------------------------------
/scripts/sample_mp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import re
5 | import sys
6 | import time
7 |
8 | import numpy as np
9 | import pandas as pd
10 | import torch
11 | import yaml
12 |
13 | os.environ["CURL_CA_BUNDLE"] = ""
14 |
15 | from collections import defaultdict
16 | from multiprocessing import Process, Queue
17 | from queue import Empty
18 |
19 | from PIL import Image
20 | from tqdm import tqdm
21 |
22 | sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
23 | from augmentation import AUGMENT_METHODS
24 | from dataset import DATASET_NAME_MAPPING, IMBALANCE_DATASET_NAME_MAPPING
25 | from utils.misc import parse_finetuned_ckpt
26 |
27 |
28 | def check_args_valid(args):
29 |
30 | if args.sample_strategy == "real-gen":
31 | args.lora_path = None
32 | args.embed_path = None
33 | args.aug_strength = 1
34 | elif args.sample_strategy == "diff-gen":
35 | lora_path, embed_path = parse_finetuned_ckpt(args.finetuned_ckpt)
36 | args.lora_path = lora_path
37 | args.embed_path = embed_path
38 | args.aug_strength = 1
39 | elif args.sample_strategy in ["real-aug", "real-mix"]:
40 | args.lora_path = None
41 | args.embed_path = None
42 | elif args.sample_strategy in ["diff-aug", "diff-mix"]:
43 | lora_path, embed_path = parse_finetuned_ckpt(args.finetuned_ckpt)
44 | args.lora_path = lora_path
45 | args.embed_path = embed_path
46 |
47 |
48 | def sample_func(args, in_queue, out_queue, gpu_id, process_id):
49 |
50 | os.environ["CURL_CA_BUNDLE"] = ""
51 |
52 | random.seed(args.seed + process_id)
53 | np.random.seed(args.seed + process_id)
54 | torch.manual_seed(args.seed + process_id)
55 |
56 | if args.task == "imbalanced":
57 | train_dataset = IMBALANCE_DATASET_NAME_MAPPING[args.dataset](
58 | split="train",
59 | seed=args.seed,
60 | resolution=args.resolution,
61 | imbalance_factor=args.imbalance_factor,
62 | )
63 | else:
64 | train_dataset = DATASET_NAME_MAPPING[args.dataset](
65 | split="train",
66 | seed=args.seed, # dataset seed is fixed for all processes
67 | examples_per_class=args.examples_per_class,
68 | resolution=args.resolution,
69 | )
70 |
71 | model = AUGMENT_METHODS[args.sample_strategy](
72 | model_path=args.model_path,
73 | embed_path=args.embed_path,
74 | lora_path=args.lora_path,
75 | prompt=args.prompt,
76 | guidance_scale=args.guidance_scale,
77 | device=f"cuda:{gpu_id}",
78 | )
79 | batch_size = args.batch_size
80 |
81 | while True:
82 | index_list = []
83 | source_label_list = []
84 | target_label_list = []
85 | strength_list = []
86 | for _ in range(batch_size):
87 | try:
88 | index, source_label, target_label, strength = in_queue.get(timeout=1)
89 | index_list.append(index)
90 | source_label_list.append(source_label)
91 | target_label_list.append(target_label)
92 | strength_list.append(strength)
93 | except Empty:
94 | print("queue empty, exit")
95 | break
96 | target_label = target_label_list[0]
97 | target_indice = random.sample(train_dataset.label_to_indices[target_label], 1)[
98 | 0
99 | ]
100 | target_metadata = train_dataset.get_metadata_by_idx(target_indice)
101 | target_name = target_metadata["name"].replace(" ", "_").replace("/", "_")
102 |
103 | source_images = []
104 | save_paths = []
105 | if args.task == "vanilla":
106 | source_indices = [
107 | random.sample(train_dataset.label_to_indices[source_label], 1)[0]
108 | for source_label in source_label_list
109 | ]
110 | elif args.task == "imbalanced":
111 | source_indices = random.sample(range(len(train_dataset)), batch_size)
112 | for index, source_indice in zip(index_list, source_indices):
113 | source_images.append(train_dataset.get_image_by_idx(source_indice))
114 | source_metadata = train_dataset.get_metadata_by_idx(source_indice)
115 | source_name = source_metadata["name"].replace(" ", "_").replace("/", "_")
116 | save_name = os.path.join(
117 | source_name, f"{target_name}-{index:06d}-{strength}.png"
118 | )
119 | save_paths.append(os.path.join(args.output_path, "data", save_name))
120 |
121 | if os.path.exists(save_paths[0]):
122 | print(f"skip {save_paths[0]}")
123 | else:
124 | image, _ = model(
125 | image=source_images,
126 | label=target_label,
127 | strength=strength,
128 | metadata=target_metadata,
129 | resolution=args.resolution,
130 | )
131 | for image, save_path in zip(image, save_paths):
132 | image.save(save_path)
133 | print(f"save {save_path}")
134 |
135 |
136 | def main(args):
137 |
138 | torch.multiprocessing.set_start_method("spawn")
139 |
140 | os.makedirs(os.path.join(args.output_root, args.dataset), exist_ok=True)
141 |
142 | check_args_valid(args)
143 | if args.task == "vanilla":
144 | output_name = f"shot{args.examples_per_class}_{args.sample_strategy}_{args.strength_strategy}_{args.aug_strength}"
145 | else: # imbalanced
146 | output_name = f"imb{args.imbalance_factor}_{args.sample_strategy}_{args.strength_strategy}_{args.aug_strength}"
147 | args.output_path = os.path.join(args.output_root, args.dataset, output_name)
148 |
149 | os.makedirs(args.output_path, exist_ok=True)
150 | torch.manual_seed(args.seed)
151 | np.random.seed(args.seed)
152 | random.seed(args.seed)
153 |
154 | gpu_ids = args.gpu_ids
155 | in_queue = Queue()
156 | out_queue = Queue()
157 |
158 | if args.task == "imbalanced":
159 | train_dataset = IMBALANCE_DATASET_NAME_MAPPING[args.dataset](
160 | split="train",
161 | seed=args.seed,
162 | resolution=args.resolution,
163 | imbalance_factor=args.imbalance_factor,
164 | )
165 | else:
166 | train_dataset = DATASET_NAME_MAPPING[args.dataset](
167 | split="train",
168 | seed=args.seed,
169 | examples_per_class=args.examples_per_class,
170 | resolution=args.resolution,
171 | )
172 |
173 | num_classes = len(train_dataset.class_names)
174 |
175 | for name in train_dataset.class_names:
176 | name = name.replace(" ", "_").replace("/", "_")
177 | os.makedirs(os.path.join(args.output_path, "data", name), exist_ok=True)
178 |
179 | num_tasks = args.syn_dataset_mulitiplier * len(train_dataset)
180 |
181 | if args.sample_strategy in [
182 | "real-gen",
183 | "real-aug",
184 | "diff-aug",
185 | "diff-gen",
186 | "ti-aug",
187 | ]:
188 | source_classes = random.choices(
189 | range(len(train_dataset.class_names)), k=num_tasks
190 | )
191 | target_classes = source_classes
192 | elif args.sample_strategy in ["real-mix", "diff-mix", "ti-mix"]:
193 | source_classes = random.choices(
194 | range(len(train_dataset.class_names)), k=num_tasks
195 | )
196 | target_classes = random.choices(
197 | range(len(train_dataset.class_names)), k=num_tasks
198 | )
199 | else:
200 | raise ValueError(f"Augmentation strategy {args.sample_strategy} not supported")
201 |
202 | if args.strength_strategy == "fixed":
203 | strength_list = [args.aug_strength] * num_tasks
204 | elif args.strength_strategy == "uniform":
205 | strength_list = random.choices([0.3, 0.5, 0.7, 0.9], k=num_tasks)
206 |
207 | options = zip(range(num_tasks), source_classes, target_classes, strength_list)
208 |
209 | for option in options:
210 | in_queue.put(option)
211 |
212 | sample_config = vars(args)
213 | sample_config["num_classes"] = num_classes
214 | sample_config["total_tasks"] = num_tasks
215 | sample_config["sample_strategy"] = args.sample_strategy
216 |
217 | with open(
218 | os.path.join(args.output_path, "config.yaml"), "w", encoding="utf-8"
219 | ) as f:
220 | yaml.dump(sample_config, f)
221 |
222 | processes = []
223 | total_tasks = in_queue.qsize()
224 | print("Number of total tasks", total_tasks)
225 |
226 | with tqdm(total=total_tasks, desc="Processing") as pbar:
227 | for process_id, gpu_id in enumerate(gpu_ids):
228 | process = Process(
229 | target=sample_func,
230 | args=(args, in_queue, out_queue, gpu_id, process_id),
231 | )
232 | process.start()
233 | processes.append(process)
234 |
235 | while any(process.is_alive() for process in processes):
236 | current_queue_size = in_queue.qsize()
237 | pbar.n = total_tasks - current_queue_size
238 | pbar.refresh()
239 | time.sleep(1)
240 |
241 | for process in processes:
242 | process.join()
243 |
244 | # Generate meta.csv for indexing images
245 | rootdir = os.path.join(args.output_path, "data")
246 | pattern_level_1 = r"(.+)"
247 | pattern_level_2 = r"(.+)-(\d+)-(.+).png"
248 | data_dict = defaultdict(list)
249 | for dir in os.listdir(rootdir):
250 | if not os.path.isdir(os.path.join(rootdir, dir)):
251 | continue
252 | match_1 = re.match(pattern_level_1, dir)
253 | first_dir = match_1.group(1).replace("_", " ")
254 | for file in os.listdir(os.path.join(rootdir, dir)):
255 | match_2 = re.match(pattern_level_2, file)
256 | second_dir = match_2.group(1).replace("_", " ")
257 | num = int(match_2.group(2))
258 | floating_num = float(match_2.group(3))
259 | data_dict["First Directory"].append(first_dir)
260 | data_dict["Second Directory"].append(second_dir)
261 | data_dict["Number"].append(num)
262 | data_dict["Strength"].append(floating_num)
263 | data_dict["Path"].append(os.path.join(dir, file))
264 |
265 | df = pd.DataFrame(data_dict)
266 |
267 | # Validate generated images
268 | valid_rows = []
269 | for index, row in tqdm(df.iterrows(), total=len(df)):
270 | image_path = os.path.join(args.output_path, "data", row["Path"])
271 | try:
272 | img = Image.open(image_path)
273 | img.close()
274 | valid_rows.append(row)
275 | except Exception as e:
276 | os.remove(image_path)
277 | print(f"Deleted {image_path} due to error: {str(e)}")
278 |
279 | valid_df = pd.DataFrame(valid_rows)
280 | csv_path = os.path.join(args.output_path, "meta.csv")
281 | valid_df.to_csv(csv_path, index=False)
282 |
283 | print("DataFrame:")
284 | print(df)
285 |
286 |
287 | if __name__ == "__main__":
288 | parser = argparse.ArgumentParser("Inference script")
289 | parser.add_argument(
290 | "--finetuned_ckpt",
291 | type=str,
292 | required=True,
293 | help="key for indexing finetuned model",
294 | )
295 | parser.add_argument(
296 | "--output_root",
297 | type=str,
298 | default="outputs/aug_samples",
299 | help="output root directory",
300 | )
301 | parser.add_argument(
302 | "--model_path", type=str, default="CompVis/stable-diffusion-v1-4"
303 | )
304 | parser.add_argument("--dataset", type=str, default="pascal", help="dataset name")
305 | parser.add_argument("--seed", type=int, default=0, help="random seed")
306 | parser.add_argument(
307 | "--examples_per_class",
308 | type=int,
309 | default=-1,
310 | help="synthetic examples per class",
311 | )
312 | parser.add_argument("--resolution", type=int, default=512, help="image resolution")
313 | parser.add_argument("--batch_size", type=int, default=1, help="batch size")
314 | parser.add_argument(
315 | "--prompt", type=str, default="a photo of a {name}", help="prompt for synthesis"
316 | )
317 | parser.add_argument(
318 | "--sample_strategy",
319 | type=str,
320 | default="ti-mix",
321 | choices=[
322 | "real-gen",
323 | "real-aug", # real guidance
324 | "real-mix",
325 | "ti-aug",
326 | "ti-mix",
327 | "diff-aug",
328 | "diff-mix",
329 | "diff-gen",
330 | ],
331 | help="sampling strategy for synthetic data",
332 | )
333 | parser.add_argument(
334 | "--guidance-scale",
335 | type=float,
336 | default=7.5,
337 | help="classifier free guidance scale",
338 | )
339 | parser.add_argument("--gpu_ids", type=int, nargs="+", default=[0], help="gpu ids")
340 | parser.add_argument(
341 | "--task",
342 | type=str,
343 | default="vanilla",
344 | choices=["vanilla", "imbalanced"],
345 | help="task",
346 | )
347 | parser.add_argument(
348 | "--imbalance_factor",
349 | type=float,
350 | default=0.01,
351 | choices=[0.01, 0.02, 0.1],
352 | help="imbalanced factor, only for imbalanced task",
353 | )
354 | parser.add_argument(
355 | "--syn_dataset_mulitiplier",
356 | type=int,
357 | default=5,
358 | help="multiplier for the number of synthetic images compared to the number of real images",
359 | )
360 | parser.add_argument(
361 | "--strength_strategy",
362 | type=str,
363 | default="fixed",
364 | choices=["fixed", "uniform"],
365 | )
366 | parser.add_argument(
367 | "--aug_strength", type=float, default=0.5, help="augmentation strength"
368 | )
369 | args = parser.parse_args()
370 |
371 | main(args)
372 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 | import re
4 | from typing import List, Union
5 |
6 | import pandas as pd
7 | import yaml
8 |
9 |
10 | def count_files_in_directory(directory):
11 | count = 0
12 | for root, dirs, files in os.walk(directory):
13 | count += len(files)
14 | return count
15 |
16 |
17 | def check_synthetic_dir_valid(synthetic_dir):
18 |
19 | if not os.path.exists(synthetic_dir):
20 | raise FileNotFoundError(f"Directory '{synthetic_dir}' does not exist.")
21 |
22 | total_files = count_files_in_directory(synthetic_dir)
23 | if total_files > 100:
24 | print(f"Directory '{synthetic_dir}' is valid with {total_files} files.")
25 | else:
26 | raise ValueError(
27 | f"Directory '{synthetic_dir}' contains less than 100 files, which is insufficient."
28 | )
29 |
30 |
31 | def parse_finetuned_ckpt(finetuned_ckpt):
32 | lora_path = None
33 | embed_path = None
34 | for file in os.listdir(finetuned_ckpt):
35 | if "pytorch_lora_weights" in file:
36 | lora_path = os.path.join(finetuned_ckpt, file)
37 | elif "learned_embeds-steps-last" in file:
38 | embed_path = os.path.join(finetuned_ckpt, file)
39 | return lora_path, embed_path
40 |
41 |
42 | def checked_has_run(exp_dir, args):
43 | parent_dir = os.path.abspath(os.path.join(exp_dir, os.pardir))
44 | current_args = copy.deepcopy(args)
45 | current_args.pop("gpu", None)
46 | current_args.pop("note", None)
47 | current_args.pop("target_class_num", None)
48 |
49 | for dirpath, dirnames, filenames in os.walk(parent_dir):
50 | for dirname in dirnames:
51 | config_file = os.path.join(dirpath, dirname, "config.yaml")
52 | if os.path.exists(config_file):
53 | with open(config_file, "r") as file:
54 | saved_args = yaml.load(file, Loader=yaml.FullLoader)
55 |
56 | if (
57 | current_args["syndata_dir"] is None
58 | or "aug" in current_args["syndata_dir"]
59 | or "gen" in current_args["syndata_dir"]
60 | ):
61 | current_args.pop("gamma", None)
62 | saved_args.pop("gamma", None)
63 | saved_args.pop("gpu", None)
64 | saved_args.pop("note", None)
65 | saved_args.pop("target_class_num", None)
66 | if saved_args == current_args:
67 | print(
68 | f"This program has already been run in directory: {dirpath}/{dirname}"
69 | )
70 | return True
71 | return False
72 |
73 |
74 | def parse_result(target_dir, extra_column=[]):
75 | results = []
76 | for file in os.listdir(target_dir):
77 | config_file = os.path.join(target_dir, file, "config.yaml")
78 | config = yaml.safe_load(open(config_file, "r"))
79 | if isinstance(config["syndata_dir"], list):
80 | syndata_dir = config["syndata_dir"][0]
81 | else:
82 | syndata_dir = config["syndata_dir"]
83 |
84 | if syndata_dir is None:
85 | strategy = "baseline"
86 | strength = 0
87 | else:
88 | match = re.match(r"([a-zA-Z]+)([0-9.]*).*", syndata_dir)
89 | if match:
90 | strategy = match.group(1)
91 | strength = match.group(2)
92 | else:
93 | continue
94 | for basefile in os.listdir(os.path.join(target_dir, file)):
95 | if "acc_eval" in basefile:
96 | acc = float(basefile.split("_")[-1])
97 | results.append(
98 | (
99 | config["dir"],
100 | config["res_mode"],
101 | config["lr"],
102 | strategy,
103 | strength,
104 | config["gamma"],
105 | config["seed"],
106 | *[str(config.pop(key, "False")) for key in extra_column],
107 | acc,
108 | )
109 | )
110 | break
111 |
112 | df = pd.DataFrame(
113 | results,
114 | columns=[
115 | "dataset",
116 | "resolution",
117 | "lr",
118 | "strategy",
119 | "strength",
120 | "soft power",
121 | "seed",
122 | *extra_column,
123 | "acc",
124 | ],
125 | )
126 | df["acc"] = df["acc"].astype(float)
127 | result_seed = (
128 | df.groupby(
129 | [
130 | "dataset",
131 | "resolution",
132 | "lr",
133 | "strength",
134 | "strategy",
135 | "soft power",
136 | *extra_column,
137 | ]
138 | )
139 | .agg({"acc": ["mean", "var"]})
140 | .reset_index()
141 | )
142 | result_sorted = result_seed.sort_values(
143 | by=["dataset", "resolution", "lr", "strategy", "strength", *extra_column]
144 | )
145 | result_seed.columns = ["_".join(col).strip() for col in result_seed.columns.values]
146 |
147 | return result_sorted
148 |
--------------------------------------------------------------------------------
/utils/network.py:
--------------------------------------------------------------------------------
1 | def count_parameters(model):
2 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
3 | total_params = sum(p.numel() for p in model.parameters())
4 | return trainable_params, total_params
5 |
6 |
7 | def freeze_model(model, finetune_strategy="linear"):
8 | if finetune_strategy == "linear":
9 | for name, param in model.named_parameters():
10 | if "fc" in name:
11 | param.requires_grad = True
12 | else:
13 | param.requires_grad = False
14 | elif finetune_strategy == "stages4+linear":
15 | for name, param in model.named_parameters():
16 | if any(list(map(lambda x: x in name, ["layer4", "fc"]))):
17 | param.requires_grad = True
18 | else:
19 | param.requires_grad = False
20 | elif finetune_strategy == "stages3-4+linear":
21 | for name, param in model.named_parameters():
22 | if any(list(map(lambda x: x in name, ["layer3", "layer4", "fc"]))):
23 | param.requires_grad = True
24 | else:
25 | param.requires_grad = False
26 | elif finetune_strategy == "stages2-4+linear":
27 | for name, param in model.named_parameters():
28 | if any(
29 | list(map(lambda x: x in name, ["layer2", "layer3", "layer4", "fc"]))
30 | ):
31 | param.requires_grad = True
32 | else:
33 | param.requires_grad = False
34 | elif finetune_strategy == "stages1-4+linear":
35 | for name, param in model.named_parameters():
36 | if any(
37 | list(
38 | map(
39 | lambda x: x in name,
40 | ["layer1", "layer2", "layer3", "layer4", "fc"],
41 | )
42 | )
43 | ):
44 | param.requires_grad = True
45 | else:
46 | param.requires_grad = False
47 | elif finetune_strategy == "all":
48 | for name, param in model.named_parameters():
49 | param.requires_grad = True
50 | else:
51 | raise NotImplementedError(f"{finetune_strategy}")
52 |
53 | trainable_params, total_params = count_parameters(model)
54 | ratio = trainable_params / total_params
55 |
56 | print(f"{finetune_strategy}, Trainable / Total Parameter Ratio: {ratio:.4f}")
57 |
--------------------------------------------------------------------------------
/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Union
3 |
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import torch
7 | from einops import rearrange
8 | from PIL import Image
9 | from torchvision import transforms
10 | from torchvision.utils import make_grid
11 |
12 |
13 | def visualize_images(
14 | images: List[Union[Image.Image, torch.Tensor, np.ndarray]],
15 | nrow: int = 4,
16 | show=False,
17 | save=True,
18 | outpath=None,
19 | ):
20 |
21 | if isinstance(images[0], Image.Image):
22 | transform = transforms.ToTensor()
23 | images_ts = torch.stack([transform(image) for image in images])
24 | elif isinstance(images[0], torch.Tensor):
25 | images_ts = torch.stack(images)
26 | elif isinstance(images[0], np.ndarray):
27 | images_ts = torch.stack([torch.from_numpy(image) for image in images])
28 | # save images to a grid
29 | grid = make_grid(images_ts, nrow=nrow, normalize=True, scale_each=True)
30 | # set plt figure size to (4,16)
31 |
32 | if show:
33 | plt.figure(
34 | figsize=(4 * nrow, 4 * (len(images) // nrow + (len(images) % nrow > 0)))
35 | )
36 | plt.imshow(grid.permute(1, 2, 0))
37 | plt.axis("off")
38 | plt.show()
39 | # remove the axis
40 | grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
41 | img = Image.fromarray(grid.astype(np.uint8))
42 | if save:
43 | assert outpath is not None
44 | if os.path.dirname(outpath) and not os.path.exists(os.path.dirname(outpath)):
45 | os.makedirs(os.path.dirname(outpath), exist_ok=True)
46 | img.save(f"{outpath}")
47 | return img
48 |
--------------------------------------------------------------------------------
/visualization/visualize_attn_map.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 | from collections import defaultdict
5 |
6 | import matplotlib.pyplot as plt
7 | import torch
8 |
9 | from diffusers.models.attention import Attention
10 | from utils.utils import (
11 | AUGMENT_METHODS,
12 | DATASET_NAME_MAPPING,
13 | T2I_DATASET_NAME_MAPPING,
14 | finetuned_ckpt_dir,
15 | )
16 |
17 |
18 | class AttentionVisualizer:
19 | def __init__(self, model, hook_target_name):
20 | self.model = model
21 | self.hook_target_name = hook_target_name
22 | self.activation = defaultdict(list)
23 | self.hooks = []
24 |
25 | def get_attn_softmax(self, name):
26 | def hook(unet, input, kwargs, output):
27 | scale = 1.0
28 | with torch.no_grad():
29 | hidden_states = input[0]
30 | encoder_hidden_states = kwargs["encoder_hidden_states"]
31 | attention_mask = kwargs["attention_mask"]
32 | batch_size, sequence_length, _ = (
33 | hidden_states.shape
34 | if encoder_hidden_states is None
35 | else encoder_hidden_states.shape
36 | )
37 | attention_mask = unet.prepare_attention_mask(
38 | attention_mask, sequence_length, batch_size
39 | )
40 | if hasattr(unet, "preocessor"):
41 | query = unet.to_q(hidden_states) + scale * unet.processor.to_q_lora(
42 | hidden_states
43 | )
44 | query = unet.head_to_batch_dim(query)
45 |
46 | if encoder_hidden_states is None:
47 | encoder_hidden_states = hidden_states
48 | elif unet.norm_cross:
49 | encoder_hidden_states = unet.norm_encoder_hidden_states(
50 | encoder_hidden_states
51 | )
52 |
53 | key = unet.to_k(
54 | encoder_hidden_states
55 | ) + scale * unet.processor.to_k_lora(encoder_hidden_states)
56 | value = unet.to_v(
57 | encoder_hidden_states
58 | ) + scale * unet.processor.to_v_lora(encoder_hidden_states)
59 | else:
60 | query = unet.to_q(hidden_states)
61 | query = unet.head_to_batch_dim(query)
62 |
63 | if encoder_hidden_states is None:
64 | encoder_hidden_states = hidden_states
65 | elif unet.norm_cross:
66 | encoder_hidden_states = unet.norm_encoder_hidden_states(
67 | encoder_hidden_states
68 | )
69 |
70 | key = unet.to_k(encoder_hidden_states)
71 | value = unet.to_v(encoder_hidden_states)
72 |
73 | key = unet.head_to_batch_dim(key)
74 | value = unet.head_to_batch_dim(value)
75 |
76 | attention_probs = unet.get_attention_scores(query, key, attention_mask)
77 |
78 | self.activation[name].append(attention_probs)
79 |
80 | return hook
81 |
82 | def __enter__(self):
83 | unet = self.model.pipe.unet
84 | for name, module in unet.named_modules():
85 | if self.hook_target_name is not None:
86 | if self.hook_target_name == name:
87 | print("Added hook to", name)
88 | hook = module.register_forward_hook(
89 | self.get_attn_softmax(name), with_kwargs=True
90 | )
91 | self.hooks.append(hook)
92 | return self
93 |
94 | def __exit__(self, exc_type, exc_value, traceback):
95 | for hook in self.hooks:
96 | hook.remove()
97 |
98 |
99 | def plot_attn_map(attn_map, path="figures/attn_map/"):
100 | # Ensure the output directory exists
101 | os.makedirs(path, exist_ok=True)
102 |
103 | for i, attention_map in enumerate(attn_map):
104 | # Attention map is of shape [8+8, 4096, 77]
105 | num_heads, hw, num_tokens = attention_map.size()
106 |
107 | # Reshape to (num_heads, sqrt(num_tokens), sqrt(num_tokens), num_classes)
108 | H = int(math.sqrt(hw))
109 | W = int(math.sqrt(hw))
110 | vis_map = attention_map.view(num_heads, H, W, -1)
111 |
112 | # Split into unconditional and conditional attention maps
113 | uncond_attn_map, cond_attn_map = torch.chunk(vis_map, 2, dim=0)
114 |
115 | # Mean over heads [h, w, num_classes]
116 | cond_attn_map = cond_attn_map.mean(0)
117 | uncond_attn_map = uncond_attn_map.mean(0)
118 |
119 | # Plot and save attention maps
120 | fig, ax = plt.subplots(1, 10, figsize=(20, 2))
121 | for j in range(10):
122 | attn_slice = cond_attn_map[:, :, j].unsqueeze(-1).cpu().numpy()
123 | ax[j].imshow(attn_slice)
124 | ax[j].axis("off")
125 |
126 | # Save the plot
127 | save_path = os.path.join(path, f"attn_map_{i:03d}.jpg")
128 | plt.tight_layout()
129 | plt.savefig(save_path)
130 | plt.close(fig)
131 |
132 | print(f"Saved attention map at: {save_path}")
133 |
134 |
135 | def synthesize_images(
136 | model,
137 | strength,
138 | dataset="cub",
139 | finetuned_ckpt="db_ti_latest",
140 | source_label=1,
141 | target_label=2,
142 | source_image=None,
143 | seed=0,
144 | hook_target_name: str = "up_blocks.2.attentions.1.transformer_blocks.0.attn2",
145 | ):
146 |
147 | random.seed(seed)
148 | train_dataset = DATASET_NAME_MAPPING[dataset](split="train")
149 | target_indice = random.sample(train_dataset.label_to_indices[target_label], 1)[0]
150 |
151 | if source_image is None:
152 | # source_indice = random.sample(train_dataset.label_to_indices[source_label], 1)[0]
153 | source_indice = train_dataset.label_to_indices[source_label][0]
154 | source_image = train_dataset.get_image_by_idx(source_indice)
155 | target_metadata = train_dataset.get_metadata_by_idx(target_indice)
156 | with AttentionVisualizer(model, hook_target_name) as visualizer:
157 | image, _ = model(
158 | image=[source_image],
159 | label=target_label,
160 | strength=strength,
161 | metadata=target_metadata,
162 | )
163 | attn_map = visualizer.activation[hook_target_name]
164 | path = os.path.join("figures/attn_map/", dataset, finetuned_ckpt)
165 | plot_attn_map(attn_map, path=path)
166 | return image
167 |
168 |
169 | if __name__ == "__main__":
170 | dataset_list = ["cub"]
171 | aug = "diff-mix" #'diff-aug/mixup" "real-mix"
172 | finetuned_ckpt = "db_latest"
173 | guidance_scale = 7
174 | prompt = "a photo of a {name}"
175 |
176 | for dataset in dataset_list:
177 | lora_path, embed_path = finetuned_ckpt_dir(
178 | dataset=dataset, finetuned_ckpt=finetuned_ckpt
179 | )
180 | AUGMENT_METHODS[aug].pipe = None
181 | model = AUGMENT_METHODS[aug](
182 | embed_path=embed_path,
183 | lora_path=lora_path,
184 | prompt=prompt,
185 | guidance_scale=guidance_scale,
186 | mask=False,
187 | inverted=False,
188 | device=f"cuda:1",
189 | )
190 | image = synthesize_images(
191 | model,
192 | 0.5,
193 | dataset,
194 | finetuned_ckpt=finetuned_ckpt,
195 | source_label=13,
196 | target_label=2,
197 | source_image=None,
198 | seed=0,
199 | )
200 |
--------------------------------------------------------------------------------
/visualization/visualize_cases.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import sys
4 |
5 | os.environ["DISABLE_TELEMETRY"] = "YES"
6 | sys.path.append("../")
7 |
8 | from augmentation import AUGMENT_METHODS
9 | from dataset import DATASET_NAME_MAPPING
10 | from utils.misc import finetuned_ckpt_dir
11 | from utils.visualization import visualize_images
12 |
13 |
14 | def synthesize_images(
15 | model, strength, train_dataset, source_label=1, target_label=2, source_image=None
16 | ):
17 | num = 1
18 | random.seed(seed)
19 | target_indice = random.sample(train_dataset.label_to_indices[target_label], 1)[0]
20 |
21 | if source_image is None:
22 | source_indice = random.sample(train_dataset.label_to_indices[source_label], 1)[
23 | 0
24 | ]
25 | source_image = train_dataset.get_image_by_idx(source_indice)
26 | target_metadata = train_dataset.get_metadata_by_idx(target_indice)
27 | image_list = []
28 | image, _ = model(
29 | image=[source_image],
30 | label=target_label,
31 | strength=strength,
32 | metadata=target_metadata,
33 | )
34 | return image
35 |
36 |
37 | if __name__ == "__main__":
38 | device = "cuda:1"
39 | dataset = "pascal"
40 | aug = "diff-mix" #'diff-aug/mixup" "real-mix"
41 | finetuned_ckpt = "db_latest_5shot"
42 | guidance_scale = 7
43 | strength_list = [0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
44 |
45 | seed = 0
46 | random.seed(seed)
47 | source_label = 5
48 | for target_label in [4, 7, 6]:
49 | for dataset in ["pascal"]:
50 | for aug in ["diff-mix"]:
51 | train_dataset = DATASET_NAME_MAPPING[dataset](
52 | split="train", examples_per_class=5
53 | )
54 | lora_path, embed_path = finetuned_ckpt_dir(
55 | dataset=dataset, finetuned_ckpt=finetuned_ckpt
56 | )
57 |
58 | AUGMENT_METHODS[aug].pipe = None
59 | model = AUGMENT_METHODS[aug](
60 | embed_path=embed_path,
61 | lora_path=lora_path,
62 | prompt="a photo of a {name}",
63 | guidance_scale=guidance_scale,
64 | mask=False,
65 | inverted=False,
66 | device=device,
67 | )
68 |
69 | image_list = []
70 | for strength in strength_list:
71 | source_image = train_dataset.get_image_by_idx(
72 | train_dataset.label_to_indices[source_label][3]
73 | )
74 | image_list.append(
75 | synthesize_images(
76 | model,
77 | strength,
78 | train_dataset,
79 | source_label=source_label,
80 | target_label=target_label,
81 | source_image=source_image,
82 | )[0]
83 | )
84 |
85 | outpath = (
86 | f"figures/cases/{dataset}/{aug}_{source_label}_{target_label}.png"
87 | )
88 | visualize_images(
89 | image_list, nrow=6, show=False, save=True, outpath=outpath
90 | )
91 |
--------------------------------------------------------------------------------
/visualization/visualize_filtered_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import sys
4 |
5 | os.environ["DISABLE_TELEMETRY"] = "YES"
6 | sys.path.append("../")
7 |
8 | from augmentation import AUGMENT_METHODS, load_embeddings
9 | from dataset import DATASET_NAME_MAPPING, IMBALANCE_DATASET_NAME_MAPPING
10 | from dataset.base import SyntheticDataset
11 | from utils.misc import parse_synthetic_dir
12 | from utils.visualization import visualize_images
13 |
14 | if __name__ == "__main__":
15 | device = "cuda:1"
16 | dataset = "aircraft"
17 | csv_file = "meta_90-100per.csv"
18 | csv_file = "meta_0-10-per.csv"
19 | row = 5
20 | column = 2
21 | synthetic_type = "mixup_uniform"
22 | synthetic_dir = parse_synthetic_dir(dataset, synthetic_type=synthetic_type)
23 |
24 | for csv_file in ["meta_0-10per.csv", "meta_90-100per.csv"]:
25 | ds = SyntheticDataset(synthetic_dir, csv_file=csv_file)
26 | num = row * column
27 | indices = random.sample(range(len(ds)), num)
28 |
29 | image_list = []
30 | for index in indices:
31 | image_list.append(ds.get_image_by_idx(index).resize((224, 224)))
32 |
33 | outpath = f"figures/cases_filtered/{dataset}/{synthetic_type}_{csv_file}.png"
34 | visualize_images(image_list, nrow=row, show=False, save=True, outpath=outpath)
35 |
--------------------------------------------------------------------------------