├── .gitignore
├── LICENSE
├── README.md
├── environment.yml
├── figure.png
├── generation
├── DeepAugment
│ ├── CAE_Model
│ │ ├── __pycache__
│ │ │ └── cae_32x32x32_zero_pad_bin.cpython-39.pyc
│ │ └── cae_32x32x32_zero_pad_bin.py
│ ├── CAE_Weights
│ │ └── model_final.state
│ ├── CAE_distort_imagenet.py
│ ├── EDSR_Model
│ │ ├── __init__.py
│ │ ├── common.py
│ │ └── edsr.py
│ ├── EDSR_Weights
│ │ └── edsr_baseline_x4.pt
│ ├── EDSR_distort_imagenet.py
│ ├── README.md
│ ├── deepaugment.png
│ ├── train.py
│ ├── train.sh
│ └── train_noise2net.py
├── __pycache__
│ ├── conditional_ldm.cpython-39.pyc
│ ├── data.cpython-39.pyc
│ ├── generate_images.cpython-39.pyc
│ ├── generate_images_captions.cpython-39.pyc
│ └── generate_images_i2i.cpython-39.pyc
├── cin256-v2.yaml
├── classes.py
├── conditional_ldm.py
├── data.py
├── folder_to_class.csv
├── generate_images.py
├── generate_images_captions.py
├── generate_images_i2i.py
└── ldm
│ ├── __pycache__
│ └── util.cpython-39.pyc
│ ├── data
│ ├── __init__.py
│ ├── base.py
│ ├── imagenet.py
│ └── lsun.py
│ ├── lr_scheduler.py
│ ├── models
│ ├── __pycache__
│ │ └── autoencoder.cpython-39.pyc
│ ├── autoencoder.py
│ └── diffusion
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-39.pyc
│ │ ├── ddim.cpython-39.pyc
│ │ └── ddpm.cpython-39.pyc
│ │ ├── classifier.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ └── plms.py
│ ├── modules
│ ├── __pycache__
│ │ ├── attention.cpython-39.pyc
│ │ ├── ema.cpython-39.pyc
│ │ └── x_transformer.cpython-39.pyc
│ ├── attention.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-39.pyc
│ │ │ ├── model.cpython-39.pyc
│ │ │ ├── openaimodel.cpython-39.pyc
│ │ │ └── util.cpython-39.pyc
│ │ ├── model.py
│ │ ├── openaimodel.py
│ │ └── util.py
│ ├── distributions
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-39.pyc
│ │ │ └── distributions.cpython-39.pyc
│ │ └── distributions.py
│ ├── ema.py
│ ├── encoders
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-39.pyc
│ │ │ └── modules.cpython-39.pyc
│ │ └── modules.py
│ ├── image_degradation
│ │ ├── __init__.py
│ │ ├── bsrgan.py
│ │ ├── bsrgan_light.py
│ │ ├── utils
│ │ │ └── test.png
│ │ └── utils_image.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── contperceptual.py
│ │ └── vqperceptual.py
│ └── x_transformer.py
│ └── util.py
├── sd_finetune
├── data.py
├── finetune_sd.py
└── parser.py
├── train
├── get_masks.py
├── resnet_configs
│ ├── alexnet.yaml
│ ├── efficient.yaml
│ ├── inception.yaml
│ ├── mobilenet.yaml
│ ├── resnext101.yaml
│ ├── resnext50.yaml
│ ├── rn18_88_epochs.yaml
│ └── vgg16.yaml
├── train_imagenet.py
├── write_imagenet.py
└── write_imagenet.sh
└── utils
├── create_imagenet_subset.py
├── mappings
├── classes.py
├── folder_to_class.csv
├── folder_to_objectnet_label.json
├── imagenet100.txt
├── imagenet100_to_labels.json
├── imagenet_pytorch_id_to_objectnet_id.json
├── imagenet_to_label_2012_v2
├── imagenet_to_labels.json
├── objectnet_im100_folder.json
├── objectnet_im1k_folder.json
├── objectnet_to_im100.json
├── objectnet_to_imagenet_1k.json
└── pytorch_to_imagenet_2012_id.json
├── rename_imagenet_v2.py
└── subset_objectnet_im.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Build and Release Folders
2 | bin-debug/
3 | bin-release/
4 | [Oo]bj/
5 | [Bb]in/
6 |
7 | # Other files and folders
8 | .settings/
9 |
10 | # Executables
11 | *.swf
12 | *.air
13 | *.ipa
14 | *.apk
15 |
16 | # Project files, i.e. `.project`, `.actionScriptProperties` and `.flexProperties`
17 | # should NOT be excluded as they contain compiler settings and other important
18 | # information for Eclipse / Flash Builder.
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Hritikbansal
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Generative Robustness
2 |
3 | This repo contains the code for the experiments in the 'Leaving Reality to Imagination: Robust Classification via Generated Datasets' paper. Arxiv link: [https://arxiv.org/pdf/2302.02503.pdf](https://arxiv.org/pdf/2302.02503.pdf)
4 |
5 | [Colab](https://colab.research.google.com/drive/1I2IO8tD_l9JdCRJHOqlAP6ojMPq_BsoR?usp=sharing)
6 |
7 |

8 |
9 | Accepted as Oral Presentation at RTML ICLR 2023.
10 | Accepted at SPIGM ICML 2023.
11 |
12 | ## Link to Generated ImageNet-1K dataset
13 |
14 | You can download the `Base-Generated-ImageNet-1K` dataset from [here](https://drive.google.com/drive/folders/1-jLyiJ_S-VZMS5zQNR6e1xDOAkAVJxvs?usp=sharing). Even though we discuss three variants of generated data in the paper, we make generations using captions of the class labels public for novel usecases by the community.
15 |
16 | You can download the `Finetuned-Generated-ImageNet-1K` dataset from [here](https://drive.google.com/drive/folders/1xVBcCrae1JrmOCoHHNWCw5UN6XAOoupk?usp=sharing).
17 |
18 | Structure of the dataset looks like:
19 |
20 | ```
21 | * train (1000 folders)
22 | * n01440764 (1300 images)
23 | * image1.jpeg
24 | * .
25 | * imageN.jpeg
26 | * .
27 | * .
28 | * val (1000 images)
29 | * n01440764 (50 images)
30 | * image1.jpeg
31 | * .
32 | * imageN.jpeg
33 | * .
34 | * .
35 | ```
36 |
37 | ## Finetuning Stable Diffusion on the Real ImageNet-1K
38 |
39 | We provide the finetuning details and as well as the finetuned Stable Diffusion model at [https://huggingface.co/hbXNov/ucla-mint-finetune-sd-im1k](https://huggingface.co/hbXNov/ucla-mint-finetune-sd-im1k).
40 |
41 | Colab Notebook: [here](https://colab.research.google.com/drive/1I2IO8tD_l9JdCRJHOqlAP6ojMPq_BsoR?usp=sharing)
42 |
43 | The finetuning code could be found in this folder: [sd_finetune](sd_finetune). Most of this code is adopted from the `diffusers` library - [text-to-image](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image). We are really thankful to the authors!!
44 |
45 |
46 | *The rest of the README will focus on generating data from Stable Diffusion in a zero-shot manner, and training ImageNet classifiers efficiently using FFCV.*
47 |
48 | ## Data Generation Using Stable Diffusion
49 |
50 | [Stable Diffusion](https://github.com/CompVis/stable-diffusion) is a popular text-to-image generative model. Most of the code is adapted from the very popular [diffusers](https://github.com/huggingface/diffusers) library from HuggingFace.
51 |
52 | However, it might not be straightforward to generate images from Stable Diffusion on multiple GPUs. To that end, we use the [accelerate](https://huggingface.co/docs/accelerate/index) package from Huggingface.
53 |
54 | ### Requirements
55 |
56 | - Both Linux and Windows are supported, but we strongly recommend Linux for performance and compatibility reasons.
57 | - 64-bit Python 3.7+ installation.
58 | - We used 5 A6000s 24GB DRAM GPUs for generation.
59 |
60 | ### Setup
61 |
62 | ```
63 | 1. git clone https://github.com/Hritikbansal/leaving_reality_to_imagination.git
64 | 2. cd leaving_reality_to_imagination
65 | 3. conda env create -f environment.yml
66 | 4. pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 torchaudio==0.13.0+cu117 -f https://download.pytorch.org/whl/torch_stable.html (Replace based on your computer's hardware)
67 | 5. accelerate config
68 | - This machine
69 | - multi-GPU
70 | - (How many machines?) 1
71 | - (optimize with dynamo?) NO
72 | - (Deepspeed?) NO
73 | - (FullyShardedParallel?) NO
74 | - (MegatronLM) NO
75 | - (Num of GPUs) 5
76 | - (device ids) 0,1,2,3,4
77 | - (np/fp16/bp16) no
78 | ```
79 |
80 | ### Files
81 |
82 | 1. [generate_images_captions](generation/generate_images_captions.py) generates the images conditioned on the diverse text prompts (Class Labels).
83 | 2. [generate_images](generation/generate_images.py) generates the images conditioned on the images (Real Images).
84 | 3. [generate_images_i2i](generation.generate_images_i2i.py) generates the images conditioned on the encoded images and text (Real Images and Class Labels).
85 | 4. [conditional ldm](generation/conditional_ldm.py) generates images from the class-conditional latent diffusion model. You can download the model ckpt from [Stable Diffusion](https://github.com/CompVis/stable-diffusion) repo.
86 |
87 | Move the [classes.py](generation/classes.py) and [folder_to_class.csv](generation/folder_to_class.csv) to the `imagenet_dir`.
88 |
89 | ### Commands
90 |
91 | ```python
92 | accelerate launch --num_cpu_threads_per_process 8 -m generation.generate_images_captions --batch_size 8 --data_dir --save_image_gen --diversity --split val
93 | ```
94 |
95 | ```python
96 | accelerate launch --num_cpu_threads_per_process 8 -m generation.generate_images --batch_size 2 --eval_test_data_dir --save_image_gen --split val
97 | ```
98 |
99 | ```python
100 | accelerate launch --num_cpu_threads_per_process 8 -m generation.generate_images_i2i --batch_size 12 --data_dir --save_image_gen --split val --diversity
101 | ```
102 |
103 | ```python
104 | accelerate launch --num_cpu_threads_per_process 8 conditional_ldm.py --config cin256-v2.yaml --checkpoint --save_image_gen
105 | ```
106 |
107 |
108 | ## Training ImageNet Models Using FFCV
109 |
110 | We suggest the users to create a separate FFCV conda environment for training ImageNet models.
111 |
112 | ### Preparing the Dataset
113 | Following the ImageNet training pipeline of [FFCV](https://github.com/libffcv/ffcv-imagenet) for ResNet50, generate the dataset with the following command (`IMAGENET_DIR` should point to a PyTorch style [ImageNet dataset](https://github.com/MadryLab/pytorch-imagenet-dataset)):
114 |
115 | ```bash
116 | # Required environmental variables for the script:
117 | cd train/
118 | export IMAGENET_DIR=/path/to/pytorch/format/imagenet/directory/
119 | export WRITE_DIR=/your/path/here/
120 |
121 | # Serialize images with:
122 | # - 500px side length maximum
123 | # - 50% JPEG encoded, 90% raw pixel values
124 | # - quality=90 JPEGs
125 | ./write_imagenet.sh 500 0.50 90
126 | ```
127 | Note that we prepare the dataset with the following FFCV configuration:
128 | * ResNet-50 training: 50% JPEG 500px side length (*train_500_0.50_90.ffcv*)
129 | * ResNet-50 evaluation: 0% JPEG 500px side length (*val_500_uncompressed.ffcv*)
130 |
131 | - We have made some custom edits to [write_imagenet.py](train/write_imagenet.py) to generate augmented imagenet data
132 |
133 | ### Training
134 |
135 | ```python
136 | CUDA_VISIBLE_DEVICES=0,1,2,3,4 python train_imagenet.py --config-file resnet_configs/resnext50.yaml --data.train_dataset= --data.val_dataset= --data.num_workers=8 --logging.folder= --model.num_classes=100 (if imagenet 100) --training.distributed=1 --dist.world_size=5
137 | ```
138 |
139 | ### Evaluation
140 |
141 | ```python
142 | CUDA_VISIBLE_DEVICES=0,1,2,3,4 python train_imagenet.py --config-file resnet_configs/rn18_88_epochs.yaml --data.train_dataset= --data.val_dataset= --data.num_workers=8 --training.path= --model.num_classes=1000 --training.distributed=1 --dist.world_size=5 --training.eval_only=1
143 | ```
144 |
145 | ### Note
146 |
147 | 1. Since ImageNet-R and ObjectNet do not share all the classes with ImageNet1K, we use an additional `validation.imr` or `validation.obj` flag while evaluating on these datasets.
148 | 2. [create_imagenet_subset](utils/create_imagenet_subset.py) is used to create the random subset containing 100 classes. [mappings](utils/mappings/) contains the relevant `imagenet100.txt` file.
149 |
150 | ## Natural Distribution Shift Datasets
151 |
152 | We use (a) ImageNet-Sketch, (b) ImageNet-R, (c) ImageNet-V2, and (d) ObjectNet in our work. The users can navigate to their respective sources to download the data.
153 |
154 | 1. [rename_imagenetv2](utils/rename_imagenet_v2.py) renames the imagenetv2 folders that are original named based on indices 0-1000 to original imagenet folder names n0XXXXX.
155 | 2. [subset_objectnet_im](utils/subset_objectnet_im.py) is used to create a subset of ObjectNet classes that overlap with ImageNet-100/1000.
156 |
157 |
158 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: lrti
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=5.1=1_gnu
9 | - ca-certificates=2022.10.11=h06a4308_0
10 | - certifi=2022.9.24=py39h06a4308_0
11 | - cuda=11.6.2=0
12 | - cuda-cccl=11.6.55=hf6102b2_0
13 | - cuda-command-line-tools=11.6.2=0
14 | - cuda-compiler=11.6.2=0
15 | - cuda-cudart=11.6.55=he381448_0
16 | - cuda-cudart-dev=11.6.55=h42ad0f4_0
17 | - cuda-cuobjdump=11.6.124=h2eeebcb_0
18 | - cuda-cupti=11.6.124=h86345e5_0
19 | - cuda-cuxxfilt=11.6.124=hecbf4f6_0
20 | - cuda-driver-dev=11.6.55=0
21 | - cuda-gdb=12.0.90=0
22 | - cuda-libraries=11.6.2=0
23 | - cuda-libraries-dev=11.6.2=0
24 | - cuda-memcheck=11.8.86=0
25 | - cuda-nsight=12.0.78=0
26 | - cuda-nsight-compute=12.0.0=0
27 | - cuda-nvcc=11.6.124=hbba6d2d_0
28 | - cuda-nvdisasm=12.0.76=0
29 | - cuda-nvml-dev=11.6.55=haa9ef22_0
30 | - cuda-nvprof=12.0.90=0
31 | - cuda-nvprune=11.6.124=he22ec0a_0
32 | - cuda-nvrtc=11.6.124=h020bade_0
33 | - cuda-nvrtc-dev=11.6.124=h249d397_0
34 | - cuda-nvtx=11.6.124=h0630a44_0
35 | - cuda-nvvp=12.0.90=0
36 | - cuda-runtime=11.6.2=0
37 | - cuda-samples=11.6.101=h8efea70_0
38 | - cuda-sanitizer-api=12.0.90=0
39 | - cuda-toolkit=11.6.2=0
40 | - cuda-tools=11.6.2=0
41 | - cuda-visual-tools=11.6.2=0
42 | - gds-tools=1.5.0.59=0
43 | - ld_impl_linux-64=2.38=h1181459_1
44 | - libcublas=12.0.1.189=0
45 | - libcublas-dev=12.0.1.189=0
46 | - libcufft=11.0.0.21=0
47 | - libcufft-dev=11.0.0.21=0
48 | - libcufile=1.5.0.59=0
49 | - libcufile-dev=1.5.0.59=0
50 | - libcurand=10.3.1.50=0
51 | - libcurand-dev=10.3.1.50=0
52 | - libcusolver=11.4.2.57=0
53 | - libcusolver-dev=11.4.2.57=0
54 | - libcusparse=12.0.0.76=0
55 | - libcusparse-dev=12.0.0.76=0
56 | - libffi=3.4.2=h6a678d5_6
57 | - libgcc-ng=11.2.0=h1234567_1
58 | - libgomp=11.2.0=h1234567_1
59 | - libnpp=12.0.0.30=0
60 | - libnpp-dev=12.0.0.30=0
61 | - libnvjpeg=12.0.0.28=0
62 | - libnvjpeg-dev=12.0.0.28=0
63 | - libstdcxx-ng=11.2.0=h1234567_1
64 | - ncurses=6.3=h5eee18b_3
65 | - nsight-compute=2022.4.0.15=0
66 | - openssl=1.1.1s=h7f8727e_0
67 | - pip=22.3.1=py39h06a4308_0
68 | - python=3.9.15=h7a1cb2a_2
69 | - readline=8.2=h5eee18b_0
70 | - setuptools=65.5.0=py39h06a4308_0
71 | - sqlite=3.40.0=h5082296_0
72 | - tk=8.6.12=h1ccaba5_0
73 | - tzdata=2022g=h04d1e81_0
74 | - wheel=0.37.1=pyhd3eb1b0_0
75 | - xz=5.2.8=h5eee18b_0
76 | - zlib=1.2.13=h5eee18b_0
77 | - pip:
78 | - accelerate==0.15.0
79 | - beautifulsoup4==4.11.1
80 | - charset-normalizer==2.1.1
81 | - clip==0.2.0
82 | - contourpy==1.0.6
83 | - cycler==0.11.0
84 | - filelock==3.8.2
85 | - fonttools==4.38.0
86 | - ftfy==6.1.1
87 | - gdown==4.6.0
88 | - huggingface-hub==0.11.1
89 | - idna==3.4
90 | - joblib==1.2.0
91 | - kiwisolver==1.4.4
92 | - littleutils==0.2.2
93 | - matplotlib==3.6.2
94 | - numpy==1.23.5
95 | - ogb==1.3.5
96 | - outdated==0.2.2
97 | - packaging==22.0
98 | - pandas==1.5.2
99 | - pillow==9.3.0
100 | - psutil==5.9.4
101 | - pyparsing==3.0.9
102 | - pysocks==1.7.1
103 | - python-dateutil==2.8.2
104 | - pytz==2022.6
105 | - pyyaml==6.0
106 | - regex==2022.10.31
107 | - requests==2.28.1
108 | - scikit-learn==1.2.0
109 | - scipy==1.9.3
110 | - seaborn==0.12.2
111 | - six==1.16.0
112 | - soupsieve==2.3.2.post1
113 | - threadpoolctl==3.1.0
114 | - timm==0.3.4
115 | - tokenizers==0.13.2
116 | - tqdm==4.64.1
117 | - transformers==4.25.1
118 | - typing-extensions==4.4.0
119 | - urllib3==1.26.13
120 | - wcwidth==0.2.5
121 | - wilds==2.0.0
122 |
--------------------------------------------------------------------------------
/figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/figure.png
--------------------------------------------------------------------------------
/generation/DeepAugment/CAE_Model/__pycache__/cae_32x32x32_zero_pad_bin.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/DeepAugment/CAE_Model/__pycache__/cae_32x32x32_zero_pad_bin.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/DeepAugment/CAE_Model/cae_32x32x32_zero_pad_bin.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 | import random
6 |
7 |
8 | class CAE(nn.Module):
9 | """
10 | This AE module will be fed 3x128x128 patches from the original image
11 | Shapes are (batch_size, channels, height, width)
12 | Latent representation: 32x32x32 bits per patch => 240KB per image (for 720p)
13 | """
14 |
15 | def __init__(self):
16 | super(CAE, self).__init__()
17 |
18 | self.encoded = None
19 |
20 | # ENCODER
21 |
22 | # 64x64x64
23 | self.e_conv_1 = nn.Sequential(
24 | nn.ZeroPad2d((1, 2, 1, 2)),
25 | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(5, 5), stride=(2, 2)),
26 | nn.LeakyReLU()
27 | )
28 |
29 | # 128x32x32
30 | self.e_conv_2 = nn.Sequential(
31 | nn.ZeroPad2d((1, 2, 1, 2)),
32 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(5, 5), stride=(2, 2)),
33 | nn.LeakyReLU()
34 | )
35 |
36 | # 128x32x32
37 | self.e_block_1 = nn.Sequential(
38 | nn.ZeroPad2d((1, 1, 1, 1)),
39 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)),
40 | nn.LeakyReLU(),
41 |
42 | nn.ZeroPad2d((1, 1, 1, 1)),
43 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)),
44 | )
45 |
46 | # 128x32x32
47 | self.e_block_2 = nn.Sequential(
48 | nn.ZeroPad2d((1, 1, 1, 1)),
49 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)),
50 | nn.LeakyReLU(),
51 |
52 | nn.ZeroPad2d((1, 1, 1, 1)),
53 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)),
54 | )
55 |
56 | # 128x32x32
57 | self.e_block_3 = nn.Sequential(
58 | nn.ZeroPad2d((1, 1, 1, 1)),
59 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)),
60 | nn.LeakyReLU(),
61 |
62 | nn.ZeroPad2d((1, 1, 1, 1)),
63 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)),
64 | )
65 |
66 | # 32x32x32
67 | self.e_conv_3 = nn.Sequential(
68 | nn.Conv2d(in_channels=128, out_channels=32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)),
69 | nn.Tanh()
70 | )
71 |
72 | # DECODER
73 |
74 | # a
75 | self.d_up_conv_1 = nn.Sequential(
76 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), stride=(1, 1)),
77 | nn.LeakyReLU(),
78 |
79 | nn.ZeroPad2d((1, 1, 1, 1)),
80 | nn.ConvTranspose2d(in_channels=64, out_channels=128, kernel_size=(2, 2), stride=(2, 2))
81 | )
82 |
83 | # 128x64x64
84 | self.d_block_1 = nn.Sequential(
85 | nn.ZeroPad2d((1, 1, 1, 1)),
86 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)),
87 | nn.LeakyReLU(),
88 |
89 | nn.ZeroPad2d((1, 1, 1, 1)),
90 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)),
91 | )
92 |
93 | # 128x64x64
94 | self.d_block_2 = nn.Sequential(
95 | nn.ZeroPad2d((1, 1, 1, 1)),
96 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)),
97 | nn.LeakyReLU(),
98 |
99 | nn.ZeroPad2d((1, 1, 1, 1)),
100 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)),
101 | )
102 |
103 | # 128x64x64
104 | self.d_block_3 = nn.Sequential(
105 | nn.ZeroPad2d((1, 1, 1, 1)),
106 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)),
107 | nn.LeakyReLU(),
108 |
109 | nn.ZeroPad2d((1, 1, 1, 1)),
110 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)),
111 | )
112 |
113 | # 256x128x128
114 | self.d_up_conv_2 = nn.Sequential(
115 | nn.Conv2d(in_channels=128, out_channels=32, kernel_size=(3, 3), stride=(1, 1)),
116 | nn.LeakyReLU(),
117 |
118 | nn.ZeroPad2d((1, 1, 1, 1)),
119 | nn.ConvTranspose2d(in_channels=32, out_channels=256, kernel_size=(2, 2), stride=(2, 2))
120 | )
121 |
122 | # 3x128x128
123 | self.d_up_conv_3 = nn.Sequential(
124 | nn.Conv2d(in_channels=256, out_channels=16, kernel_size=(3, 3), stride=(1, 1)),
125 | nn.LeakyReLU(),
126 |
127 | nn.ReflectionPad2d((2, 2, 2, 2)),
128 | nn.Conv2d(in_channels=16, out_channels=3, kernel_size=(3, 3), stride=(1, 1)),
129 | nn.Tanh()
130 | )
131 |
132 | def forward(self, x):
133 | ec1 = self.e_conv_1(x)
134 | ec2 = self.e_conv_2(ec1)
135 | eblock1 = self.e_block_1(ec2) + ec2
136 | eblock2 = self.e_block_2(eblock1) + eblock1
137 | eblock3 = self.e_block_3(eblock2) + eblock2
138 | ec3 = self.e_conv_3(eblock3) # in [-1, 1] from tanh activation
139 | option = np.random.choice(range(9))
140 | if option == 1:
141 | # set some weights to zero
142 | H = ec3.size()[2]
143 | W = ec3.size()[3]
144 | mask = (torch.cuda.FloatTensor(H, W).uniform_() > 0.2).float().cuda()
145 | ec3 = ec3 * mask
146 | del mask
147 | elif option == 2:
148 | # negare some of the weights
149 | H = ec3.size()[2]
150 | W = ec3.size()[3]
151 | mask = (((torch.cuda.FloatTensor(H, W).uniform_() > 0.1).float() * 2) - 1).cuda()
152 | ec3 = ec3 * mask
153 | del mask
154 | elif option == 3:
155 | num_channels = 10
156 | perm = np.array(list(np.random.permutation(num_channels)) + list(range(num_channels, ec3.size()[1])))
157 | ec3 = ec3[:, perm, :, :]
158 | elif option == 4:
159 | num_channels = ec3.shape[1]
160 | num_channels_transform = 5
161 |
162 | _k = random.randint(1,3)
163 | _dims = [0, 1, 2]
164 | random.shuffle(_dims)
165 | _dims = _dims[:2]
166 |
167 | for i in range(num_channels_transform):
168 | filter_select = random.choice(list(range(num_channels)))
169 | ec3[:,filter_select] = torch.flip(ec3[:,filter_select], dims=_dims)
170 | elif option == 5:
171 | num_channels = ec3.shape[1]
172 | num_channels_transform = num_channels
173 |
174 | _k = random.randint(1,3)
175 | _dims = [0, 1, 2]
176 | random.shuffle(_dims)
177 | _dims = _dims[:2]
178 |
179 | for i in range(num_channels_transform):
180 | if i == num_channels_transform / 2:
181 | _dims = [_dims[1], _dims[0]]
182 | ec3[:,i] = torch.flip(ec3[:,i], dims=_dims)
183 | elif option == 6:
184 | with torch.no_grad():
185 | c, h, w = ec3.shape[1], ec3.shape[2], ec3.shape[3]
186 | z = torch.zeros(c, c, 3, 3).cuda()
187 | for j in range(z.size(0)):
188 | shift_x, shift_y = 1, 1# np.random.randint(3, size=(2,))
189 | z[j,j,shift_x,shift_y] = 1 # np.random.choice([1.,-1.])
190 |
191 | # Without this line, z would be the identity convolution
192 | z = z + ((torch.rand_like(z) - 0.5) * 0.2)
193 | ec3 = F.conv2d(ec3, z, padding=1)
194 | del z
195 | elif option == 7:
196 | with torch.no_grad():
197 | c, h, w = ec3.shape[1], ec3.shape[2], ec3.shape[3]
198 | z = torch.zeros(c, c, 3, 3).cuda()
199 | for j in range(z.size(0)):
200 | shift_x, shift_y = 1, 1# np.random.randint(3, size=(2,))
201 | z[j,j,shift_x,shift_y] = 1 # np.random.choice([1.,-1.])
202 |
203 | if random.random() < 0.5:
204 | rand_layer = random.randint(0, c - 1)
205 | z[j, rand_layer, random.randint(-1, 1), random.randint(-1, 1)] = 1
206 |
207 | ec3 = F.conv2d(ec3, z, padding=1)
208 | del z
209 | elif option == 8:
210 | with torch.no_grad():
211 | c, h, w = ec3.shape[1], ec3.shape[2], ec3.shape[3]
212 | z = torch.zeros(c, c, 3, 3).cuda()
213 | shift_x, shift_y = np.random.randint(3, size=(2,))
214 | for j in range(z.size(0)):
215 | if random.random() < 0.2:
216 | shift_x, shift_y = np.random.randint(3, size=(2,))
217 |
218 | z[j,j,shift_x,shift_y] = 1 # np.random.choice([1.,-1.])
219 |
220 | # Without this line, z would be the identity convolution
221 | # z = z + ((torch.rand_like(z) - 0.5) * 0.2)
222 | ec3 = F.conv2d(ec3, z, padding=1)
223 | del z
224 |
225 | # stochastic binarization
226 | with torch.no_grad():
227 | rand = torch.rand(ec3.shape).cuda()
228 | prob = (1 + ec3) / 2
229 | eps = torch.zeros(ec3.shape).cuda()
230 | eps[rand <= prob] = (1 - ec3)[rand <= prob]
231 | eps[rand > prob] = (-ec3 - 1)[rand > prob]
232 |
233 | # encoded tensor
234 | self.encoded = 0.5 * (ec3 + eps + 1) # (-1|1) -> (0|1)
235 | if option == 0:
236 | self.encoded = self.encoded *\
237 | (3 + 2 * np.float32(np.random.uniform()) * (2*torch.rand_like(self.encoded-1)))
238 | return self.decode(self.encoded)
239 |
240 | def decode(self, encoded):
241 | y = encoded * 2.0 - 1 # (0|1) -> (-1|1)
242 |
243 | uc1 = self.d_up_conv_1(y)
244 | dblock1 = self.d_block_1(uc1) + uc1
245 | dblock2 = self.d_block_2(dblock1) + dblock1
246 | dblock3 = self.d_block_3(dblock2) + dblock2
247 | uc2 = self.d_up_conv_2(dblock3)
248 | dec = self.d_up_conv_3(uc2)
249 |
250 | return dec
251 |
--------------------------------------------------------------------------------
/generation/DeepAugment/CAE_Weights/model_final.state:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/DeepAugment/CAE_Weights/model_final.state
--------------------------------------------------------------------------------
/generation/DeepAugment/EDSR_Model/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import import_module
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.parallel as P
7 | import torch.utils.model_zoo
8 |
9 | class Model(nn.Module):
10 | def __init__(self, args, ckp):
11 | super(Model, self).__init__()
12 | print('Making model...')
13 |
14 | self.scale = args.scale
15 | self.idx_scale = 0
16 | self.input_large = (args.model == 'VDSR')
17 | self.self_ensemble = args.self_ensemble
18 | self.chop = args.chop
19 | self.precision = args.precision
20 | self.cpu = args.cpu
21 | self.device = torch.device('cpu' if args.cpu else 'cuda')
22 | self.n_GPUs = args.n_GPUs
23 | self.save_models = args.save_models
24 |
25 | module = import_module('model.' + args.model.lower())
26 | self.model = module.make_model(args).to(self.device)
27 | if args.precision == 'half':
28 | self.model.half()
29 |
30 | self.load(
31 | ckp.get_path('model'),
32 | pre_train=args.pre_train,
33 | resume=args.resume,
34 | cpu=args.cpu
35 | )
36 | print(self.model, file=ckp.log_file)
37 |
38 | def forward(self, x, idx_scale):
39 | self.idx_scale = idx_scale
40 | if hasattr(self.model, 'set_scale'):
41 | self.model.set_scale(idx_scale)
42 |
43 | if self.training:
44 | if self.n_GPUs > 1:
45 | return P.data_parallel(self.model, x, range(self.n_GPUs))
46 | else:
47 | return self.model(x)
48 | else:
49 | if self.chop:
50 | forward_function = self.forward_chop
51 | else:
52 | forward_function = self.model.forward
53 |
54 | if self.self_ensemble:
55 | return self.forward_x8(x, forward_function=forward_function)
56 | else:
57 | return forward_function(x)
58 |
59 | def save(self, apath, epoch, is_best=False):
60 | save_dirs = [os.path.join(apath, 'model_latest.pt')]
61 |
62 | if is_best:
63 | save_dirs.append(os.path.join(apath, 'model_best.pt'))
64 | if self.save_models:
65 | save_dirs.append(
66 | os.path.join(apath, 'model_{}.pt'.format(epoch))
67 | )
68 |
69 | for s in save_dirs:
70 | torch.save(self.model.state_dict(), s)
71 |
72 | def load(self, apath, pre_train='', resume=-1, cpu=False):
73 | load_from = None
74 | kwargs = {}
75 | if cpu:
76 | kwargs = {'map_location': lambda storage, loc: storage}
77 |
78 | if resume == -1:
79 | load_from = torch.load(
80 | os.path.join(apath, 'model_latest.pt'),
81 | **kwargs
82 | )
83 | elif resume == 0:
84 | if pre_train == 'download':
85 | print('Download the model')
86 | dir_model = os.path.join('..', 'models')
87 | os.makedirs(dir_model, exist_ok=True)
88 | load_from = torch.utils.model_zoo.load_url(
89 | self.model.url,
90 | model_dir=dir_model,
91 | **kwargs
92 | )
93 | elif pre_train:
94 | print('Load the model from {}'.format(pre_train))
95 | load_from = torch.load(pre_train, **kwargs)
96 | else:
97 | load_from = torch.load(
98 | os.path.join(apath, 'model_{}.pt'.format(resume)),
99 | **kwargs
100 | )
101 |
102 | if load_from:
103 | self.model.load_state_dict(load_from, strict=False)
104 |
105 | def forward_chop(self, *args, shave=10, min_size=160000):
106 | scale = 1 if self.input_large else self.scale[self.idx_scale]
107 | n_GPUs = min(self.n_GPUs, 4)
108 | # height, width
109 | h, w = args[0].size()[-2:]
110 |
111 | top = slice(0, h//2 + shave)
112 | bottom = slice(h - h//2 - shave, h)
113 | left = slice(0, w//2 + shave)
114 | right = slice(w - w//2 - shave, w)
115 | x_chops = [torch.cat([
116 | a[..., top, left],
117 | a[..., top, right],
118 | a[..., bottom, left],
119 | a[..., bottom, right]
120 | ]) for a in args]
121 |
122 | y_chops = []
123 | if h * w < 4 * min_size:
124 | for i in range(0, 4, n_GPUs):
125 | x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops]
126 | y = P.data_parallel(self.model, *x, range(n_GPUs))
127 | if not isinstance(y, list): y = [y]
128 | if not y_chops:
129 | y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]
130 | else:
131 | for y_chop, _y in zip(y_chops, y):
132 | y_chop.extend(_y.chunk(n_GPUs, dim=0))
133 | else:
134 | for p in zip(*x_chops):
135 | y = self.forward_chop(*p, shave=shave, min_size=min_size)
136 | if not isinstance(y, list): y = [y]
137 | if not y_chops:
138 | y_chops = [[_y] for _y in y]
139 | else:
140 | for y_chop, _y in zip(y_chops, y): y_chop.append(_y)
141 |
142 | h *= scale
143 | w *= scale
144 | top = slice(0, h//2)
145 | bottom = slice(h - h//2, h)
146 | bottom_r = slice(h//2 - h, None)
147 | left = slice(0, w//2)
148 | right = slice(w - w//2, w)
149 | right_r = slice(w//2 - w, None)
150 |
151 | # batch size, number of color channels
152 | b, c = y_chops[0][0].size()[:-2]
153 | y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops]
154 | for y_chop, _y in zip(y_chops, y):
155 | _y[..., top, left] = y_chop[0][..., top, left]
156 | _y[..., top, right] = y_chop[1][..., top, right_r]
157 | _y[..., bottom, left] = y_chop[2][..., bottom_r, left]
158 | _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r]
159 |
160 | if len(y) == 1: y = y[0]
161 |
162 | return y
163 |
164 | def forward_x8(self, *args, forward_function=None):
165 | def _transform(v, op):
166 | if self.precision != 'single': v = v.float()
167 |
168 | v2np = v.data.cpu().numpy()
169 | if op == 'v':
170 | tfnp = v2np[:, :, :, ::-1].copy()
171 | elif op == 'h':
172 | tfnp = v2np[:, :, ::-1, :].copy()
173 | elif op == 't':
174 | tfnp = v2np.transpose((0, 1, 3, 2)).copy()
175 |
176 | ret = torch.Tensor(tfnp).to(self.device)
177 | if self.precision == 'half': ret = ret.half()
178 |
179 | return ret
180 |
181 | list_x = []
182 | for a in args:
183 | x = [a]
184 | for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x])
185 |
186 | list_x.append(x)
187 |
188 | list_y = []
189 | for x in zip(*list_x):
190 | y = forward_function(*x)
191 | if not isinstance(y, list): y = [y]
192 | if not list_y:
193 | list_y = [[_y] for _y in y]
194 | else:
195 | for _list_y, _y in zip(list_y, y): _list_y.append(_y)
196 |
197 | for _list_y in list_y:
198 | for i in range(len(_list_y)):
199 | if i > 3:
200 | _list_y[i] = _transform(_list_y[i], 't')
201 | if i % 4 > 1:
202 | _list_y[i] = _transform(_list_y[i], 'h')
203 | if (i % 4) % 2 == 1:
204 | _list_y[i] = _transform(_list_y[i], 'v')
205 |
206 | y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y]
207 | if len(y) == 1: y = y[0]
208 |
209 | return y
210 |
--------------------------------------------------------------------------------
/generation/DeepAugment/EDSR_Model/common.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
8 | return nn.Conv2d(
9 | in_channels, out_channels, kernel_size,
10 | padding=(kernel_size//2), bias=bias)
11 |
12 | class MeanShift(nn.Conv2d):
13 | def __init__(
14 | self, rgb_range,
15 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
16 |
17 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
18 | std = torch.Tensor(rgb_std)
19 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
20 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
21 | for p in self.parameters():
22 | p.requires_grad = False
23 |
24 | class BasicBlock(nn.Sequential):
25 | def __init__(
26 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,
27 | bn=True, act=nn.ReLU(True)):
28 |
29 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
30 | if bn:
31 | m.append(nn.BatchNorm2d(out_channels))
32 | if act is not None:
33 | m.append(act)
34 |
35 | super(BasicBlock, self).__init__(*m)
36 |
37 | class ResBlock(nn.Module):
38 | def __init__(
39 | self, conv, n_feats, kernel_size,
40 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
41 |
42 | super(ResBlock, self).__init__()
43 | m = []
44 | for i in range(2):
45 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
46 | if bn:
47 | m.append(nn.BatchNorm2d(n_feats))
48 | if i == 0:
49 | m.append(act)
50 |
51 | self.body = nn.Sequential(*m)
52 | self.res_scale = res_scale
53 |
54 | def forward(self, x):
55 | res = self.body(x).mul(self.res_scale)
56 | res += x
57 |
58 | return res
59 |
60 | class Upsampler(nn.Sequential):
61 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
62 |
63 | m = []
64 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
65 | for _ in range(int(math.log(scale, 2))):
66 | m.append(conv(n_feats, 4 * n_feats, 3, bias))
67 | m.append(nn.PixelShuffle(2))
68 | if bn:
69 | m.append(nn.BatchNorm2d(n_feats))
70 | if act == 'relu':
71 | m.append(nn.ReLU(True))
72 | elif act == 'prelu':
73 | m.append(nn.PReLU(n_feats))
74 |
75 | elif scale == 3:
76 | m.append(conv(n_feats, 9 * n_feats, 3, bias))
77 | m.append(nn.PixelShuffle(3))
78 | if bn:
79 | m.append(nn.BatchNorm2d(n_feats))
80 | if act == 'relu':
81 | m.append(nn.ReLU(True))
82 | elif act == 'prelu':
83 | m.append(nn.PReLU(n_feats))
84 | else:
85 | raise NotImplementedError
86 |
87 | super(Upsampler, self).__init__(*m)
88 |
89 |
--------------------------------------------------------------------------------
/generation/DeepAugment/EDSR_Model/edsr.py:
--------------------------------------------------------------------------------
1 | from model import common
2 |
3 | import torch.nn as nn
4 |
5 | url = {
6 | 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt',
7 | 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt',
8 | 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt',
9 | 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt',
10 | 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt',
11 | 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt'
12 | }
13 |
14 | def make_model(args, parent=False):
15 | return EDSR(args)
16 |
17 | class EDSR(nn.Module):
18 | def __init__(self, args, conv=common.default_conv):
19 | super(EDSR, self).__init__()
20 |
21 | n_resblocks = args.n_resblocks
22 | n_feats = args.n_feats
23 | kernel_size = 3
24 | scale = args.scale[0]
25 | act = nn.ReLU(True)
26 | url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale)
27 | if url_name in url:
28 | self.url = url[url_name]
29 | else:
30 | self.url = None
31 | self.sub_mean = common.MeanShift(args.rgb_range)
32 | self.add_mean = common.MeanShift(args.rgb_range, sign=1)
33 |
34 | # define head module
35 | m_head = [conv(args.n_colors, n_feats, kernel_size)]
36 |
37 | # define body module
38 | m_body = [
39 | common.ResBlock(
40 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
41 | ) for _ in range(n_resblocks)
42 | ]
43 | m_body.append(conv(n_feats, n_feats, kernel_size))
44 |
45 | # define tail module
46 | m_tail = [
47 | common.Upsampler(conv, scale, n_feats, act=False),
48 | conv(n_feats, args.n_colors, kernel_size)
49 | ]
50 |
51 | self.head = nn.Sequential(*m_head)
52 | self.body = nn.Sequential(*m_body)
53 | self.tail = nn.Sequential(*m_tail)
54 |
55 | def forward(self, x):
56 | x = self.sub_mean(x)
57 | x = self.head(x)
58 |
59 | res = self.body(x)
60 | res += x
61 |
62 | x = self.tail(res)
63 | x = self.add_mean(x)
64 |
65 | return x
66 |
67 | def load_state_dict(self, state_dict, strict=True):
68 | own_state = self.state_dict()
69 | for name, param in state_dict.items():
70 | if name in own_state:
71 | if isinstance(param, nn.Parameter):
72 | param = param.data
73 | try:
74 | own_state[name].copy_(param)
75 | except Exception:
76 | if name.find('tail') == -1:
77 | raise RuntimeError('While copying the parameter named {}, '
78 | 'whose dimensions in the model are {} and '
79 | 'whose dimensions in the checkpoint are {}.'
80 | .format(name, own_state[name].size(), param.size()))
81 | elif strict:
82 | if name.find('tail') == -1:
83 | raise KeyError('unexpected key "{}" in state_dict'
84 | .format(name))
85 |
86 |
--------------------------------------------------------------------------------
/generation/DeepAugment/EDSR_Weights/edsr_baseline_x4.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/DeepAugment/EDSR_Weights/edsr_baseline_x4.pt
--------------------------------------------------------------------------------
/generation/DeepAugment/README.md:
--------------------------------------------------------------------------------
1 | # DeepAugment
2 |
3 |
4 |
5 | ## DeepAugment Files
6 |
7 | Here is an overview of the files needed to create ImageNet training data augmented with DeepAugment.
8 | Alternatively, you can download our [EDSR](https://drive.google.com/file/d/1Ij_D3LuHWI4_WOlsg6dJMAPKEu10g47_/view?usp=sharing) and [CAE](https://drive.google.com/file/d/1xN9Z7pZ2GNwRww7j8ClPnPNwOc12VsM5/view?usp=sharing) images directly.
9 |
10 | ## Create Datasets
11 | - `EDSR_distort_imagenet.py`: Creates a distorted version of ImageNet using EDSR (https://arxiv.org/abs/1707.02921)
12 | - `CAE_distort_imagenet.py`: Creates a distorted version of ImageNet using a CAE (https://arxiv.org/abs/1703.00395)
13 |
14 | The above scripts can be run in parallel to speed up dataset creation (multiple workers processing different classes). For example, to split dataset creation across 5 processes, run the following in parallel:
15 | ```bash
16 | CUDA_VISIBLE_DEVICES=0 python3 EDSR_distort_imagenet.py --total-workers=5 --worker-number=0
17 | CUDA_VISIBLE_DEVICES=1 python3 EDSR_distort_imagenet.py --total-workers=5 --worker-number=1
18 | CUDA_VISIBLE_DEVICES=2 python3 EDSR_distort_imagenet.py --total-workers=5 --worker-number=2
19 | CUDA_VISIBLE_DEVICES=3 python3 EDSR_distort_imagenet.py --total-workers=5 --worker-number=3
20 | CUDA_VISIBLE_DEVICES=4 python3 EDSR_distort_imagenet.py --total-workers=5 --worker-number=4
21 | ```
22 | You will need to change the save path and original ImageNet train set path.
23 |
24 | ## DeepAugment with Noise2Net
25 |
26 | In addition to EDSR and CAE, the DeepAugment approach works with randomly sampled architectures. We call an example [Noise2Net](https://openreview.net/pdf?id=o20_NVA92tK#page=20), which can generate augmentations in memory and in parallel. See the code above.
27 |
28 | [ResNet-50 + DeepAugment (Noise2Net)](https://drive.google.com/file/d/1TlylFp-hJ5ENUNdjGBf-tLPoIIcJyPG4/view?usp=sharing)
29 |
30 | [ResNeXt-101 + AugMix + DeepAugment (Noise2Net, EDSR, CAE)](https://drive.google.com/file/d/1dJoYMma4uZwT-wLzk4Teh8dKqCVz7Tlx/view?usp=sharing)
31 |
--------------------------------------------------------------------------------
/generation/DeepAugment/deepaugment.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/DeepAugment/deepaugment.png
--------------------------------------------------------------------------------
/generation/DeepAugment/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source ~/pytorch.sh
4 |
5 | python train.py --dist-url 'tcp://127.0.0.1:32767' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 /data/imagenet
6 |
--------------------------------------------------------------------------------
/generation/__pycache__/conditional_ldm.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/__pycache__/conditional_ldm.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/__pycache__/data.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/__pycache__/data.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/__pycache__/generate_images.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/__pycache__/generate_images.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/__pycache__/generate_images_captions.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/__pycache__/generate_images_captions.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/__pycache__/generate_images_i2i.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/__pycache__/generate_images_i2i.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/cin256-v2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss
17 | use_ema: False
18 |
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 64
23 | in_channels: 3
24 | out_channels: 3
25 | model_channels: 192
26 | attention_resolutions:
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 5
36 | num_heads: 1
37 | use_spatial_transformer: true
38 | transformer_depth: 1
39 | context_dim: 512
40 |
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 3
45 | n_embed: 8192
46 | ddconfig:
47 | double_z: false
48 | z_channels: 3
49 | resolution: 256
50 | in_channels: 3
51 | out_ch: 3
52 | ch: 128
53 | ch_mult:
54 | - 1
55 | - 2
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions: []
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config:
64 | target: ldm.modules.encoders.modules.ClassEmbedder
65 | params:
66 | n_classes: 1001
67 | embed_dim: 512
68 | key: class_label
69 |
--------------------------------------------------------------------------------
/generation/conditional_ldm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import numpy as np
5 |
6 | from tqdm import tqdm
7 | from PIL import Image
8 | from data import CondLDM
9 | from einops import rearrange
10 | from omegaconf import OmegaConf
11 | from taming.models import vqgan
12 | from accelerate import Accelerator
13 | from ldm.util import instantiate_from_config
14 | from ldm.models.diffusion.ddim import DDIMSampler
15 |
16 | os.environ["NCCL_P2P_DISABLE"] = "1"
17 |
18 | parser = argparse.ArgumentParser()
19 |
20 | parser.add_argument("--batch_size", type = int, default = 24)
21 | parser.add_argument("--split", type = str, default = "test", help = "Path to eval test data")
22 | parser.add_argument("--config", type = str, default = None, help = "Path to config file")
23 | parser.add_argument("--checkpoint", type = str, default = None, help = "Path to checkpoint")
24 | parser.add_argument("--save_image_size", type = int, default = 64)
25 | parser.add_argument("--save_image_gen", type = str, default = None, help = "Path saved generated images")
26 |
27 | '''
28 | accelerate launch --num_cpu_threads_per_process 8 --main_process_port 9876 conditional_ldm.py --config cin256-v2.yaml --checkpoint /home/data/ckpts/hbansal/ldm/model.ckpt --save_image_gen /home/data/datasets/ImageNet100/CondGeneration/
29 | '''
30 | args = parser.parse_args()
31 |
32 | accelerator = Accelerator()
33 | os.makedirs(os.path.join(args.save_image_gen, args.split), exist_ok = True)
34 |
35 | def load_model_from_config(config, ckpt):
36 | print(f"Loading model from {ckpt}")
37 | pl_sd = torch.load(ckpt)#, map_location="cpu")
38 | sd = pl_sd["state_dict"]
39 | model = instantiate_from_config(config.model)
40 | m, u = model.load_state_dict(sd, strict=False)
41 | return model
42 |
43 |
44 | def get_model():
45 | config = OmegaConf.load(args.config)
46 | model = load_model_from_config(config, args.checkpoint)
47 | return model
48 |
49 | def generate_images(model, sampler, dataloader, args):
50 |
51 | model, sampler, dataloader = accelerator.prepare(model, sampler, dataloader)
52 | model = model.to(accelerator.device)
53 | model.eval()
54 |
55 | ## Hyperparameters
56 | ddim_steps = 20
57 | ddim_eta = 0.0
58 | scale = 3.0 # for unconditional guidance
59 |
60 | model = model.module
61 | with torch.no_grad():
62 | with model.ema_scope():
63 | for class_indices, class_labels, folder_names in tqdm(dataloader):
64 | for folder_name in folder_names:
65 | os.makedirs(os.path.join(args.save_image_gen, args.split, folder_name), exist_ok = True)
66 | indices = list(filter(lambda x: not os.path.exists(os.path.join(args.save_image_gen, args.split, folder_names[x], str(class_indices[x].item()) + ".png")), range(len(class_indices))))
67 | if len(indices) == 0: continue
68 | class_indices = [class_indices[i] for i in indices]
69 | class_labels = [class_labels[i] for i in indices]
70 | folder_names = [folder_names[i] for i in indices]
71 | class_labels = torch.tensor(class_labels)
72 | uc = model.get_learned_conditioning(
73 | {model.cond_stage_key: torch.tensor(len(indices)*[1000]).to(model.device)}
74 | )
75 | c = model.get_learned_conditioning({model.cond_stage_key: class_labels.to(model.device)})
76 | print(len(indices))
77 | samples_ddim, _ = sampler.sample(S=ddim_steps,
78 | conditioning=c,
79 | batch_size=len(indices),
80 | shape=[3, args.save_image_size, args.save_image_size],
81 | verbose=False,
82 | unconditional_guidance_scale=scale,
83 | unconditional_conditioning=uc,
84 | eta=ddim_eta)
85 | x_samples_ddim = model.decode_first_stage(samples_ddim)
86 | x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)
87 | samples = 255. * rearrange(x_samples_ddim, 'b c h w -> b h w c').cpu().numpy()
88 | for index in range(len(samples)):
89 | im = Image.fromarray(samples[index].astype(np.uint8))
90 | im.save(os.path.join(args.save_image_gen, args.split, folder_names[index], str(class_indices[index].item()) + ".png"))
91 | def main():
92 | model = get_model()
93 | sampler = DDIMSampler(model)
94 |
95 | dataset = CondLDM(num_images_per_class = 1300 if args.split == 'train' else 50)
96 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size)
97 |
98 | generate_images(model, sampler, dataloader, args)
99 |
100 | if __name__ == "__main__":
101 | main()
--------------------------------------------------------------------------------
/generation/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import PIL
3 | import glob
4 | import torch
5 | import torchvision
6 | import requests
7 | import random
8 | import pickle
9 | import numpy as np
10 | import pandas as pd
11 | from io import BytesIO
12 | from PIL import Image, ImageFile
13 | from torchvision import transforms
14 | from collections import defaultdict
15 | from torchvision.datasets import VisionDataset
16 | from torch.utils.data import Dataset, DataLoader
17 |
18 |
19 | ImageFile.LOAD_TRUNCATED_IMAGES = True
20 |
21 | def preprocess_image(image):
22 | w, h = image.size
23 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
24 | image = image.resize((512, 512), resample=PIL.Image.LANCZOS)
25 | image = np.array(image).astype(np.float32) / 255.0
26 | image = image[None].transpose(0, 3, 1, 2)
27 | image = torch.from_numpy(image)
28 | return 2.0 * image - 1.0
29 |
30 | class CondLDM(Dataset):
31 |
32 | def __init__(self, num_images_per_class = 1300):
33 |
34 | im100_folders = ['n02869837', 'n01749939', 'n02488291', 'n02107142', 'n13037406', 'n02091831', 'n04517823', 'n04589890', 'n03062245', 'n01773797', 'n01735189', 'n07831146', 'n07753275', 'n03085013', 'n04485082', 'n02105505', 'n01983481', 'n02788148', 'n03530642', 'n04435653', 'n02086910', 'n02859443', 'n13040303', 'n03594734', 'n02085620', 'n02099849', 'n01558993', 'n04493381', 'n02109047', 'n04111531', 'n02877765', 'n04429376', 'n02009229', 'n01978455', 'n02106550', 'n01820546', 'n01692333', 'n07714571', 'n02974003', 'n02114855', 'n03785016', 'n03764736', 'n03775546', 'n02087046', 'n07836838', 'n04099969', 'n04592741', 'n03891251', 'n02701002', 'n03379051', 'n02259212', 'n07715103', 'n03947888', 'n04026417', 'n02326432', 'n03637318', 'n01980166', 'n02113799', 'n02086240', 'n03903868', 'n02483362', 'n04127249', 'n02089973', 'n03017168', 'n02093428', 'n02804414', 'n02396427', 'n04418357', 'n02172182', 'n01729322', 'n02113978', 'n03787032', 'n02089867', 'n02119022', 'n03777754', 'n04238763', 'n02231487', 'n03032252', 'n02138441', 'n02104029', 'n03837869', 'n03494278', 'n04136333', 'n03794056', 'n03492542', 'n02018207', 'n04067472', 'n03930630', 'n03584829', 'n02123045', 'n04229816', 'n02100583', 'n03642806', 'n04336792', 'n03259280', 'n02116738', 'n02108089', 'n03424325', 'n01855672', 'n02090622']
35 | im100_classes = [452, 64, 374, 236, 993, 176, 882, 904, 503, 74, 57, 959, 953, 508, 872, 228, 122, 421, 599, 858, 157, 449, 994, 608, 151, 209, 15, 876, 246, 766, 455, 857, 131, 119, 234, 90, 45, 936, 479, 272, 665, 653, 659, 158, 960, 765, 908, 703, 407, 560, 317, 938, 724, 748, 331, 619, 120, 267, 155, 708, 368, 772, 167, 494, 180, 431, 342, 854, 305, 54, 268, 667, 166, 277, 662, 798, 313, 498, 299, 222, 682, 593, 775, 674, 592, 137, 758, 717, 606, 281, 796, 211, 620, 830, 544, 275, 242, 570, 99, 169]
36 |
37 | self.folders = im100_folders * num_images_per_class
38 | self.classes = im100_classes * num_images_per_class
39 |
40 | def __len__(self):
41 | return len(self.classes)
42 |
43 | def __getitem__(self, idx):
44 | return idx, self.classes[idx], self.folders[idx]
45 |
46 | class PromptDataset(Dataset):
47 |
48 | def __init__(self, root, split = 'train', dataset = 'imagenet', diversity = True, i2i = False, transform = None):
49 |
50 | self.root = root
51 | self.split = split
52 | config = eval(open(os.path.join(self.root, 'classes.py'), 'r').read())
53 | self.diversity = diversity
54 | self.templates = config["templates"] if self.diversity else lambda x: f'a photo of a {x}'
55 | self.folders = os.listdir(os.path.join(self.root, self.split))
56 | if dataset == 'imagenet':
57 | df = pd.read_csv(os.path.join(self.root, 'folder_to_class.csv'))
58 | df = df[['folder', 'class']]
59 | self.folder_to_class = df.set_index('folder').T.to_dict('list')
60 | elif dataset == 'cifar10':
61 | # folder names in cifar are class names
62 | self.folder_to_class = {k:k for k in self.folders}
63 |
64 | self.images = []
65 | self.classes = []
66 | for folder in self.folders:
67 | class_images = os.listdir(os.path.join(self.root, split, folder))
68 | class_images = list(map(lambda x: os.path.join(split, folder, x), class_images))
69 | self.images = self.images + class_images
70 | self.classes = self.classes + ([self.folder_to_class[folder]] * len(class_images))
71 |
72 | self.i2i = i2i
73 | self.transform = transform
74 |
75 | def __len__(self):
76 | return len(self.images)
77 |
78 | def __getitem__(self, idx):
79 | if self.diversity:
80 | index = random.randint(0, len(self.templates) - 1)
81 | caption = self.templates[index](self.classes[idx][0])
82 | else:
83 | caption = self.templates(self.classes[idx][0])
84 |
85 | image_location = self.images[idx]
86 |
87 | if self.i2i:
88 | image = Image.open(os.path.join(self.root, image_location)).convert("RGB")
89 | return self.transform(image), image_location, caption, self.classes[idx][0]
90 | else:
91 | return image_location, caption, self.classes[idx][0]
92 |
93 |
94 | class ImageDataset(Dataset):
95 | def __init__(self, root, transform, split = 'train'):
96 |
97 | self.root = root
98 | self.transform = transform
99 | config = eval(open(os.path.join(self.root, 'classes.py'), 'r').read())
100 | self.templates = config["templates"]
101 | self.folders = os.listdir(os.path.join(self.root, split))
102 | self.images = []
103 | for folder in self.folders:
104 | class_images = os.listdir(os.path.join(self.root, split, folder))
105 | class_images = list(map(lambda x: os.path.join(split, folder, x), class_images))
106 | self.images = self.images + class_images
107 |
108 | def __len__(self):
109 | return len(self.images)
110 |
111 | def __getitem__(self, idx):
112 | image = self.transform(Image.open(os.path.join(self.root, self.images[idx])).convert('RGB'))
113 | return self.images[idx], image
--------------------------------------------------------------------------------
/generation/generate_images.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import torch
4 | import argparse
5 | from tqdm import tqdm
6 | from PIL import Image
7 | from .data import ImageDataset
8 | from torchvision import transforms
9 | from accelerate import Accelerator
10 | from diffusers import StableDiffusionImageVariationPipeline
11 |
12 | parser = argparse.ArgumentParser()
13 |
14 | parser.add_argument("--batch_size", type = int, default = 2)
15 | parser.add_argument("--split", type = str, default = "train", help = "Path to eval test data")
16 | parser.add_argument("--data_dir", type = str, default = "/home/data/ImageNet1K/validation", help = "Path to eval test data")
17 | parser.add_argument("--save_image_gen", type = str, default = None, help = "Path saved generated images")
18 |
19 | args = parser.parse_args()
20 |
21 | accelerator = Accelerator()
22 | os.makedirs(args.save_image_gen, exist_ok = True)
23 |
24 | def generate_images(pipe, dataloader, args):
25 | pipe, dataloader = accelerator.prepare(pipe, dataloader)
26 | pipe = pipe.to(accelerator.device)
27 | filename = os.path.join(args.save_image_gen, 'images_variation.csv')
28 | with torch.no_grad():
29 | for image_locations, original_images in tqdm(dataloader):
30 | indices = list(filter(lambda x: not os.path.exists(os.path.join(args.save_image_gen, image_locations[x])), range(len(image_locations))))
31 | if len(indices) == 0:
32 | continue
33 | original_images = original_images[indices]
34 | image_locations = [image_locations[i] for i in indices]
35 | images = pipe(original_images, guidance_scale = 3).images
36 | for index in range(len(images)):
37 | os.makedirs(os.path.join(args.save_image_gen, os.path.dirname(image_locations[index])), exist_ok = True)
38 | images[index].save(os.path.join(args.save_image_gen, image_locations[index]))
39 |
40 | def main():
41 | model_name_path = "lambdalabs/sd-image-variations-diffusers"
42 | pipe = StableDiffusionImageVariationPipeline.from_pretrained(model_name_path, revision = "v2.0")
43 |
44 | tform = transforms.Compose([
45 | transforms.ToTensor(),
46 | transforms.Resize(
47 | (224, 224),
48 | interpolation=transforms.InterpolationMode.BICUBIC,
49 | antialias=False,
50 | ),
51 | transforms.Normalize(
52 | [0.48145466, 0.4578275, 0.40821073],
53 | [0.26862954, 0.26130258, 0.27577711]),
54 | ])
55 |
56 | dataset = ImageDataset(args.data_dir, tform, split = args.split)
57 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, shuffle = False)
58 | generate_images(pipe, dataloader, args)
59 |
60 |
61 | if __name__ == "__main__":
62 | main()
--------------------------------------------------------------------------------
/generation/generate_images_captions.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import torch
4 | import argparse
5 | import shutil
6 | import pandas as pd
7 | from tqdm import tqdm
8 | from PIL import Image
9 | from .data import PromptDataset
10 | from torchvision import transforms
11 | from accelerate import Accelerator
12 | from diffusers import StableDiffusionPipeline
13 |
14 | parser = argparse.ArgumentParser()
15 |
16 | parser.add_argument("--batch_size", type = int, default = 2)
17 | parser.add_argument("--split", type = str, default = "train", help = "Path to eval test data")
18 | parser.add_argument("--dataset", type = str, default = "imagenet", help = "dataset name")
19 | parser.add_argument("--data_dir", type = str, default = "/home/data/ImageNet1K/validation", help = "Path to eval test data")
20 | parser.add_argument("--save_image_gen", type = str, default = None, help = "Path saved generated images")
21 | parser.add_argument('--save_real', action='store_true', help='save real or not')
22 | parser.add_argument('--diversity', action='store_true', help='diverse captions or not')
23 |
24 | args = parser.parse_args()
25 | accelerator = Accelerator()
26 | os.makedirs(args.save_image_gen, exist_ok = True)
27 |
28 | filename = os.path.join(args.save_image_gen, 'train_captions.csv')
29 | def generate_images(pipe, dataloader, args):
30 | pipe, dataloader = accelerator.prepare(pipe, dataloader)
31 | pipe = pipe.to(accelerator.device)
32 | with torch.no_grad():
33 | for image_locations, captions, labels in tqdm(dataloader):
34 | indices = list(filter(lambda x: not os.path.exists(os.path.join(args.save_image_gen, image_locations[x])), range(len(image_locations))))
35 | if len(indices) == 0:
36 | continue
37 | image_locations = [image_locations[i] for i in indices]
38 | captions = [captions[i] for i in indices]
39 | labels = [labels[i] for i in indices]
40 | images = pipe(captions).images
41 | for index in range(len(images)):
42 | os.makedirs(os.path.join(args.save_image_gen, os.path.dirname(image_locations[index])), exist_ok = True)
43 | path = os.path.join(args.save_image_gen, image_locations[index])
44 | images[index].save(path)
45 | with open(filename, 'a') as csvfile:
46 | csvwriter = csv.writer(csvfile)
47 | csvwriter.writerow([path, labels[index], captions[index]])
48 | if args.save_real:
49 | real_img_path = os.path.join(args.save_image_gen, 'real')
50 | os.makedirs(real_img_path, exist_ok = True)
51 | shutil.copy(os.path.join(args.data_dir, image_locations[index]), real_img_path)
52 |
53 |
54 | def main():
55 | pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
56 | dataset = PromptDataset(args.data_dir, split = args.split, dataset = args.dataset, diversity = args.diversity)
57 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, shuffle = False)
58 | generate_images(pipe, dataloader, args)
59 |
60 | if __name__ == "__main__":
61 | main()
62 |
63 |
--------------------------------------------------------------------------------
/generation/generate_images_i2i.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import torch
4 | import shutil
5 | import argparse
6 | import numpy as np
7 | import pandas as pd
8 | from tqdm import tqdm
9 | from PIL import Image
10 | from .data import PromptDataset
11 | from torchvision import transforms
12 | from accelerate import Accelerator
13 | from diffusers import StableDiffusionImg2ImgPipeline
14 |
15 | parser = argparse.ArgumentParser()
16 |
17 | parser.add_argument("--batch_size", type = int, default = 2)
18 | parser.add_argument("--split", type = str, default = "train", help = "Path to eval test data")
19 | parser.add_argument("--dataset", type = str, default = "imagenet", help = "dataset name")
20 | parser.add_argument("--data_dir", type = str, default = "/home/data/ImageNet1K/validation", help = "Path to eval test data")
21 | parser.add_argument("--save_image_gen", type = str, default = None, help = "Path saved generated images")
22 | parser.add_argument("--save_real", type = str, default = None, help = "save real data")
23 | parser.add_argument('--diversity', action='store_true', help='diverse captions or not')
24 |
25 | args = parser.parse_args()
26 | accelerator = Accelerator()
27 | os.makedirs(args.save_image_gen, exist_ok = True)
28 |
29 | def preprocess(image):
30 | image = image.resize((512, 512), resample=Image.LANCZOS)
31 | image = np.array(image).astype(np.uint8)
32 | image = (image / 127.5 - 1.0).astype(np.float32)
33 | image = torch.from_numpy(image).permute(2, 0, 1)
34 | return image
35 |
36 | def generate_images(pipe, dataloader, args):
37 | pipe, dataloader = accelerator.prepare(pipe, dataloader)
38 | pipe = pipe.to(accelerator.device)
39 | filename = os.path.join(args.save_image_gen, 'i2i.csv')
40 | with torch.no_grad():
41 | for original_images, image_locations, captions, labels in tqdm(dataloader):
42 | indices = list(filter(lambda x: not os.path.exists(os.path.join(args.save_image_gen, image_locations[x])), range(len(image_locations))))
43 | if len(indices) == 0:
44 | continue
45 | original_images, image_locations, captions, labels = map(lambda x: [x[i] for i in indices], (original_images, image_locations, captions, labels))
46 | original_images = torch.stack(original_images).to(accelerator.device)
47 | images = pipe(prompt = captions, image = original_images, strength = 1).images
48 | for index in range(len(images)):
49 | os.makedirs(os.path.join(args.save_image_gen, os.path.dirname(image_locations[index])), exist_ok = True)
50 | path = os.path.join(args.save_image_gen, image_locations[index])
51 | images[index].save(path)
52 | with open(filename, 'a') as csvfile:
53 | csvwriter = csv.writer(csvfile)
54 | csvwriter.writerow([path, labels[index], captions[index]])
55 | if args.save_real:
56 | real_img_path = os.path.join(args.save_image_gen, 'real')
57 | os.makedirs(real_img_path, exist_ok = True)
58 | shutil.copy(os.path.join(args.data_dir, image_locations[index]), real_img_path)
59 |
60 |
61 | def main():
62 | model_id_or_path = "runwayml/stable-diffusion-v1-5"
63 | pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
64 | dataset = PromptDataset(args.data_dir, split = args.split, dataset = args.dataset, diversity = args.diversity, i2i = True, transform = preprocess)
65 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, shuffle = False)
66 | generate_images(pipe, dataloader, args)
67 |
68 | if __name__ == "__main__":
69 | main()
--------------------------------------------------------------------------------
/generation/ldm/__pycache__/util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/__pycache__/util.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/ldm/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/data/__init__.py
--------------------------------------------------------------------------------
/generation/ldm/data/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3 |
4 |
5 | class Txt2ImgIterableBaseDataset(IterableDataset):
6 | '''
7 | Define an interface to make the IterableDatasets for text2img data chainable
8 | '''
9 | def __init__(self, num_records=0, valid_ids=None, size=256):
10 | super().__init__()
11 | self.num_records = num_records
12 | self.valid_ids = valid_ids
13 | self.sample_ids = valid_ids
14 | self.size = size
15 |
16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17 |
18 | def __len__(self):
19 | return self.num_records
20 |
21 | @abstractmethod
22 | def __iter__(self):
23 | pass
--------------------------------------------------------------------------------
/generation/ldm/data/lsun.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 |
9 | class LSUNBase(Dataset):
10 | def __init__(self,
11 | txt_file,
12 | data_root,
13 | size=None,
14 | interpolation="bicubic",
15 | flip_p=0.5
16 | ):
17 | self.data_paths = txt_file
18 | self.data_root = data_root
19 | with open(self.data_paths, "r") as f:
20 | self.image_paths = f.read().splitlines()
21 | self._length = len(self.image_paths)
22 | self.labels = {
23 | "relative_file_path_": [l for l in self.image_paths],
24 | "file_path_": [os.path.join(self.data_root, l)
25 | for l in self.image_paths],
26 | }
27 |
28 | self.size = size
29 | self.interpolation = {"linear": PIL.Image.LINEAR,
30 | "bilinear": PIL.Image.BILINEAR,
31 | "bicubic": PIL.Image.BICUBIC,
32 | "lanczos": PIL.Image.LANCZOS,
33 | }[interpolation]
34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35 |
36 | def __len__(self):
37 | return self._length
38 |
39 | def __getitem__(self, i):
40 | example = dict((k, self.labels[k][i]) for k in self.labels)
41 | image = Image.open(example["file_path_"])
42 | if not image.mode == "RGB":
43 | image = image.convert("RGB")
44 |
45 | # default to score-sde preprocessing
46 | img = np.array(image).astype(np.uint8)
47 | crop = min(img.shape[0], img.shape[1])
48 | h, w, = img.shape[0], img.shape[1]
49 | img = img[(h - crop) // 2:(h + crop) // 2,
50 | (w - crop) // 2:(w + crop) // 2]
51 |
52 | image = Image.fromarray(img)
53 | if self.size is not None:
54 | image = image.resize((self.size, self.size), resample=self.interpolation)
55 |
56 | image = self.flip(image)
57 | image = np.array(image).astype(np.uint8)
58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59 | return example
60 |
61 |
62 | class LSUNChurchesTrain(LSUNBase):
63 | def __init__(self, **kwargs):
64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65 |
66 |
67 | class LSUNChurchesValidation(LSUNBase):
68 | def __init__(self, flip_p=0., **kwargs):
69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70 | flip_p=flip_p, **kwargs)
71 |
72 |
73 | class LSUNBedroomsTrain(LSUNBase):
74 | def __init__(self, **kwargs):
75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76 |
77 |
78 | class LSUNBedroomsValidation(LSUNBase):
79 | def __init__(self, flip_p=0.0, **kwargs):
80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81 | flip_p=flip_p, **kwargs)
82 |
83 |
84 | class LSUNCatsTrain(LSUNBase):
85 | def __init__(self, **kwargs):
86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87 |
88 |
89 | class LSUNCatsValidation(LSUNBase):
90 | def __init__(self, flip_p=0., **kwargs):
91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92 | flip_p=flip_p, **kwargs)
93 |
--------------------------------------------------------------------------------
/generation/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/generation/ldm/models/__pycache__/autoencoder.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/models/__pycache__/autoencoder.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/generation/ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/ldm/models/diffusion/classifier.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pytorch_lightning as pl
4 | from omegaconf import OmegaConf
5 | from torch.nn import functional as F
6 | from torch.optim import AdamW
7 | from torch.optim.lr_scheduler import LambdaLR
8 | from copy import deepcopy
9 | from einops import rearrange
10 | from glob import glob
11 | from natsort import natsorted
12 |
13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15 |
16 | __models__ = {
17 | 'class_label': EncoderUNetModel,
18 | 'segmentation': UNetModel
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | class NoisyLatentImageClassifier(pl.LightningModule):
29 |
30 | def __init__(self,
31 | diffusion_path,
32 | num_classes,
33 | ckpt_path=None,
34 | pool='attention',
35 | label_key=None,
36 | diffusion_ckpt_path=None,
37 | scheduler_config=None,
38 | weight_decay=1.e-2,
39 | log_steps=10,
40 | monitor='val/loss',
41 | *args,
42 | **kwargs):
43 | super().__init__(*args, **kwargs)
44 | self.num_classes = num_classes
45 | # get latest config of diffusion model
46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47 | self.diffusion_config = OmegaConf.load(diffusion_config).model
48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49 | self.load_diffusion()
50 |
51 | self.monitor = monitor
52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54 | self.log_steps = log_steps
55 |
56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57 | else self.diffusion_model.cond_stage_key
58 |
59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60 |
61 | if self.label_key not in __models__:
62 | raise NotImplementedError()
63 |
64 | self.load_classifier(ckpt_path, pool)
65 |
66 | self.scheduler_config = scheduler_config
67 | self.use_scheduler = self.scheduler_config is not None
68 | self.weight_decay = weight_decay
69 |
70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71 | sd = torch.load(path, map_location="cpu")
72 | if "state_dict" in list(sd.keys()):
73 | sd = sd["state_dict"]
74 | keys = list(sd.keys())
75 | for k in keys:
76 | for ik in ignore_keys:
77 | if k.startswith(ik):
78 | print("Deleting key {} from state_dict.".format(k))
79 | del sd[k]
80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81 | sd, strict=False)
82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83 | if len(missing) > 0:
84 | print(f"Missing Keys: {missing}")
85 | if len(unexpected) > 0:
86 | print(f"Unexpected Keys: {unexpected}")
87 |
88 | def load_diffusion(self):
89 | model = instantiate_from_config(self.diffusion_config)
90 | self.diffusion_model = model.eval()
91 | self.diffusion_model.train = disabled_train
92 | for param in self.diffusion_model.parameters():
93 | param.requires_grad = False
94 |
95 | def load_classifier(self, ckpt_path, pool):
96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98 | model_config.out_channels = self.num_classes
99 | if self.label_key == 'class_label':
100 | model_config.pool = pool
101 |
102 | self.model = __models__[self.label_key](**model_config)
103 | if ckpt_path is not None:
104 | print('#####################################################################')
105 | print(f'load from ckpt "{ckpt_path}"')
106 | print('#####################################################################')
107 | self.init_from_ckpt(ckpt_path)
108 |
109 | @torch.no_grad()
110 | def get_x_noisy(self, x, t, noise=None):
111 | noise = default(noise, lambda: torch.randn_like(x))
112 | continuous_sqrt_alpha_cumprod = None
113 | if self.diffusion_model.use_continuous_noise:
114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115 | # todo: make sure t+1 is correct here
116 |
117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119 |
120 | def forward(self, x_noisy, t, *args, **kwargs):
121 | return self.model(x_noisy, t)
122 |
123 | @torch.no_grad()
124 | def get_input(self, batch, k):
125 | x = batch[k]
126 | if len(x.shape) == 3:
127 | x = x[..., None]
128 | x = rearrange(x, 'b h w c -> b c h w')
129 | x = x.to(memory_format=torch.contiguous_format).float()
130 | return x
131 |
132 | @torch.no_grad()
133 | def get_conditioning(self, batch, k=None):
134 | if k is None:
135 | k = self.label_key
136 | assert k is not None, 'Needs to provide label key'
137 |
138 | targets = batch[k].to(self.device)
139 |
140 | if self.label_key == 'segmentation':
141 | targets = rearrange(targets, 'b h w c -> b c h w')
142 | for down in range(self.numd):
143 | h, w = targets.shape[-2:]
144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145 |
146 | # targets = rearrange(targets,'b c h w -> b h w c')
147 |
148 | return targets
149 |
150 | def compute_top_k(self, logits, labels, k, reduction="mean"):
151 | _, top_ks = torch.topk(logits, k, dim=1)
152 | if reduction == "mean":
153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154 | elif reduction == "none":
155 | return (top_ks == labels[:, None]).float().sum(dim=-1)
156 |
157 | def on_train_epoch_start(self):
158 | # save some memory
159 | self.diffusion_model.model.to('cpu')
160 |
161 | @torch.no_grad()
162 | def write_logs(self, loss, logits, targets):
163 | log_prefix = 'train' if self.training else 'val'
164 | log = {}
165 | log[f"{log_prefix}/loss"] = loss.mean()
166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167 | logits, targets, k=1, reduction="mean"
168 | )
169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170 | logits, targets, k=5, reduction="mean"
171 | )
172 |
173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176 | lr = self.optimizers().param_groups[0]['lr']
177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178 |
179 | def shared_step(self, batch, t=None):
180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181 | targets = self.get_conditioning(batch)
182 | if targets.dim() == 4:
183 | targets = targets.argmax(dim=1)
184 | if t is None:
185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186 | else:
187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188 | x_noisy = self.get_x_noisy(x, t)
189 | logits = self(x_noisy, t)
190 |
191 | loss = F.cross_entropy(logits, targets, reduction='none')
192 |
193 | self.write_logs(loss.detach(), logits.detach(), targets.detach())
194 |
195 | loss = loss.mean()
196 | return loss, logits, x_noisy, targets
197 |
198 | def training_step(self, batch, batch_idx):
199 | loss, *_ = self.shared_step(batch)
200 | return loss
201 |
202 | def reset_noise_accs(self):
203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205 |
206 | def on_validation_start(self):
207 | self.reset_noise_accs()
208 |
209 | @torch.no_grad()
210 | def validation_step(self, batch, batch_idx):
211 | loss, *_ = self.shared_step(batch)
212 |
213 | for t in self.noisy_acc:
214 | _, logits, _, targets = self.shared_step(batch, t)
215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217 |
218 | return loss
219 |
220 | def configure_optimizers(self):
221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222 |
223 | if self.use_scheduler:
224 | scheduler = instantiate_from_config(self.scheduler_config)
225 |
226 | print("Setting up LambdaLR scheduler...")
227 | scheduler = [
228 | {
229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230 | 'interval': 'step',
231 | 'frequency': 1
232 | }]
233 | return [optimizer], scheduler
234 |
235 | return optimizer
236 |
237 | @torch.no_grad()
238 | def log_images(self, batch, N=8, *args, **kwargs):
239 | log = dict()
240 | x = self.get_input(batch, self.diffusion_model.first_stage_key)
241 | log['inputs'] = x
242 |
243 | y = self.get_conditioning(batch)
244 |
245 | if self.label_key == 'class_label':
246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247 | log['labels'] = y
248 |
249 | if ismap(y):
250 | log['labels'] = self.diffusion_model.to_rgb(y)
251 |
252 | for step in range(self.log_steps):
253 | current_time = step * self.log_time_interval
254 |
255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256 |
257 | log[f'inputs@t{current_time}'] = x_noisy
258 |
259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260 | pred = rearrange(pred, 'b h w c -> b c h w')
261 |
262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263 |
264 | for key in log:
265 | log[key] = log[key][:N]
266 |
267 | return log
268 |
--------------------------------------------------------------------------------
/generation/ldm/models/diffusion/ddim.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from functools import partial
7 |
8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9 |
10 |
11 | class DDIMSampler(object):
12 | def __init__(self, model, schedule="linear", **kwargs):
13 | super().__init__()
14 | self.model = model
15 | self.ddpm_num_timesteps = model.num_timesteps
16 | self.schedule = schedule
17 |
18 | def register_buffer(self, name, attr):
19 | if type(attr) == torch.Tensor:
20 | if attr.device != torch.device("cuda"):
21 | attr = attr.to(torch.device("cuda"))
22 | setattr(self, name, attr)
23 |
24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
27 | alphas_cumprod = self.model.alphas_cumprod
28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
30 |
31 | self.register_buffer('betas', to_torch(self.model.betas))
32 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
33 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
34 |
35 | # calculations for diffusion q(x_t | x_{t-1}) and others
36 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
37 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
38 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
39 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
41 |
42 | # ddim sampling parameters
43 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
44 | ddim_timesteps=self.ddim_timesteps,
45 | eta=ddim_eta,verbose=verbose)
46 | self.register_buffer('ddim_sigmas', ddim_sigmas)
47 | self.register_buffer('ddim_alphas', ddim_alphas)
48 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
49 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
50 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
51 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
52 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
53 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
54 |
55 | @torch.no_grad()
56 | def sample(self,
57 | S,
58 | batch_size,
59 | shape,
60 | conditioning=None,
61 | callback=None,
62 | normals_sequence=None,
63 | img_callback=None,
64 | quantize_x0=False,
65 | eta=0.,
66 | mask=None,
67 | x0=None,
68 | temperature=1.,
69 | noise_dropout=0.,
70 | score_corrector=None,
71 | corrector_kwargs=None,
72 | verbose=True,
73 | x_T=None,
74 | log_every_t=100,
75 | unconditional_guidance_scale=1.,
76 | unconditional_conditioning=None,
77 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
78 | **kwargs
79 | ):
80 | if conditioning is not None:
81 | if isinstance(conditioning, dict):
82 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
83 | if cbs != batch_size:
84 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
85 | else:
86 | if conditioning.shape[0] != batch_size:
87 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
88 |
89 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
90 | # sampling
91 | C, H, W = shape
92 | size = (batch_size, C, H, W)
93 | print(f'Data shape for DDIM sampling is {size}, eta {eta}')
94 |
95 | samples, intermediates = self.ddim_sampling(conditioning, size,
96 | callback=callback,
97 | img_callback=img_callback,
98 | quantize_denoised=quantize_x0,
99 | mask=mask, x0=x0,
100 | ddim_use_original_steps=False,
101 | noise_dropout=noise_dropout,
102 | temperature=temperature,
103 | score_corrector=score_corrector,
104 | corrector_kwargs=corrector_kwargs,
105 | x_T=x_T,
106 | log_every_t=log_every_t,
107 | unconditional_guidance_scale=unconditional_guidance_scale,
108 | unconditional_conditioning=unconditional_conditioning,
109 | )
110 | return samples, intermediates
111 |
112 | @torch.no_grad()
113 | def ddim_sampling(self, cond, shape,
114 | x_T=None, ddim_use_original_steps=False,
115 | callback=None, timesteps=None, quantize_denoised=False,
116 | mask=None, x0=None, img_callback=None, log_every_t=100,
117 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
118 | unconditional_guidance_scale=1., unconditional_conditioning=None,):
119 | device = self.model.betas.device
120 | b = shape[0]
121 | if x_T is None:
122 | img = torch.randn(shape, device=device)
123 | else:
124 | img = x_T
125 |
126 | if timesteps is None:
127 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
128 | elif timesteps is not None and not ddim_use_original_steps:
129 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
130 | timesteps = self.ddim_timesteps[:subset_end]
131 |
132 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
133 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
134 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
135 | print(f"Running DDIM Sampling with {total_steps} timesteps")
136 |
137 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
138 |
139 | for i, step in enumerate(iterator):
140 | index = total_steps - i - 1
141 | ts = torch.full((b,), step, device=device, dtype=torch.long)
142 |
143 | if mask is not None:
144 | assert x0 is not None
145 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
146 | img = img_orig * mask + (1. - mask) * img
147 |
148 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
149 | quantize_denoised=quantize_denoised, temperature=temperature,
150 | noise_dropout=noise_dropout, score_corrector=score_corrector,
151 | corrector_kwargs=corrector_kwargs,
152 | unconditional_guidance_scale=unconditional_guidance_scale,
153 | unconditional_conditioning=unconditional_conditioning)
154 | img, pred_x0 = outs
155 | if callback: callback(i)
156 | if img_callback: img_callback(pred_x0, i)
157 |
158 | if index % log_every_t == 0 or index == total_steps - 1:
159 | intermediates['x_inter'].append(img)
160 | intermediates['pred_x0'].append(pred_x0)
161 |
162 | return img, intermediates
163 |
164 | @torch.no_grad()
165 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
166 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
167 | unconditional_guidance_scale=1., unconditional_conditioning=None):
168 | b, *_, device = *x.shape, x.device
169 |
170 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
171 | e_t = self.model.apply_model(x, t, c)
172 | else:
173 | x_in = torch.cat([x] * 2)
174 | t_in = torch.cat([t] * 2)
175 | c_in = torch.cat([unconditional_conditioning, c])
176 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
177 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
178 |
179 | if score_corrector is not None:
180 | assert self.model.parameterization == "eps"
181 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
182 |
183 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
184 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
185 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
186 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
187 | # select parameters corresponding to the currently considered timestep
188 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
189 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
190 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
191 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
192 |
193 | # current prediction for x_0
194 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
195 | if quantize_denoised:
196 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
197 | # direction pointing to x_t
198 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
199 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
200 | if noise_dropout > 0.:
201 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
202 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
203 | return x_prev, pred_x0
204 |
--------------------------------------------------------------------------------
/generation/ldm/models/diffusion/plms.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from functools import partial
7 |
8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9 |
10 |
11 | class PLMSSampler(object):
12 | def __init__(self, model, schedule="linear", **kwargs):
13 | super().__init__()
14 | self.model = model
15 | self.ddpm_num_timesteps = model.num_timesteps
16 | self.schedule = schedule
17 |
18 | def register_buffer(self, name, attr):
19 | if type(attr) == torch.Tensor:
20 | if attr.device != torch.device("cuda"):
21 | attr = attr.to(torch.device("cuda"))
22 | setattr(self, name, attr)
23 |
24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25 | if ddim_eta != 0:
26 | raise ValueError('ddim_eta must be 0 for PLMS')
27 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
28 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
29 | alphas_cumprod = self.model.alphas_cumprod
30 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
31 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
32 |
33 | self.register_buffer('betas', to_torch(self.model.betas))
34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
36 |
37 | # calculations for diffusion q(x_t | x_{t-1}) and others
38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
43 |
44 | # ddim sampling parameters
45 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
46 | ddim_timesteps=self.ddim_timesteps,
47 | eta=ddim_eta,verbose=verbose)
48 | self.register_buffer('ddim_sigmas', ddim_sigmas)
49 | self.register_buffer('ddim_alphas', ddim_alphas)
50 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
51 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
52 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
53 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
54 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
55 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
56 |
57 | @torch.no_grad()
58 | def sample(self,
59 | S,
60 | batch_size,
61 | shape,
62 | conditioning=None,
63 | callback=None,
64 | normals_sequence=None,
65 | img_callback=None,
66 | quantize_x0=False,
67 | eta=0.,
68 | mask=None,
69 | x0=None,
70 | temperature=1.,
71 | noise_dropout=0.,
72 | score_corrector=None,
73 | corrector_kwargs=None,
74 | verbose=True,
75 | x_T=None,
76 | log_every_t=100,
77 | unconditional_guidance_scale=1.,
78 | unconditional_conditioning=None,
79 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
80 | **kwargs
81 | ):
82 | if conditioning is not None:
83 | if isinstance(conditioning, dict):
84 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
85 | if cbs != batch_size:
86 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87 | else:
88 | if conditioning.shape[0] != batch_size:
89 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
90 |
91 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
92 | # sampling
93 | C, H, W = shape
94 | size = (batch_size, C, H, W)
95 | print(f'Data shape for PLMS sampling is {size}')
96 |
97 | samples, intermediates = self.plms_sampling(conditioning, size,
98 | callback=callback,
99 | img_callback=img_callback,
100 | quantize_denoised=quantize_x0,
101 | mask=mask, x0=x0,
102 | ddim_use_original_steps=False,
103 | noise_dropout=noise_dropout,
104 | temperature=temperature,
105 | score_corrector=score_corrector,
106 | corrector_kwargs=corrector_kwargs,
107 | x_T=x_T,
108 | log_every_t=log_every_t,
109 | unconditional_guidance_scale=unconditional_guidance_scale,
110 | unconditional_conditioning=unconditional_conditioning,
111 | )
112 | return samples, intermediates
113 |
114 | @torch.no_grad()
115 | def plms_sampling(self, cond, shape,
116 | x_T=None, ddim_use_original_steps=False,
117 | callback=None, timesteps=None, quantize_denoised=False,
118 | mask=None, x0=None, img_callback=None, log_every_t=100,
119 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
120 | unconditional_guidance_scale=1., unconditional_conditioning=None,):
121 | device = self.model.betas.device
122 | b = shape[0]
123 | if x_T is None:
124 | img = torch.randn(shape, device=device)
125 | else:
126 | img = x_T
127 |
128 | if timesteps is None:
129 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
130 | elif timesteps is not None and not ddim_use_original_steps:
131 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
132 | timesteps = self.ddim_timesteps[:subset_end]
133 |
134 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
135 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
136 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
137 | print(f"Running PLMS Sampling with {total_steps} timesteps")
138 |
139 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
140 | old_eps = []
141 |
142 | for i, step in enumerate(iterator):
143 | index = total_steps - i - 1
144 | ts = torch.full((b,), step, device=device, dtype=torch.long)
145 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
146 |
147 | if mask is not None:
148 | assert x0 is not None
149 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
150 | img = img_orig * mask + (1. - mask) * img
151 |
152 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
153 | quantize_denoised=quantize_denoised, temperature=temperature,
154 | noise_dropout=noise_dropout, score_corrector=score_corrector,
155 | corrector_kwargs=corrector_kwargs,
156 | unconditional_guidance_scale=unconditional_guidance_scale,
157 | unconditional_conditioning=unconditional_conditioning,
158 | old_eps=old_eps, t_next=ts_next)
159 | img, pred_x0, e_t = outs
160 | old_eps.append(e_t)
161 | if len(old_eps) >= 4:
162 | old_eps.pop(0)
163 | if callback: callback(i)
164 | if img_callback: img_callback(pred_x0, i)
165 |
166 | if index % log_every_t == 0 or index == total_steps - 1:
167 | intermediates['x_inter'].append(img)
168 | intermediates['pred_x0'].append(pred_x0)
169 |
170 | return img, intermediates
171 |
172 | @torch.no_grad()
173 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
174 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
175 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
176 | b, *_, device = *x.shape, x.device
177 |
178 | def get_model_output(x, t):
179 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
180 | e_t = self.model.apply_model(x, t, c)
181 | else:
182 | x_in = torch.cat([x] * 2)
183 | t_in = torch.cat([t] * 2)
184 | c_in = torch.cat([unconditional_conditioning, c])
185 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
186 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
187 |
188 | if score_corrector is not None:
189 | assert self.model.parameterization == "eps"
190 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
191 |
192 | return e_t
193 |
194 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
195 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
196 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
197 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
198 |
199 | def get_x_prev_and_pred_x0(e_t, index):
200 | # select parameters corresponding to the currently considered timestep
201 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
202 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
203 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
204 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
205 |
206 | # current prediction for x_0
207 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
208 | if quantize_denoised:
209 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
210 | # direction pointing to x_t
211 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
212 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
213 | if noise_dropout > 0.:
214 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
215 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
216 | return x_prev, pred_x0
217 |
218 | e_t = get_model_output(x, t)
219 | if len(old_eps) == 0:
220 | # Pseudo Improved Euler (2nd order)
221 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
222 | e_t_next = get_model_output(x_prev, t_next)
223 | e_t_prime = (e_t + e_t_next) / 2
224 | elif len(old_eps) == 1:
225 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
226 | e_t_prime = (3 * e_t - old_eps[-1]) / 2
227 | elif len(old_eps) == 2:
228 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
229 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
230 | elif len(old_eps) >= 3:
231 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
232 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
233 |
234 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
235 |
236 | return x_prev, pred_x0, e_t
237 |
--------------------------------------------------------------------------------
/generation/ldm/modules/__pycache__/attention.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/__pycache__/attention.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/ldm/modules/__pycache__/ema.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/__pycache__/ema.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/ldm/modules/__pycache__/x_transformer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/__pycache__/x_transformer.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/ldm/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, einsum
6 | from einops import rearrange, repeat
7 |
8 | from ldm.modules.diffusionmodules.util import checkpoint
9 |
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def uniq(arr):
16 | return{el: True for el in arr}.keys()
17 |
18 |
19 | def default(val, d):
20 | if exists(val):
21 | return val
22 | return d() if isfunction(d) else d
23 |
24 |
25 | def max_neg_value(t):
26 | return -torch.finfo(t.dtype).max
27 |
28 |
29 | def init_(tensor):
30 | dim = tensor.shape[-1]
31 | std = 1 / math.sqrt(dim)
32 | tensor.uniform_(-std, std)
33 | return tensor
34 |
35 |
36 | # feedforward
37 | class GEGLU(nn.Module):
38 | def __init__(self, dim_in, dim_out):
39 | super().__init__()
40 | self.proj = nn.Linear(dim_in, dim_out * 2)
41 |
42 | def forward(self, x):
43 | x, gate = self.proj(x).chunk(2, dim=-1)
44 | return x * F.gelu(gate)
45 |
46 |
47 | class FeedForward(nn.Module):
48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49 | super().__init__()
50 | inner_dim = int(dim * mult)
51 | dim_out = default(dim_out, dim)
52 | project_in = nn.Sequential(
53 | nn.Linear(dim, inner_dim),
54 | nn.GELU()
55 | ) if not glu else GEGLU(dim, inner_dim)
56 |
57 | self.net = nn.Sequential(
58 | project_in,
59 | nn.Dropout(dropout),
60 | nn.Linear(inner_dim, dim_out)
61 | )
62 |
63 | def forward(self, x):
64 | return self.net(x)
65 |
66 |
67 | def zero_module(module):
68 | """
69 | Zero out the parameters of a module and return it.
70 | """
71 | for p in module.parameters():
72 | p.detach().zero_()
73 | return module
74 |
75 |
76 | def Normalize(in_channels):
77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78 |
79 |
80 | class LinearAttention(nn.Module):
81 | def __init__(self, dim, heads=4, dim_head=32):
82 | super().__init__()
83 | self.heads = heads
84 | hidden_dim = dim_head * heads
85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87 |
88 | def forward(self, x):
89 | b, c, h, w = x.shape
90 | qkv = self.to_qkv(x)
91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92 | k = k.softmax(dim=-1)
93 | context = torch.einsum('bhdn,bhen->bhde', k, v)
94 | out = torch.einsum('bhde,bhdn->bhen', context, q)
95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96 | return self.to_out(out)
97 |
98 |
99 | class SpatialSelfAttention(nn.Module):
100 | def __init__(self, in_channels):
101 | super().__init__()
102 | self.in_channels = in_channels
103 |
104 | self.norm = Normalize(in_channels)
105 | self.q = torch.nn.Conv2d(in_channels,
106 | in_channels,
107 | kernel_size=1,
108 | stride=1,
109 | padding=0)
110 | self.k = torch.nn.Conv2d(in_channels,
111 | in_channels,
112 | kernel_size=1,
113 | stride=1,
114 | padding=0)
115 | self.v = torch.nn.Conv2d(in_channels,
116 | in_channels,
117 | kernel_size=1,
118 | stride=1,
119 | padding=0)
120 | self.proj_out = torch.nn.Conv2d(in_channels,
121 | in_channels,
122 | kernel_size=1,
123 | stride=1,
124 | padding=0)
125 |
126 | def forward(self, x):
127 | h_ = x
128 | h_ = self.norm(h_)
129 | q = self.q(h_)
130 | k = self.k(h_)
131 | v = self.v(h_)
132 |
133 | # compute attention
134 | b,c,h,w = q.shape
135 | q = rearrange(q, 'b c h w -> b (h w) c')
136 | k = rearrange(k, 'b c h w -> b c (h w)')
137 | w_ = torch.einsum('bij,bjk->bik', q, k)
138 |
139 | w_ = w_ * (int(c)**(-0.5))
140 | w_ = torch.nn.functional.softmax(w_, dim=2)
141 |
142 | # attend to values
143 | v = rearrange(v, 'b c h w -> b c (h w)')
144 | w_ = rearrange(w_, 'b i j -> b j i')
145 | h_ = torch.einsum('bij,bjk->bik', v, w_)
146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147 | h_ = self.proj_out(h_)
148 |
149 | return x+h_
150 |
151 |
152 | class CrossAttention(nn.Module):
153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
154 | super().__init__()
155 | inner_dim = dim_head * heads
156 | context_dim = default(context_dim, query_dim)
157 |
158 | self.scale = dim_head ** -0.5
159 | self.heads = heads
160 |
161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164 |
165 | self.to_out = nn.Sequential(
166 | nn.Linear(inner_dim, query_dim),
167 | nn.Dropout(dropout)
168 | )
169 |
170 | def forward(self, x, context=None, mask=None):
171 | h = self.heads
172 |
173 | q = self.to_q(x)
174 | context = default(context, x)
175 | k = self.to_k(context)
176 | v = self.to_v(context)
177 |
178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
179 |
180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
181 |
182 | if exists(mask):
183 | mask = rearrange(mask, 'b ... -> b (...)')
184 | max_neg_value = -torch.finfo(sim.dtype).max
185 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
186 | sim.masked_fill_(~mask, max_neg_value)
187 |
188 | # attention, what we cannot get enough of
189 | attn = sim.softmax(dim=-1)
190 |
191 | out = einsum('b i j, b j d -> b i d', attn, v)
192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
193 | return self.to_out(out)
194 |
195 |
196 | class BasicTransformerBlock(nn.Module):
197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
198 | super().__init__()
199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
203 | self.norm1 = nn.LayerNorm(dim)
204 | self.norm2 = nn.LayerNorm(dim)
205 | self.norm3 = nn.LayerNorm(dim)
206 | self.checkpoint = checkpoint
207 |
208 | def forward(self, x, context=None):
209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
210 |
211 | def _forward(self, x, context=None):
212 | x = self.attn1(self.norm1(x)) + x
213 | x = self.attn2(self.norm2(x), context=context) + x
214 | x = self.ff(self.norm3(x)) + x
215 | return x
216 |
217 |
218 | class SpatialTransformer(nn.Module):
219 | """
220 | Transformer block for image-like data.
221 | First, project the input (aka embedding)
222 | and reshape to b, t, d.
223 | Then apply standard transformer action.
224 | Finally, reshape to image
225 | """
226 | def __init__(self, in_channels, n_heads, d_head,
227 | depth=1, dropout=0., context_dim=None):
228 | super().__init__()
229 | self.in_channels = in_channels
230 | inner_dim = n_heads * d_head
231 | self.norm = Normalize(in_channels)
232 |
233 | self.proj_in = nn.Conv2d(in_channels,
234 | inner_dim,
235 | kernel_size=1,
236 | stride=1,
237 | padding=0)
238 |
239 | self.transformer_blocks = nn.ModuleList(
240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
241 | for d in range(depth)]
242 | )
243 |
244 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
245 | in_channels,
246 | kernel_size=1,
247 | stride=1,
248 | padding=0))
249 |
250 | def forward(self, x, context=None):
251 | # note: if no context is given, cross-attention defaults to self-attention
252 | b, c, h, w = x.shape
253 | x_in = x
254 | x = self.norm(x)
255 | x = self.proj_in(x)
256 | x = rearrange(x, 'b c h w -> b (h w) c')
257 | for block in self.transformer_blocks:
258 | x = block(x, context=context)
259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
260 | x = self.proj_out(x)
261 | return x + x_in
--------------------------------------------------------------------------------
/generation/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/generation/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/ldm/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from ldm.util import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (
24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25 | )
26 |
27 | elif schedule == "cosine":
28 | timesteps = (
29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30 | )
31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
32 | alphas = torch.cos(alphas).pow(2)
33 | alphas = alphas / alphas[0]
34 | betas = 1 - alphas[1:] / alphas[:-1]
35 | betas = np.clip(betas, a_min=0, a_max=0.999)
36 |
37 | elif schedule == "sqrt_linear":
38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39 | elif schedule == "sqrt":
40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41 | else:
42 | raise ValueError(f"schedule '{schedule}' unknown.")
43 | return betas.numpy()
44 |
45 |
46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47 | if ddim_discr_method == 'uniform':
48 | c = num_ddpm_timesteps // num_ddim_timesteps
49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50 | elif ddim_discr_method == 'quad':
51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52 | else:
53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54 |
55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
57 | steps_out = ddim_timesteps + 1
58 | if verbose:
59 | print(f'Selected timesteps for ddim sampler: {steps_out}')
60 | return steps_out
61 |
62 |
63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64 | # select alphas for computing the variance schedule
65 | alphas = alphacums[ddim_timesteps]
66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67 |
68 | # according the the formula provided in https://arxiv.org/abs/2010.02502
69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70 | if verbose:
71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72 | print(f'For the chosen value of eta, which is {eta}, '
73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74 | return sigmas, alphas, alphas_prev
75 |
76 |
77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78 | """
79 | Create a beta schedule that discretizes the given alpha_t_bar function,
80 | which defines the cumulative product of (1-beta) over time from t = [0,1].
81 | :param num_diffusion_timesteps: the number of betas to produce.
82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83 | produces the cumulative product of (1-beta) up to that
84 | part of the diffusion process.
85 | :param max_beta: the maximum beta to use; use values lower than 1 to
86 | prevent singularities.
87 | """
88 | betas = []
89 | for i in range(num_diffusion_timesteps):
90 | t1 = i / num_diffusion_timesteps
91 | t2 = (i + 1) / num_diffusion_timesteps
92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93 | return np.array(betas)
94 |
95 |
96 | def extract_into_tensor(a, t, x_shape):
97 | b, *_ = t.shape
98 | out = a.gather(-1, t)
99 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100 |
101 |
102 | def checkpoint(func, inputs, params, flag):
103 | """
104 | Evaluate a function without caching intermediate activations, allowing for
105 | reduced memory at the expense of extra compute in the backward pass.
106 | :param func: the function to evaluate.
107 | :param inputs: the argument sequence to pass to `func`.
108 | :param params: a sequence of parameters `func` depends on but does not
109 | explicitly take as arguments.
110 | :param flag: if False, disable gradient checkpointing.
111 | """
112 | if flag:
113 | args = tuple(inputs) + tuple(params)
114 | return CheckpointFunction.apply(func, len(inputs), *args)
115 | else:
116 | return func(*inputs)
117 |
118 |
119 | class CheckpointFunction(torch.autograd.Function):
120 | @staticmethod
121 | def forward(ctx, run_function, length, *args):
122 | ctx.run_function = run_function
123 | ctx.input_tensors = list(args[:length])
124 | ctx.input_params = list(args[length:])
125 |
126 | with torch.no_grad():
127 | output_tensors = ctx.run_function(*ctx.input_tensors)
128 | return output_tensors
129 |
130 | @staticmethod
131 | def backward(ctx, *output_grads):
132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133 | with torch.enable_grad():
134 | # Fixes a bug where the first op in run_function modifies the
135 | # Tensor storage in place, which is not allowed for detach()'d
136 | # Tensors.
137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138 | output_tensors = ctx.run_function(*shallow_copies)
139 | input_grads = torch.autograd.grad(
140 | output_tensors,
141 | ctx.input_tensors + ctx.input_params,
142 | output_grads,
143 | allow_unused=True,
144 | )
145 | del ctx.input_tensors
146 | del ctx.input_params
147 | del output_tensors
148 | return (None, None) + input_grads
149 |
150 |
151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
152 | """
153 | Create sinusoidal timestep embeddings.
154 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
155 | These may be fractional.
156 | :param dim: the dimension of the output.
157 | :param max_period: controls the minimum frequency of the embeddings.
158 | :return: an [N x dim] Tensor of positional embeddings.
159 | """
160 | if not repeat_only:
161 | half = dim // 2
162 | freqs = torch.exp(
163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
164 | ).to(device=timesteps.device)
165 | args = timesteps[:, None].float() * freqs[None]
166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
167 | if dim % 2:
168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
169 | else:
170 | embedding = repeat(timesteps, 'b -> b d', d=dim)
171 | return embedding
172 |
173 |
174 | def zero_module(module):
175 | """
176 | Zero out the parameters of a module and return it.
177 | """
178 | for p in module.parameters():
179 | p.detach().zero_()
180 | return module
181 |
182 |
183 | def scale_module(module, scale):
184 | """
185 | Scale the parameters of a module and return it.
186 | """
187 | for p in module.parameters():
188 | p.detach().mul_(scale)
189 | return module
190 |
191 |
192 | def mean_flat(tensor):
193 | """
194 | Take the mean over all non-batch dimensions.
195 | """
196 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
197 |
198 |
199 | def normalization(channels):
200 | """
201 | Make a standard normalization layer.
202 | :param channels: number of input channels.
203 | :return: an nn.Module for normalization.
204 | """
205 | return GroupNorm32(32, channels)
206 |
207 |
208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
209 | class SiLU(nn.Module):
210 | def forward(self, x):
211 | return x * torch.sigmoid(x)
212 |
213 |
214 | class GroupNorm32(nn.GroupNorm):
215 | def forward(self, x):
216 | return super().forward(x.float()).type(x.dtype)
217 |
218 | def conv_nd(dims, *args, **kwargs):
219 | """
220 | Create a 1D, 2D, or 3D convolution module.
221 | """
222 | if dims == 1:
223 | return nn.Conv1d(*args, **kwargs)
224 | elif dims == 2:
225 | return nn.Conv2d(*args, **kwargs)
226 | elif dims == 3:
227 | return nn.Conv3d(*args, **kwargs)
228 | raise ValueError(f"unsupported dimensions: {dims}")
229 |
230 |
231 | def linear(*args, **kwargs):
232 | """
233 | Create a linear module.
234 | """
235 | return nn.Linear(*args, **kwargs)
236 |
237 |
238 | def avg_pool_nd(dims, *args, **kwargs):
239 | """
240 | Create a 1D, 2D, or 3D average pooling module.
241 | """
242 | if dims == 1:
243 | return nn.AvgPool1d(*args, **kwargs)
244 | elif dims == 2:
245 | return nn.AvgPool2d(*args, **kwargs)
246 | elif dims == 3:
247 | return nn.AvgPool3d(*args, **kwargs)
248 | raise ValueError(f"unsupported dimensions: {dims}")
249 |
250 |
251 | class HybridConditioner(nn.Module):
252 |
253 | def __init__(self, c_concat_config, c_crossattn_config):
254 | super().__init__()
255 | self.concat_conditioner = instantiate_from_config(c_concat_config)
256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
257 |
258 | def forward(self, c_concat, c_crossattn):
259 | c_concat = self.concat_conditioner(c_concat)
260 | c_crossattn = self.crossattn_conditioner(c_crossattn)
261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
262 |
263 |
264 | def noise_like(shape, device, repeat=False):
265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
266 | noise = lambda: torch.randn(shape, device=device)
267 | return repeat_noise() if repeat else noise()
--------------------------------------------------------------------------------
/generation/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/generation/ldm/modules/distributions/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/distributions/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/ldm/modules/distributions/__pycache__/distributions.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/distributions/__pycache__/distributions.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/generation/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/generation/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/generation/ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/ldm/modules/encoders/__pycache__/modules.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/encoders/__pycache__/modules.cpython-39.pyc
--------------------------------------------------------------------------------
/generation/ldm/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 | import clip
5 | from einops import rearrange, repeat
6 | import kornia
7 |
8 |
9 | from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
10 |
11 |
12 | class AbstractEncoder(nn.Module):
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def encode(self, *args, **kwargs):
17 | raise NotImplementedError
18 |
19 |
20 |
21 | class ClassEmbedder(nn.Module):
22 | def __init__(self, embed_dim, n_classes=1000, key='class'):
23 | super().__init__()
24 | self.key = key
25 | self.embedding = nn.Embedding(n_classes, embed_dim)
26 |
27 | def forward(self, batch, key=None):
28 | if key is None:
29 | key = self.key
30 | # this is for use in crossattn
31 | c = batch[key][:, None]
32 | c = self.embedding(c)
33 | return c
34 |
35 |
36 | class TransformerEmbedder(AbstractEncoder):
37 | """Some transformer encoder layers"""
38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
39 | super().__init__()
40 | self.device = device
41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
42 | attn_layers=Encoder(dim=n_embed, depth=n_layer))
43 |
44 | def forward(self, tokens):
45 | tokens = tokens.to(self.device) # meh
46 | z = self.transformer(tokens, return_embeddings=True)
47 | return z
48 |
49 | def encode(self, x):
50 | return self(x)
51 |
52 |
53 | class BERTTokenizer(AbstractEncoder):
54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
55 | def __init__(self, device="cuda", vq_interface=True, max_length=77):
56 | super().__init__()
57 | from transformers import BertTokenizerFast # TODO: add to reuquirements
58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
59 | self.device = device
60 | self.vq_interface = vq_interface
61 | self.max_length = max_length
62 |
63 | def forward(self, text):
64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
66 | tokens = batch_encoding["input_ids"].to(self.device)
67 | return tokens
68 |
69 | @torch.no_grad()
70 | def encode(self, text):
71 | tokens = self(text)
72 | if not self.vq_interface:
73 | return tokens
74 | return None, None, [None, None, tokens]
75 |
76 | def decode(self, text):
77 | return text
78 |
79 |
80 | class BERTEmbedder(AbstractEncoder):
81 | """Uses the BERT tokenizr model and add some transformer encoder layers"""
82 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
83 | device="cuda",use_tokenizer=True, embedding_dropout=0.0):
84 | super().__init__()
85 | self.use_tknz_fn = use_tokenizer
86 | if self.use_tknz_fn:
87 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
88 | self.device = device
89 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
90 | attn_layers=Encoder(dim=n_embed, depth=n_layer),
91 | emb_dropout=embedding_dropout)
92 |
93 | def forward(self, text):
94 | if self.use_tknz_fn:
95 | tokens = self.tknz_fn(text)#.to(self.device)
96 | else:
97 | tokens = text
98 | z = self.transformer(tokens, return_embeddings=True)
99 | return z
100 |
101 | def encode(self, text):
102 | # output of length 77
103 | return self(text)
104 |
105 |
106 | class SpatialRescaler(nn.Module):
107 | def __init__(self,
108 | n_stages=1,
109 | method='bilinear',
110 | multiplier=0.5,
111 | in_channels=3,
112 | out_channels=None,
113 | bias=False):
114 | super().__init__()
115 | self.n_stages = n_stages
116 | assert self.n_stages >= 0
117 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
118 | self.multiplier = multiplier
119 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
120 | self.remap_output = out_channels is not None
121 | if self.remap_output:
122 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
123 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
124 |
125 | def forward(self,x):
126 | for stage in range(self.n_stages):
127 | x = self.interpolator(x, scale_factor=self.multiplier)
128 |
129 |
130 | if self.remap_output:
131 | x = self.channel_mapper(x)
132 | return x
133 |
134 | def encode(self, x):
135 | return self(x)
136 |
137 |
138 | class FrozenCLIPTextEmbedder(nn.Module):
139 | """
140 | Uses the CLIP transformer encoder for text.
141 | """
142 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
143 | super().__init__()
144 | self.model, _ = clip.load(version, jit=False, device="cpu")
145 | self.device = device
146 | self.max_length = max_length
147 | self.n_repeat = n_repeat
148 | self.normalize = normalize
149 |
150 | def freeze(self):
151 | self.model = self.model.eval()
152 | for param in self.parameters():
153 | param.requires_grad = False
154 |
155 | def forward(self, text):
156 | tokens = clip.tokenize(text).to(self.device)
157 | z = self.model.encode_text(tokens)
158 | if self.normalize:
159 | z = z / torch.linalg.norm(z, dim=1, keepdim=True)
160 | return z
161 |
162 | def encode(self, text):
163 | z = self(text)
164 | if z.ndim==2:
165 | z = z[:, None, :]
166 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
167 | return z
168 |
169 |
170 | class FrozenClipImageEmbedder(nn.Module):
171 | """
172 | Uses the CLIP image encoder.
173 | """
174 | def __init__(
175 | self,
176 | model,
177 | jit=False,
178 | device='cuda' if torch.cuda.is_available() else 'cpu',
179 | antialias=False,
180 | ):
181 | super().__init__()
182 | self.model, _ = clip.load(name=model, device=device, jit=jit)
183 |
184 | self.antialias = antialias
185 |
186 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
187 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
188 |
189 | def preprocess(self, x):
190 | # normalize to [0,1]
191 | x = kornia.geometry.resize(x, (224, 224),
192 | interpolation='bicubic',align_corners=True,
193 | antialias=self.antialias)
194 | x = (x + 1.) / 2.
195 | # renormalize according to clip
196 | x = kornia.enhance.normalize(x, self.mean, self.std)
197 | return x
198 |
199 | def forward(self, x):
200 | # x is assumed to be in range [-1,1]
201 | return self.model.encode_image(self.preprocess(x))
202 |
203 |
--------------------------------------------------------------------------------
/generation/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/generation/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/generation/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/generation/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/generation/ldm/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 |
11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15 | loss_real = (weights * loss_real).sum() / weights.sum()
16 | loss_fake = (weights * loss_fake).sum() / weights.sum()
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 | def adopt_weight(weight, global_step, threshold=0, value=0.):
21 | if global_step < threshold:
22 | weight = value
23 | return weight
24 |
25 |
26 | def measure_perplexity(predicted_indices, n_embed):
27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30 | avg_probs = encodings.mean(0)
31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32 | cluster_use = torch.sum(avg_probs > 0)
33 | return perplexity, cluster_use
34 |
35 | def l1(x, y):
36 | return torch.abs(x-y)
37 |
38 |
39 | def l2(x, y):
40 | return torch.pow((x-y), 2)
41 |
42 |
43 | class VQLPIPSWithDiscriminator(nn.Module):
44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48 | pixel_loss="l1"):
49 | super().__init__()
50 | assert disc_loss in ["hinge", "vanilla"]
51 | assert perceptual_loss in ["lpips", "clips", "dists"]
52 | assert pixel_loss in ["l1", "l2"]
53 | self.codebook_weight = codebook_weight
54 | self.pixel_weight = pixelloss_weight
55 | if perceptual_loss == "lpips":
56 | print(f"{self.__class__.__name__}: Running with LPIPS.")
57 | self.perceptual_loss = LPIPS().eval()
58 | else:
59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60 | self.perceptual_weight = perceptual_weight
61 |
62 | if pixel_loss == "l1":
63 | self.pixel_loss = l1
64 | else:
65 | self.pixel_loss = l2
66 |
67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68 | n_layers=disc_num_layers,
69 | use_actnorm=use_actnorm,
70 | ndf=disc_ndf
71 | ).apply(weights_init)
72 | self.discriminator_iter_start = disc_start
73 | if disc_loss == "hinge":
74 | self.disc_loss = hinge_d_loss
75 | elif disc_loss == "vanilla":
76 | self.disc_loss = vanilla_d_loss
77 | else:
78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80 | self.disc_factor = disc_factor
81 | self.discriminator_weight = disc_weight
82 | self.disc_conditional = disc_conditional
83 | self.n_classes = n_classes
84 |
85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86 | if last_layer is not None:
87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89 | else:
90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92 |
93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95 | d_weight = d_weight * self.discriminator_weight
96 | return d_weight
97 |
98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
100 | if not exists(codebook_loss):
101 | codebook_loss = torch.tensor([0.]).to(inputs.device)
102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
104 | if self.perceptual_weight > 0:
105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
106 | rec_loss = rec_loss + self.perceptual_weight * p_loss
107 | else:
108 | p_loss = torch.tensor([0.0])
109 |
110 | nll_loss = rec_loss
111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
112 | nll_loss = torch.mean(nll_loss)
113 |
114 | # now the GAN part
115 | if optimizer_idx == 0:
116 | # generator update
117 | if cond is None:
118 | assert not self.disc_conditional
119 | logits_fake = self.discriminator(reconstructions.contiguous())
120 | else:
121 | assert self.disc_conditional
122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
123 | g_loss = -torch.mean(logits_fake)
124 |
125 | try:
126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
127 | except RuntimeError:
128 | assert not self.training
129 | d_weight = torch.tensor(0.0)
130 |
131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
133 |
134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
136 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
137 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
138 | "{}/p_loss".format(split): p_loss.detach().mean(),
139 | "{}/d_weight".format(split): d_weight.detach(),
140 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
141 | "{}/g_loss".format(split): g_loss.detach().mean(),
142 | }
143 | if predicted_indices is not None:
144 | assert self.n_classes is not None
145 | with torch.no_grad():
146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
147 | log[f"{split}/perplexity"] = perplexity
148 | log[f"{split}/cluster_usage"] = cluster_usage
149 | return loss, log
150 |
151 | if optimizer_idx == 1:
152 | # second pass for discriminator update
153 | if cond is None:
154 | logits_real = self.discriminator(inputs.contiguous().detach())
155 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
156 | else:
157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
159 |
160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
162 |
163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
164 | "{}/logits_real".format(split): logits_real.detach().mean(),
165 | "{}/logits_fake".format(split): logits_fake.detach().mean()
166 | }
167 | return d_loss, log
168 |
--------------------------------------------------------------------------------
/generation/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 | from einops import rearrange
7 | from functools import partial
8 |
9 | import multiprocessing as mp
10 | from threading import Thread
11 | from queue import Queue
12 |
13 | from inspect import isfunction
14 | from PIL import Image, ImageDraw, ImageFont
15 |
16 |
17 | def log_txt_as_img(wh, xc, size=10):
18 | # wh a tuple of (width, height)
19 | # xc a list of captions to plot
20 | b = len(xc)
21 | txts = list()
22 | for bi in range(b):
23 | txt = Image.new("RGB", wh, color="white")
24 | draw = ImageDraw.Draw(txt)
25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
26 | nc = int(40 * (wh[0] / 256))
27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
28 |
29 | try:
30 | draw.text((0, 0), lines, fill="black", font=font)
31 | except UnicodeEncodeError:
32 | print("Cant encode string for logging. Skipping.")
33 |
34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
35 | txts.append(txt)
36 | txts = np.stack(txts)
37 | txts = torch.tensor(txts)
38 | return txts
39 |
40 |
41 | def ismap(x):
42 | if not isinstance(x, torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] > 3)
45 |
46 |
47 | def isimage(x):
48 | if not isinstance(x, torch.Tensor):
49 | return False
50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51 |
52 |
53 | def exists(x):
54 | return x is not None
55 |
56 |
57 | def default(val, d):
58 | if exists(val):
59 | return val
60 | return d() if isfunction(d) else d
61 |
62 |
63 | def mean_flat(tensor):
64 | """
65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
66 | Take the mean over all non-batch dimensions.
67 | """
68 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
69 |
70 |
71 | def count_params(model, verbose=False):
72 | total_params = sum(p.numel() for p in model.parameters())
73 | if verbose:
74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75 | return total_params
76 |
77 |
78 | def instantiate_from_config(config):
79 | if not "target" in config:
80 | if config == '__is_first_stage__':
81 | return None
82 | elif config == "__is_unconditional__":
83 | return None
84 | raise KeyError("Expected key `target` to instantiate.")
85 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
86 |
87 |
88 | def get_obj_from_str(string, reload=False):
89 | module, cls = string.rsplit(".", 1)
90 | if reload:
91 | module_imp = importlib.import_module(module)
92 | importlib.reload(module_imp)
93 | return getattr(importlib.import_module(module, package=None), cls)
94 |
95 |
96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
97 | # create dummy dataset instance
98 |
99 | # run prefetching
100 | if idx_to_fn:
101 | res = func(data, worker_id=idx)
102 | else:
103 | res = func(data)
104 | Q.put([idx, res])
105 | Q.put("Done")
106 |
107 |
108 | def parallel_data_prefetch(
109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
110 | ):
111 | # if target_data_type not in ["ndarray", "list"]:
112 | # raise ValueError(
113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
114 | # )
115 | if isinstance(data, np.ndarray) and target_data_type == "list":
116 | raise ValueError("list expected but function got ndarray.")
117 | elif isinstance(data, abc.Iterable):
118 | if isinstance(data, dict):
119 | print(
120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
121 | )
122 | data = list(data.values())
123 | if target_data_type == "ndarray":
124 | data = np.asarray(data)
125 | else:
126 | data = list(data)
127 | else:
128 | raise TypeError(
129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
130 | )
131 |
132 | if cpu_intensive:
133 | Q = mp.Queue(1000)
134 | proc = mp.Process
135 | else:
136 | Q = Queue(1000)
137 | proc = Thread
138 | # spawn processes
139 | if target_data_type == "ndarray":
140 | arguments = [
141 | [func, Q, part, i, use_worker_id]
142 | for i, part in enumerate(np.array_split(data, n_proc))
143 | ]
144 | else:
145 | step = (
146 | int(len(data) / n_proc + 1)
147 | if len(data) % n_proc != 0
148 | else int(len(data) / n_proc)
149 | )
150 | arguments = [
151 | [func, Q, part, i, use_worker_id]
152 | for i, part in enumerate(
153 | [data[i: i + step] for i in range(0, len(data), step)]
154 | )
155 | ]
156 | processes = []
157 | for i in range(n_proc):
158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
159 | processes += [p]
160 |
161 | # start processes
162 | print(f"Start prefetching...")
163 | import time
164 |
165 | start = time.time()
166 | gather_res = [[] for _ in range(n_proc)]
167 | try:
168 | for p in processes:
169 | p.start()
170 |
171 | k = 0
172 | while k < n_proc:
173 | # get result
174 | res = Q.get()
175 | if res == "Done":
176 | k += 1
177 | else:
178 | gather_res[res[0]] = res[1]
179 |
180 | except Exception as e:
181 | print("Exception: ", e)
182 | for p in processes:
183 | p.terminate()
184 |
185 | raise e
186 | finally:
187 | for p in processes:
188 | p.join()
189 | print(f"Prefetching complete. [{time.time() - start} sec.]")
190 |
191 | if target_data_type == 'ndarray':
192 | if not isinstance(gather_res[0], np.ndarray):
193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
194 |
195 | # order outputs
196 | return np.concatenate(gather_res, axis=0)
197 | elif target_data_type == 'list':
198 | out = []
199 | for r in gather_res:
200 | out.extend(r)
201 | return out
202 | else:
203 | return gather_res
204 |
--------------------------------------------------------------------------------
/sd_finetune/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import PIL
3 | import glob
4 | import torch
5 | import torchvision
6 | import requests
7 | import random
8 | import pickle
9 | import numpy as np
10 | import pandas as pd
11 | from io import BytesIO
12 | from PIL import Image, ImageFile
13 | from torchvision import transforms
14 | from collections import defaultdict
15 | from torchvision.datasets import VisionDataset
16 | from torch.utils.data import Dataset, DataLoader
17 |
18 |
19 | ImageFile.LOAD_TRUNCATED_IMAGES = True
20 |
21 |
22 |
23 | class ImageCaptionDataset(Dataset):
24 |
25 | def __init__(self, filename, tokenizer, image_transform, caption_key = 'caption', image_key = 'image'):
26 |
27 | df = pd.read_csv(filename)
28 |
29 | self.transform = image_transform
30 | self.captions = df[caption_key].tolist()
31 | self.tokenizer = tokenizer
32 | self.images = df[image_key].tolist()
33 |
34 | self.t_captions= tokenizer(self.captions, max_length = tokenizer.model_max_length, padding = 'max_length', truncation = True, return_tensors = 'pt')
35 |
36 | def __len__(self):
37 | return len(self.images)
38 |
39 | def __getitem__(self, index):
40 |
41 | item = {}
42 | item['image_location'] = self.images[index]
43 | item['caption'] = self.captions[index]
44 | item['input_ids'] = self.t_captions['input_ids'][index]
45 | item['pixel_values'] = self.transform(Image.open(self.images[index]).convert('RGB'))
46 |
47 | return item
--------------------------------------------------------------------------------
/sd_finetune/parser.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | def parse_args():
5 |
6 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
7 | parser.add_argument(
8 | "--pretrained_model_name_or_path",
9 | type=str,
10 | default="runwayml/stable-diffusion-v1-5",
11 | help="Path to pretrained model or model identifier from huggingface.co/models.",
12 | )
13 | parser.add_argument(
14 | "--revision",
15 | type=str,
16 | default=None,
17 | required=False,
18 | help="Revision of pretrained model identifier from huggingface.co/models.",
19 | )
20 | parser.add_argument(
21 | "--dataset_name",
22 | type=str,
23 | default=None,
24 | help=(
25 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
26 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
27 | " or to a folder containing files that 🤗 Datasets can understand."
28 | ),
29 | )
30 | parser.add_argument(
31 | "--dataset_config_name",
32 | type=str,
33 | default=None,
34 | help="The config of the Dataset, leave as None if there's only one config.",
35 | )
36 | parser.add_argument(
37 | "--train_data_file",
38 | type=str,
39 | default=None
40 | )
41 | parser.add_argument(
42 | "--image_column", type=str, default="image", help="The column of the dataset containing an image."
43 | )
44 | parser.add_argument(
45 | "--caption_column",
46 | type=str,
47 | default="text",
48 | help="The column of the dataset containing a caption or a list of captions.",
49 | )
50 | parser.add_argument(
51 | "--max_train_samples",
52 | type=int,
53 | default=None,
54 | help=(
55 | "For debugging purposes or quicker training, truncate the number of training examples to this "
56 | "value if set."
57 | ),
58 | )
59 | parser.add_argument(
60 | "--output_dir",
61 | type=str,
62 | default="sd-model-finetuned",
63 | help="The output directory where the model predictions and checkpoints will be written.",
64 | )
65 | parser.add_argument(
66 | "--cache_dir",
67 | type=str,
68 | default=None,
69 | help="The directory where the downloaded models and datasets will be stored.",
70 | )
71 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
72 | parser.add_argument(
73 | "--resolution",
74 | type=int,
75 | default=512,
76 | help=(
77 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
78 | " resolution"
79 | ),
80 | )
81 | parser.add_argument(
82 | "--center_crop",
83 | default=False,
84 | action="store_true",
85 | help=(
86 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
87 | " cropped. The images will be resized to the resolution first before cropping."
88 | ),
89 | )
90 | parser.add_argument(
91 | "--random_flip",
92 | action="store_true",
93 | help="whether to randomly flip images horizontally",
94 | )
95 | parser.add_argument(
96 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
97 | )
98 | parser.add_argument("--num_train_epochs", type=int, default=100)
99 | parser.add_argument(
100 | "--max_train_steps",
101 | type=int,
102 | default=None,
103 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
104 | )
105 | parser.add_argument(
106 | "--gradient_accumulation_steps",
107 | type=int,
108 | default=1,
109 | help="Number of updates steps to accumulate before performing a backward/update pass.",
110 | )
111 | parser.add_argument(
112 | "--gradient_checkpointing",
113 | action="store_true",
114 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
115 | )
116 | parser.add_argument(
117 | "--learning_rate",
118 | type=float,
119 | default=1e-4,
120 | help="Initial learning rate (after the potential warmup period) to use.",
121 | )
122 | parser.add_argument(
123 | "--scale_lr",
124 | action="store_true",
125 | default=False,
126 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
127 | )
128 | parser.add_argument(
129 | "--lr_scheduler",
130 | type=str,
131 | default="constant",
132 | help=(
133 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
134 | ' "constant", "constant_with_warmup"]'
135 | ),
136 | )
137 | parser.add_argument(
138 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
139 | )
140 | parser.add_argument(
141 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
142 | )
143 | parser.add_argument(
144 | "--allow_tf32",
145 | action="store_true",
146 | help=(
147 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
148 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
149 | ),
150 | )
151 | parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
152 | parser.add_argument(
153 | "--non_ema_revision",
154 | type=str,
155 | default=None,
156 | required=False,
157 | help=(
158 | "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
159 | " remote repository specified with --pretrained_model_name_or_path."
160 | ),
161 | )
162 | parser.add_argument(
163 | "--dataloader_num_workers",
164 | type=int,
165 | default=0,
166 | help=(
167 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
168 | ),
169 | )
170 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
171 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
172 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
173 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
174 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
175 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
176 |
177 | parser.add_argument("--freeze_unet", action="store_true", help="Whether or not to freeze the unet.")
178 |
179 | parser.add_argument("--wandb", action="store_true", help="Whether or not to log using wandb")
180 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
181 | parser.add_argument(
182 | "--hub_model_id",
183 | type=str,
184 | default=None,
185 | help="The name of the repository to keep in sync with the local `output_dir`.",
186 | )
187 | parser.add_argument(
188 | "--logging_dir",
189 | type=str,
190 | default="logs",
191 | help=(
192 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
193 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
194 | ),
195 | )
196 | parser.add_argument(
197 | "--mixed_precision",
198 | type=str,
199 | default=None,
200 | choices=["no", "fp16", "bf16"],
201 | help=(
202 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
203 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
204 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
205 | ),
206 | )
207 | parser.add_argument(
208 | "--report_to",
209 | type=str,
210 | default=None,
211 | help=(
212 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
213 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
214 | ),
215 | )
216 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
217 | parser.add_argument(
218 | "--checkpointing_steps",
219 | type=int,
220 | default=500,
221 | help=(
222 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
223 | " training using `--resume_from_checkpoint`."
224 | ),
225 | )
226 | parser.add_argument(
227 | "--checkpoints_total_limit",
228 | type=int,
229 | default=None,
230 | help=(
231 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
232 | " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
233 | " for more docs"
234 | ),
235 | )
236 | parser.add_argument(
237 | "--resume_from_checkpoint",
238 | type=str,
239 | default=None,
240 | help=(
241 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
242 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
243 | ),
244 | )
245 | parser.add_argument(
246 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
247 | )
248 |
249 | args = parser.parse_args()
250 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
251 | if env_local_rank != -1 and env_local_rank != args.local_rank:
252 | args.local_rank = env_local_rank
253 |
254 | # # Sanity checks
255 | # if args.dataset_name is None and args.train_data_dir is None:
256 | # raise ValueError("Need either a dataset name or a training folder.")
257 |
258 | # default to using the same revision for the non-ema model if not specified
259 | if args.non_ema_revision is None:
260 | args.non_ema_revision = args.revision
261 |
262 | return args
--------------------------------------------------------------------------------
/train/resnet_configs/alexnet.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | gpu: 0
3 | in_memory: 1
4 | num_workers: 12
5 | dist:
6 | world_size: 2
7 | logging:
8 | folder: /tmp/
9 | log_level: 0
10 | lr:
11 | lr: 0.5
12 | lr_peak_epoch: 2
13 | lr_schedule_type: cyclic
14 | model:
15 | arch: alexnet
16 | resolution:
17 | end_ramp: 76
18 | max_res: 192
19 | min_res: 160
20 | start_ramp: 65
21 | training:
22 | batch_size: 128
23 | bn_wd: 0
24 | distributed: 0
25 | epochs: 56
26 | label_smoothing: 0.1
27 | momentum: 0.9
28 | optimizer: sgd
29 | weight_decay: 5e-5
30 | use_blurpool: 1
31 | validation:
32 | lr_tta: true
33 | resolution: 256
--------------------------------------------------------------------------------
/train/resnet_configs/efficient.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | gpu: 0
3 | in_memory: 1
4 | num_workers: 12
5 | dist:
6 | world_size: 2
7 | logging:
8 | folder: /tmp/
9 | log_level: 0
10 | lr:
11 | lr: 0.5
12 | lr_peak_epoch: 2
13 | lr_schedule_type: cyclic
14 | model:
15 | arch: efficientnet_b0
16 | resolution:
17 | end_ramp: 76
18 | max_res: 192
19 | min_res: 160
20 | start_ramp: 65
21 | training:
22 | batch_size: 128
23 | bn_wd: 0
24 | distributed: 0
25 | epochs: 56
26 | label_smoothing: 0.1
27 | momentum: 0.9
28 | optimizer: sgd
29 | weight_decay: 5e-5
30 | use_blurpool: 1
31 | validation:
32 | lr_tta: true
33 | resolution: 256
--------------------------------------------------------------------------------
/train/resnet_configs/inception.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | gpu: 0
3 | in_memory: 1
4 | num_workers: 12
5 | dist:
6 | world_size: 2
7 | logging:
8 | folder: /tmp/
9 | log_level: 0
10 | lr:
11 | lr: 0.5
12 | lr_peak_epoch: 2
13 | lr_schedule_type: cyclic
14 | model:
15 | arch: inception
16 | resolution:
17 | end_ramp: 76
18 | max_res: 192
19 | min_res: 160
20 | start_ramp: 65
21 | training:
22 | batch_size: 128
23 | bn_wd: 0
24 | distributed: 0
25 | epochs: 40
26 | label_smoothing: 0.1
27 | momentum: 0.9
28 | optimizer: sgd
29 | weight_decay: 5e-5
30 | use_blurpool: 1
31 | validation:
32 | lr_tta: true
33 | resolution: 256
--------------------------------------------------------------------------------
/train/resnet_configs/mobilenet.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | gpu: 0
3 | in_memory: 1
4 | num_workers: 12
5 | dist:
6 | world_size: 2
7 | logging:
8 | folder: /tmp/
9 | log_level: 0
10 | lr:
11 | lr: 0.5
12 | lr_peak_epoch: 2
13 | lr_schedule_type: cyclic
14 | model:
15 | arch: mobilenet_v2
16 | resolution:
17 | end_ramp: 76
18 | max_res: 192
19 | min_res: 160
20 | start_ramp: 65
21 | training:
22 | batch_size: 128
23 | bn_wd: 0
24 | distributed: 0
25 | epochs: 40
26 | label_smoothing: 0.1
27 | momentum: 0.9
28 | optimizer: sgd
29 | weight_decay: 5e-5
30 | use_blurpool: 1
31 | validation:
32 | lr_tta: true
33 | resolution: 256
--------------------------------------------------------------------------------
/train/resnet_configs/resnext101.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | gpu: 0
3 | in_memory: 1
4 | num_workers: 12
5 | dist:
6 | world_size: 2
7 | logging:
8 | folder: /tmp/
9 | log_level: 0
10 | lr:
11 | lr: 0.5
12 | lr_peak_epoch: 2
13 | lr_schedule_type: cyclic
14 | model:
15 | arch: resnext101_32x8d
16 | resolution:
17 | end_ramp: 76
18 | max_res: 192
19 | min_res: 160
20 | start_ramp: 65
21 | training:
22 | batch_size: 64
23 | bn_wd: 0
24 | distributed: 0
25 | epochs: 56
26 | label_smoothing: 0.1
27 | momentum: 0.9
28 | optimizer: sgd
29 | weight_decay: 5e-5
30 | use_blurpool: 1
31 | validation:
32 | lr_tta: true
33 | resolution: 256
--------------------------------------------------------------------------------
/train/resnet_configs/resnext50.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | gpu: 0
3 | in_memory: 1
4 | num_workers: 12
5 | dist:
6 | world_size: 3
7 | logging:
8 | folder: /tmp/
9 | log_level: 0
10 | lr:
11 | lr: 0.5
12 | lr_peak_epoch: 2
13 | lr_schedule_type: cyclic
14 | model:
15 | arch: resnext50_32x4d
16 | resolution:
17 | end_ramp: 76
18 | max_res: 192
19 | min_res: 160
20 | start_ramp: 65
21 | training:
22 | batch_size: 128
23 | bn_wd: 0
24 | distributed: 0
25 | epochs: 56
26 | label_smoothing: 0.1
27 | momentum: 0.9
28 | optimizer: sgd
29 | weight_decay: 5e-5
30 | use_blurpool: 1
31 | validation:
32 | lr_tta: true
33 | resolution: 256
--------------------------------------------------------------------------------
/train/resnet_configs/rn18_88_epochs.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | gpu: 0
3 | in_memory: 1
4 | num_workers: 12
5 | dist:
6 | world_size: 2
7 | logging:
8 | folder: /tmp/
9 | log_level: 0
10 | lr:
11 | lr: 0.5
12 | lr_peak_epoch: 2
13 | lr_schedule_type: cyclic
14 | model:
15 | arch: resnet18
16 | resolution:
17 | end_ramp: 76
18 | max_res: 192
19 | min_res: 160
20 | start_ramp: 65
21 | training:
22 | batch_size: 128
23 | bn_wd: 0
24 | distributed: 0
25 | epochs: 56
26 | label_smoothing: 0.1
27 | momentum: 0.9
28 | optimizer: sgd
29 | weight_decay: 5e-5
30 | use_blurpool: 1
31 | validation:
32 | lr_tta: true
33 | resolution: 256
--------------------------------------------------------------------------------
/train/resnet_configs/vgg16.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | gpu: 0
3 | in_memory: 1
4 | num_workers: 12
5 | dist:
6 | world_size: 2
7 | logging:
8 | folder: /tmp/
9 | log_level: 0
10 | lr:
11 | lr: 0.5
12 | lr_peak_epoch: 2
13 | lr_schedule_type: cyclic
14 | model:
15 | arch: vgg16_bn
16 | resolution:
17 | end_ramp: 76
18 | max_res: 192
19 | min_res: 160
20 | start_ramp: 65
21 | training:
22 | batch_size: 64
23 | bn_wd: 0
24 | distributed: 0
25 | epochs: 40
26 | label_smoothing: 0.1
27 | momentum: 0.9
28 | optimizer: sgd
29 | weight_decay: 5e-5
30 | use_blurpool: 1
31 | validation:
32 | lr_tta: true
33 | resolution: 256
--------------------------------------------------------------------------------
/train/write_imagenet.py:
--------------------------------------------------------------------------------
1 | '''
2 | This code is directly taken from FFCV-Imagenet https://github.com/libffcv/ffcv-imagenet
3 | '''
4 | from torch.utils.data import Subset, ConcatDataset
5 | from ffcv.writer import DatasetWriter
6 | from ffcv.fields import IntField, RGBImageField
7 | from torchvision.datasets import CIFAR10, ImageFolder
8 |
9 | from argparse import ArgumentParser
10 | from fastargs import Section, Param
11 | from fastargs.validation import And, OneOf
12 | from fastargs.decorators import param, section
13 | from fastargs import get_current_config
14 |
15 | from customdataset import CustomizeDataset
16 |
17 | Section('cfg', 'arguments to give the writer').params(
18 | dataset=Param(And(str, OneOf(['cifar', 'imagenet'])), 'Which dataset to write', default='imagenet'),
19 | split=Param(And(str, OneOf(['train', 'val'])), 'Train or val set', required=True),
20 | data_dir=Param(str, 'Where to find the PyTorch dataset', required=True),
21 | write_path=Param(str, 'Where to write the new dataset', required=True),
22 | write_mode=Param(str, 'Mode: raw, smart or jpg', required=False, default='smart'),
23 | max_resolution=Param(int, 'Max image side length', required=True),
24 | num_workers=Param(int, 'Number of workers to use', default=16),
25 | chunk_size=Param(int, 'Chunk size for writing', default=100),
26 | jpeg_quality=Param(float, 'Quality of jpeg images', default=90),
27 | subset=Param(int, 'How many images to use (-1 for all)', default=-1),
28 | compress_probability=Param(float, 'compress probability', default=None)
29 | )
30 |
31 | @section('cfg')
32 | @param('dataset')
33 | @param('split')
34 | @param('data_dir')
35 | @param('write_path')
36 | @param('max_resolution')
37 | @param('num_workers')
38 | @param('chunk_size')
39 | @param('subset')
40 | @param('jpeg_quality')
41 | @param('write_mode')
42 | @param('compress_probability')
43 | def main(dataset, split, data_dir, write_path, max_resolution, num_workers,
44 | chunk_size, subset, jpeg_quality, write_mode,
45 | compress_probability):
46 | if dataset == 'cifar':
47 | my_dataset = CIFAR10(root=data_dir, train=(split == 'train'), download=True)
48 | elif dataset == 'imagenet':
49 | my_dataset = ImageFolder(root=data_dir)
50 | elif dataset == 'imagenet_aug':
51 | my_dataset_1 = ImageFolder(root=data_dir)
52 | data_dir_2 = "path/to/generated/imagenet/train/or/val"
53 | my_dataset_2 = ImageFolder(root=data_dir_2)
54 | my_dataset = ConcatDataset([my_dataset_1, my_dataset_2])
55 | else:
56 | raise ValueError('Unrecognized dataset', dataset)
57 |
58 | if subset > 0: my_dataset = Subset(my_dataset, range(subset))
59 | writer = DatasetWriter(write_path, {
60 | 'image': RGBImageField(write_mode=write_mode,
61 | max_resolution=max_resolution,
62 | compress_probability=compress_probability,
63 | jpeg_quality=jpeg_quality),
64 | 'label': IntField(),
65 | }, num_workers=num_workers)
66 |
67 | writer.from_indexed_dataset(my_dataset, chunksize=chunk_size)
68 |
69 | if __name__ == '__main__':
70 | config = get_current_config()
71 | parser = ArgumentParser()
72 | config.augment_argparse(parser)
73 | config.collect_argparse_args(parser)
74 | config.validate(mode='stderr')
75 | config.summary()
76 | main()
77 |
--------------------------------------------------------------------------------
/train/write_imagenet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #This code is directly taken from FFCV-Imagenet https://github.com/libffcv/ffcv-imagenet
4 |
5 | write_dataset () {
6 | write_path=$WRITE_DIR/${1}_${2}_${3}_${4}.ffcv
7 | echo "Writing ImageNet ${1} dataset to ${write_path}"
8 | python write_imagenet.py \
9 | --cfg.dataset=imagenet \
10 | --cfg.split=${1} \
11 | --cfg.data_dir=$IMAGENET_DIR/${1} \
12 | --cfg.write_path=$write_path \
13 | --cfg.max_resolution=${2} \
14 | --cfg.write_mode=proportion \
15 | --cfg.compress_probability=${3} \
16 | --cfg.jpeg_quality=$4
17 | }
18 |
19 | write_dataset train $1 $2 $3
20 | # write_dataset val $1 $2 $3
--------------------------------------------------------------------------------
/utils/create_imagenet_subset.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Aniruddha Saha
3 | Description: This scripts creates the Imagenet100 subset.
4 | Instead of using symlinks like in the original repo, this copies the images
5 | for ease of use in the Dataset class in this repo which loads images from lists.
6 | '''
7 | import os
8 | import re
9 | import errno
10 | import argparse
11 | import shutil
12 | from tqdm import tqdm
13 |
14 |
15 | def create_subset(class_list, full_imagenet_path, subset_imagenet_path, *,
16 | splits=('val',)):
17 | full_imagenet_path = os.path.abspath(full_imagenet_path)
18 | subset_imagenet_path = os.path.abspath(subset_imagenet_path)
19 | os.makedirs(subset_imagenet_path, exist_ok=True)
20 | for split in splits:
21 | os.makedirs(os.path.join(subset_imagenet_path, split), exist_ok=True)
22 | for c in tqdm(class_list):
23 | if re.match(r"n[0-9]{8}", c) is None:
24 | raise ValueError(
25 | f"Expected class names to be of the format nXXXXXXXX, where "
26 | f"each X represents a numerical number, e.g., n04589890, but "
27 | f"got {c}")
28 | for split in splits:
29 | try:
30 | shutil.copytree(
31 | os.path.join(full_imagenet_path, split, c),
32 | os.path.join(subset_imagenet_path, split, c)
33 | )
34 | except:
35 | print(f'Class {c} is not present')
36 | print(f'Finished creating ImageNet subset at {subset_imagenet_path}!')
37 |
38 |
39 | if __name__ == '__main__':
40 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Subset Creation')
41 | parser.add_argument('--full_imagenet_path', metavar='IMAGENET_DIR',
42 | help='path to the existing full ImageNet dataset')
43 | parser.add_argument('--subset_imagenet_path', metavar='SUBSET_DIR',
44 | help='path to create the ImageNet subset dataset')
45 | parser.add_argument('--subset', type=str,
46 | default=os.path.join(os.path.dirname(__file__), 'imagenet100_classes.txt'),
47 | help='file contains a list of subset classes')
48 | args = parser.parse_args()
49 |
50 | print(f'Using class names specified in {args.subset}.')
51 | with open(args.subset, 'r') as f:
52 | class_list = [l.strip() for l in f.readlines()]
53 |
54 | create_subset(class_list, args.full_imagenet_path, args.subset_imagenet_path)
--------------------------------------------------------------------------------
/utils/mappings/folder_to_objectnet_label.json:
--------------------------------------------------------------------------------
1 | {
2 | "squeegee": "Squeegee",
3 | "umbrella": "Umbrella",
4 | "eyeglasses": "Eyeglasses",
5 | "coaster": "Coaster",
6 | "winter_glove": "Winter glove",
7 | "pet_food_container": "Pet food container",
8 | "scissors": "Scissors",
9 | "computer_mouse": "Computer mouse",
10 | "still_camera": "Still Camera",
11 | "weight_scale": "Weight scale",
12 | "cutting_board": "Cutting board",
13 | "spatula": "Spatula",
14 | "plunger": "Plunger",
15 | "paper": "Paper",
16 | "paper_bag": "Paper bag",
17 | "microwave": "Microwave",
18 | "keyboard": "Keyboard",
19 | "standing_lamp": "Standing lamp",
20 | "chair": "Chair",
21 | "mixing_salad_bowl": "Mixing / Salad Bowl",
22 | "milk": "Milk",
23 | "blender": "Blender",
24 | "flashlight": "Flashlight",
25 | "floss_container": "Floss container",
26 | "rake": "Rake",
27 | "sandal": "Sandal",
28 | "book_closed": "Book (closed)",
29 | "extension_cable": "Extension cable",
30 | "drawer_open": "Drawer (open)",
31 | "tv": "TV",
32 | "backpack": "Backpack",
33 | "playing_cards": "Playing cards",
34 | "jar": "Jar",
35 | "frying_pan": "Frying pan",
36 | "bench": "Bench",
37 | "coffee_grinder": "Coffee grinder",
38 | "jam": "Jam",
39 | "tape": "Tape / duct tape",
40 | "dog_bed": "Dog bed",
41 | "throw_pillow": "Throw pillow",
42 | "orange": "Orange",
43 | "speaker": "Speaker",
44 | "paint_can": "Paint can",
45 | "hand_towel_or_rag": "Dishrag or hand towel",
46 | "hat": "Hat",
47 | "toaster": "Toaster",
48 | "match": "Match",
49 | "plate": "Plate",
50 | "ironing_board": "Ironing board",
51 | "alarm_clock": "Alarm clock",
52 | "sunglasses": "Sunglasses",
53 | "power_bar": "Power bar",
54 | "wine_glass": "Wine glass",
55 | "ladle": "Ladle",
56 | "ruler": "Ruler",
57 | "blanket": "Blanket",
58 | "contact_lens_case": "Contact lens case",
59 | "watch": "Watch",
60 | "trash_bin": "Trash bin",
61 | "vase": "Vase",
62 | "oven_mitts": "Oven mitts",
63 | "spoon": "Spoon",
64 | "thermos": "Thermos",
65 | "fan": "Fan",
66 | "egg_carton": "Egg carton",
67 | "lighter": "Lighter",
68 | "glue_container": "Glue container",
69 | "bottle_cap": "Bottle cap",
70 | "wallet": "Wallet",
71 | "coffee_table": "Coffee table",
72 | "nail_clippers": "Nail clippers",
73 | "cellphone_case": "Cellphone case",
74 | "canned_food": "Canned food",
75 | "letter_opener": "Letter opener",
76 | "stapler": "Stapler",
77 | "whisk": "Whisk",
78 | "tongs": "Tongs",
79 | "lettuce": "Lettuce",
80 | "shoelace": "Shoelace",
81 | "button": "Button",
82 | "hair_dryer": "Hair dryer",
83 | "bracelet": "Bracelet",
84 | "spray_bottle": "Spray bottle",
85 | "laptop_charger": "Laptop charger",
86 | "portable_heater": "Portable heater",
87 | "suit_jacket": "Suit jacket",
88 | "dress_shoe_men": "Dress shoe (men)",
89 | "rock": "Rock",
90 | "water_filter": "Water filter",
91 | "earbuds": "Earbuds",
92 | "biscuits": "Biscuits",
93 | "mouthwash": "Mouthwash",
94 | "slipper": "Slipper",
95 | "eraser_white_board": "Eraser (white board)",
96 | "bicycle": "Bicycle",
97 | "ziploc_bag": "Ziploc bag",
98 | "dish_soap": "Dish soap",
99 | "travel_case": "Travel case",
100 | "receipt": "Receipt",
101 | "boots": "Boots",
102 | "sponge": "Sponge",
103 | "full_sized_towel": "Bath towel",
104 | "flour_container": "Flour container",
105 | "pill_bottle": "Pill bottle",
106 | "candle": "Candle",
107 | "calendar": "Calendar",
108 | "tweezers": "Tweezers",
109 | "dvd_player": "DVD player",
110 | "plastic_wrap": "Plastic wrap",
111 | "ribbon": "Ribbon",
112 | "blouse": "Blouse",
113 | "walking_cane": "Walking cane",
114 | "leaf": "Leaf",
115 | "lipstick": "Lipstick",
116 | "hair_brush": "Hair brush",
117 | "night_light": "Night light",
118 | "kettle": "Kettle",
119 | "honey_container": "Honey container",
120 | "tarp": "Tarp",
121 | "ice": "Ice",
122 | "drinking_straw": "Drinking straw",
123 | "detergent": "Detergent",
124 | "mug": "Mug",
125 | "toilet_paper_roll": "Toilet paper roll",
126 | "wok": "Wok",
127 | "swimming_trunks": "Swimming trunks",
128 | "clothes_hamper": "Clothes hamper",
129 | "removable_blade": "Removable blade",
130 | "shampoo_bottle": "Shampoo bottle",
131 | "skirt": "Skirt",
132 | "loofah": "Loofah",
133 | "broom": "Broom",
134 | "photograph_printed": "Photograph (printed)",
135 | "multitool": "Multitool",
136 | "makeup": "Makeup",
137 | "nail_file": "Nail file",
138 | "brooch": "Brooch",
139 | "cork": "Cork",
140 | "coffee_french_press": "Coffee/French press",
141 | "toy": "Toy",
142 | "thermometer": "Thermometer",
143 | "hairclip": "Hairclip",
144 | "document_folder_closed": "Document folder (closed)",
145 | "pliers": "Pliers",
146 | "strainer": "Strainer",
147 | "comb": "Comb",
148 | "water_bottle": "Water bottle",
149 | "peeler": "Peeler",
150 | "monitor": "Monitor",
151 | "box": "Box",
152 | "pill_organizer": "Pill organizer",
153 | "stopper_sink_tub": "Stopper (sink/tub)",
154 | "walker": "Walker",
155 | "step_stool": "Step stool",
156 | "bills_money": "Bills (money)",
157 | "skateboard": "Skateboard",
158 | "running_shoe": "Running shoe",
159 | "coin_money": "Coin (money)",
160 | "magazine": "Magazine",
161 | "drying_rack_for_clothes": "Drying rack for clothes",
162 | "toothpaste": "Toothpaste",
163 | "paper_towel": "Paper towel",
164 | "remote_control": "Remote control",
165 | "sugar_container": "Sugar container",
166 | "dress_pants": "Dress pants",
167 | "scarf": "Scarf",
168 | "dress_shirt": "Dress shirt",
169 | "cheese": "Cheese",
170 | "can_opener": "Can opener",
171 | "shovel": "Shovel",
172 | "paintbrush": "Paintbrush",
173 | "tennis_racket": "Tennis racket",
174 | "battery": "Battery",
175 | "stuffed_animal": "Stuffed animal",
176 | "jeans": "Jeans",
177 | "tanktop": "Tanktop",
178 | "dust_pan": "Dust pan",
179 | "earring": "Earring",
180 | "tomato": "Tomato",
181 | "marker": "Marker",
182 | "makeup_brush": "Makeup brush",
183 | "ring": "Ring",
184 | "air_freshener": "Air freshener",
185 | "tablecloth": "Tablecloth",
186 | "teabag": "Teabag",
187 | "belt": "Belt",
188 | "razor": "Razor",
189 | "clothes_hanger": "Clothes hanger",
190 | "bookend": "Bookend",
191 | "sweater": "Sweater",
192 | "sock": "Sock",
193 | "usb_flash_drive": "Usb flash drive",
194 | "cellphone_charger": "Cellphone charger",
195 | "pepper_shaker": "Pepper shaker",
196 | "phone_landline": "Phone (landline)",
197 | "banana": "Banana",
198 | "printer": "Printer",
199 | "paperclip": "Paperclip",
200 | "fork": "Fork",
201 | "headphones_over_ear": "Headphones (over ear)",
202 | "cooking_oil_bottle": "Cooking oil bottle",
203 | "deodorant": "Deodorant",
204 | "usb_cable": "Usb cable",
205 | "shorts": "Shorts",
206 | "bread_loaf": "Bread loaf",
207 | "pillow": "Pillow",
208 | "drinking_cup": "Drinking Cup",
209 | "envelope": "Envelope",
210 | "mouse_pad": "Mouse pad",
211 | "chopstick": "Chopstick",
212 | "t-shirt": "T-shirt",
213 | "padlock": "Padlock",
214 | "ice_cube_tray": "Ice cube tray",
215 | "chess_piece": "Chess piece",
216 | "cereal": "Cereal",
217 | "hairtie": "Hairtie",
218 | "teapot": "Teapot",
219 | "board_game": "Board game",
220 | "butchers_knife": "Butcher's knife",
221 | "soup_bowl": "Soup Bowl",
222 | "beer_bottle": "Beer bottle",
223 | "nail_polish": "Nail polish",
224 | "hand_mirror": "Hand mirror",
225 | "combination_lock": "Combination lock",
226 | "nut_for_screw": "Nut for a screw",
227 | "nail_fastener": "Nail (fastener)",
228 | "figurine_or_statue": "Figurine or statue",
229 | "soap_bar": "Soap bar",
230 | "bucket": "Bucket",
231 | "binder_closed": "Binder (closed)",
232 | "video_camera": "Video Camera",
233 | "baseball_glove": "Baseball glove",
234 | "tape_measure": "Tape measure",
235 | "tissue": "Tissue",
236 | "coffee_beans": "Coffee beans",
237 | "scrub_brush": "Scrub brush",
238 | "drill": "Drill",
239 | "suitcase": "Suitcase",
240 | "newspaper": "Newspaper",
241 | "sleeping_bag": "Sleeping bag",
242 | "dress_shoe_women": "Dress shoe (women)",
243 | "trophy": "Trophy",
244 | "plastic_bag": "Plastic bag",
245 | "doormat": "Doormat",
246 | "webcam": "Webcam",
247 | "rolling_pin": "Rolling pin",
248 | "pencil": "Pencil",
249 | "table_knife": "Table knife",
250 | "bread_knife": "Bread knife",
251 | "toothbrush": "Toothbrush",
252 | "bathrobe": "Bathrobe",
253 | "paper_plates": "Paper plates",
254 | "placemat": "Placemat",
255 | "light_bulb": "Light bulb",
256 | "soap_dispenser": "Soap dispenser",
257 | "nightstand": "Nightstand",
258 | "pen": "Pen",
259 | "squeeze_bottle": "Squeeze bottle",
260 | "wheel": "Wheel",
261 | "dress": "Dress",
262 | "helmet": "Helmet",
263 | "lemon": "Lemon",
264 | "hammer": "Hammer",
265 | "lampshade": "Lampshade",
266 | "salt_shaker": "Salt shaker",
267 | "power_cable": "Power cable",
268 | "vacuum_cleaner": "Vacuum cleaner",
269 | "iron_for_clothes": "Iron (for clothes)",
270 | "laptop_open": "Laptop (open)",
271 | "poster": "Poster",
272 | "coffee_machine": "Coffee machine",
273 | "tie": "Tie",
274 | "cd_case": "CD case",
275 | "baseball_bat": "Baseball bat",
276 | "tablet_ipad": "Tablet / iPad",
277 | "bottle_opener": "Bottle opener",
278 | "briefcase": "Briefcase",
279 | "baking_sheet": "Baking sheet",
280 | "screw": "Screw",
281 | "pitcher": "Pitcher",
282 | "notepad": "Notepad",
283 | "tote_bag": "Tote bag",
284 | "raincoat": "Raincoat",
285 | "necklace": "Necklace",
286 | "band_aid": "Band Aid",
287 | "notebook": "Notebook",
288 | "measuring_cup": "Measuring cup",
289 | "weight_exercise": "Weight (exercise)",
290 | "handbag": "Handbag",
291 | "bike_pump": "Bike pump",
292 | "bottle_stopper": "Bottle stopper",
293 | "chocolate": "Chocolate",
294 | "safety_pin": "Safety pin",
295 | "plastic_cup": "Plastic cup",
296 | "butter": "Butter",
297 | "cellphone": "Cellphone",
298 | "drying_rack_for_dishes": "Drying rack for plates",
299 | "trash_bag": "Trash bag",
300 | "tray": "Tray",
301 | "wine_bottle": "Wine bottle",
302 | "whistle": "Whistle",
303 | "key_chain": "Key chain",
304 | "napkin": "Napkin",
305 | "desk_lamp": "Desk lamp",
306 | "first_aid_kit": "First aid kit",
307 | "bed_sheet": "Bed sheet",
308 | "beer_can": "Beer can",
309 | "wrench": "Wrench",
310 | "pop_can": "Pop can",
311 | "basket": "Basket",
312 | "leggings": "Leggings",
313 | "egg": "Egg",
314 | "sewing_kit": "Sewing kit"
315 | }
316 |
--------------------------------------------------------------------------------
/utils/mappings/imagenet100.txt:
--------------------------------------------------------------------------------
1 | n02869837
2 | n01749939
3 | n02488291
4 | n02107142
5 | n13037406
6 | n02091831
7 | n04517823
8 | n04589890
9 | n03062245
10 | n01773797
11 | n01735189
12 | n07831146
13 | n07753275
14 | n03085013
15 | n04485082
16 | n02105505
17 | n01983481
18 | n02788148
19 | n03530642
20 | n04435653
21 | n02086910
22 | n02859443
23 | n13040303
24 | n03594734
25 | n02085620
26 | n02099849
27 | n01558993
28 | n04493381
29 | n02109047
30 | n04111531
31 | n02877765
32 | n04429376
33 | n02009229
34 | n01978455
35 | n02106550
36 | n01820546
37 | n01692333
38 | n07714571
39 | n02974003
40 | n02114855
41 | n03785016
42 | n03764736
43 | n03775546
44 | n02087046
45 | n07836838
46 | n04099969
47 | n04592741
48 | n03891251
49 | n02701002
50 | n03379051
51 | n02259212
52 | n07715103
53 | n03947888
54 | n04026417
55 | n02326432
56 | n03637318
57 | n01980166
58 | n02113799
59 | n02086240
60 | n03903868
61 | n02483362
62 | n04127249
63 | n02089973
64 | n03017168
65 | n02093428
66 | n02804414
67 | n02396427
68 | n04418357
69 | n02172182
70 | n01729322
71 | n02113978
72 | n03787032
73 | n02089867
74 | n02119022
75 | n03777754
76 | n04238763
77 | n02231487
78 | n03032252
79 | n02138441
80 | n02104029
81 | n03837869
82 | n03494278
83 | n04136333
84 | n03794056
85 | n03492542
86 | n02018207
87 | n04067472
88 | n03930630
89 | n03584829
90 | n02123045
91 | n04229816
92 | n02100583
93 | n03642806
94 | n04336792
95 | n03259280
96 | n02116738
97 | n02108089
98 | n03424325
99 | n01855672
100 | n02090622
--------------------------------------------------------------------------------
/utils/mappings/imagenet100_to_labels.json:
--------------------------------------------------------------------------------
1 | {"15": "robin, American robin, Turdus migratorius", "45": "Gila monster, Heloderma suspectum", "54": "hognose snake, puff adder, sand viper", "57": "garter snake, grass snake", "64": "green mamba", "74": "garden spider, Aranea diademata", "90": "lorikeet", "99": "goose", "119": "rock crab, Cancer irroratus", "120": "fiddler crab", "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus", "131": "little blue heron, Egretta caerulea", "137": "American coot, marsh hen, mud hen, water hen, Fulica americana", "151": "Chihuahua", "155": "Shih-Tzu", "157": "papillon", "158": "toy terrier", "166": "Walker hound, Walker foxhound", "167": "English foxhound", "169": "borzoi, Russian wolfhound", "176": "Saluki, gazelle hound", "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", "209": "Chesapeake Bay retriever", "211": "vizsla, Hungarian pointer", "222": "kuvasz", "228": "komondor", "234": "Rottweiler", "236": "Doberman, Doberman pinscher", "242": "boxer", "246": "Great Dane", "267": "standard poodle", "268": "Mexican hairless", "272": "coyote, prairie wolf, brush wolf, Canis latrans", "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", "277": "red fox, Vulpes vulpes", "281": "tabby, tabby cat", "299": "meerkat, mierkat", "305": "dung beetle", "313": "walking stick, walkingstick, stick insect", "317": "leafhopper", "331": "hare", "342": "wild boar, boar, Sus scrofa", "368": "gibbon, Hylobates lar", "374": "langur", "407": "ambulance", "421": "bannister, banister, balustrade, balusters, handrail", "431": "bassinet", "449": "boathouse", "452": "bonnet, poke bonnet", "455": "bottlecap", "479": "car wheel", "494": "chime, bell, gong", "498": "cinema, movie theater, movie theatre, movie house, picture palace", "503": "cocktail shaker", "508": "computer keyboard, keypad", "544": "Dutch oven", "560": "football helmet", "570": "gasmask, respirator, gas helmet", "592": "hard disc, hard disk, fixed disk", "593": "harmonica, mouth organ, harp, mouth harp", "599": "honeycomb", "606": "iron, smoothing iron", "608": "jean, blue jean, denim", "619": "lampshade, lamp shade", "620": "laptop, laptop computer", "653": "milk can", "659": "mixing bowl", "662": "modem", "665": "moped", "667": "mortarboard", "674": "mousetrap", "682": "obelisk", "703": "park bench", "708": "pedestal, plinth, footstall", "717": "pickup, pickup truck", "724": "pirate, pirate ship", "748": "purse", "758": "reel", "765": "rocking chair, rocker", "766": "rotisserie", "772": "safety pin", "775": "sarong", "796": "ski mask", "798": "slide rule, slipstick", "830": "stretcher", "854": "theater curtain, theatre curtain", "857": "throne", "858": "tile roof", "872": "tripod", "876": "tub, vat", "882": "vacuum, vacuum cleaner", "904": "window screen", "908": "wing", "936": "head cabbage", "938": "cauliflower", "953": "pineapple, ananas", "959": "carbonara", "960": "chocolate sauce, chocolate syrup", "993": "gyromitra", "994": "stinkhorn, carrion fungus"}
--------------------------------------------------------------------------------
/utils/mappings/imagenet_pytorch_id_to_objectnet_id.json:
--------------------------------------------------------------------------------
1 | {"409": 1, "530": 1, "414": 2, "954": 4, "419": 5, "790": 8, "434": 9, "440": 13, "703": 16, "671": 17, "444": 17, "446": 20, "455": 29, "930": 35, "462": 38, "463": 39, "499": 40, "473": 45, "470": 46, "487": 48, "423": 52, "559": 52, "765": 52, "588": 57, "550": 64, "507": 67, "673": 68, "846": 75, "533": 78, "539": 81, "630": 86, "740": 88, "968": 89, "729": 92, "549": 98, "545": 102, "567": 109, "578": 83, "589": 112, "587": 115, "560": 120, "518": 120, "606": 124, "608": 128, "508": 131, "618": 132, "619": 133, "620": 134, "951": 138, "623": 139, "626": 142, "629": 143, "644": 149, "647": 150, "651": 151, "659": 153, "664": 154, "504": 157, "677": 159, "679": 164, "950": 171, "695": 173, "696": 175, "700": 179, "418": 182, "749": 182, "563": 182, "720": 188, "721": 190, "725": 191, "728": 193, "923": 196, "731": 199, "737": 200, "811": 201, "742": 205, "761": 210, "769": 216, "770": 217, "772": 218, "773": 219, "774": 220, "783": 223, "792": 229, "601": 231, "655": 231, "689": 231, "797": 232, "804": 235, "806": 236, "809": 237, "813": 238, "632": 239, "732": 248, "759": 248, "828": 250, "850": 251, "834": 253, "837": 255, "841": 256, "842": 257, "610": 258, "851": 259, "849": 268, "752": 269, "457": 273, "906": 273, "859": 275, "999": 276, "412": 284, "868": 286, "879": 289, "882": 292, "883": 293, "893": 297, "531": 298, "898": 299, "543": 302, "778": 303, "479": 304, "694": 304, "902": 306, "907": 307, "658": 309, "909": 310}
--------------------------------------------------------------------------------
/utils/mappings/objectnet_im100_folder.json:
--------------------------------------------------------------------------------
1 | {"Bench": "n03891251", "Bottle cap": "n02877765", "Chair": "n04099969", "Helmet": "n03379051", "Iron (for clothes)": "n03584829", "Jeans": "n03594734", "Keyboard": "n03085013", "Lampshade": "n03637318", "Laptop (open)": "n03642806", "Mixing / Salad Bowl": "n03775546", "Safety pin": "n04127249", "Vacuum cleaner": "n04517823", "Wheel": "n02974003"}
--------------------------------------------------------------------------------
/utils/mappings/objectnet_im1k_folder.json:
--------------------------------------------------------------------------------
1 | {"Alarm clock": "n02708093", "Backpack": "n02769748", "Banana": "n07753592", "Band Aid": "n02786058", "Basket": "n04204238", "Bath towel": "n02808304", "Beer bottle": "n02823428", "Bench": "n03891251", "Bicycle": "n03792782", "Binder (closed)": "n02840245", "Bottle cap": "n02877765", "Bread loaf": "n07684084", "Broom": "n02906734", "Bucket": "n02909870", "Butcher's knife": "n03041632", "Can opener": "n02951585", "Candle": "n02948072", "Cellphone": "n02992529", "Chair": "n02791124", "Clothes hamper": "n03482405", "Coffee/French press": "n03297495", "Combination lock": "n03075370", "Computer mouse": "n03793489", "Desk lamp": "n04380533", "Dishrag or hand towel": "n03207743", "Doormat": "n03223299", "Dress shoe (men)": "n03680355", "Drill": "n03995372", "Drinking Cup": "n07930864", "Drying rack for plates": "n03961711", "Envelope": "n03291819", "Fan": "n03271574", "Frying pan": "n03400231", "Dress": "n03450230", "Hair dryer": "n03483316", "Hammer": "n03481172", "Helmet": "n03379051", "Iron (for clothes)": "n03584829", "Jeans": "n03594734", "Keyboard": "n03085013", "Ladle": "n03633091", "Lampshade": "n03637318", "Laptop (open)": "n03642806", "Lemon": "n07749582", "Letter opener": "n03658185", "Lighter": "n03666591", "Lipstick": "n03676483", "Match": "n03729826", "Measuring cup": "n03733805", "Microwave": "n03761084", "Mixing / Salad Bowl": "n03775546", "Monitor": "n03782006", "Mug": "n03063599", "Nail (fastener)": "n03804744", "Necklace": "n03814906", "Orange": "n07747607", "Padlock": "n03874599", "Paintbrush": "n03876231", "Paper towel": "n03887697", "Pen": "n02783161", "Pill bottle": "n03937543", "Pillow": "n03938244", "Pitcher": "n03950228", "Plastic bag": "n03958227", "Plate": "n07579787", "Plunger": "n03970156", "Pop can": "n03983396", "Portable heater": "n04265275", "Printer": "n04004767", "Remote control": "n04074963", "Ruler": "n04118776", "Running shoe": "n04120489", "Safety pin": "n04127249", "Salt shaker": "n04131690", "Sandal": "n04133789", "Screw": "n04153751", "Shovel": "n04208210", "Skirt": "n03534580", "Sleeping bag": "n04235860", "Soap dispenser": "n04254120", "Sock": "n04254777", "Soup Bowl": "n04263257", "Spatula": "n04270147", "Speaker": "n03691459", "Still Camera": "n03976467", "Strainer": "n04332243", "Stuffed animal": "n04399382", "Suit jacket": "n04350905", "Sunglasses": "n04356056", "Sweater": "n04370456", "Swimming trunks": "n04371430", "T-shirt": "n03595614", "TV": "n04404412", "Teapot": "n04398044", "Tennis racket": "n04039381", "Tie": "n02883205", "Toaster": "n04442312", "Toilet paper roll": "n15075141", "Trash bin": "n02747177", "Tray": "n04476259", "Umbrella": "n04507155", "Vacuum cleaner": "n04517823", "Vase": "n04522168", "Wallet": "n04548362", "Watch": "n03197337", "Water bottle": "n04557648", "Weight (exercise)": "n03255030", "Weight scale": "n04141975", "Wheel": "n02974003", "Whistle": "n04579432", "Wine bottle": "n04591713", "Winter glove": "n03775071", "Wok": "n04596742"}
--------------------------------------------------------------------------------
/utils/mappings/objectnet_to_im100.json:
--------------------------------------------------------------------------------
1 | {"Bench": "park bench", "Bottle cap": "bottlecap", "Chair": "rocking chair, rocker", "Helmet": "football helmet", "Iron (for clothes)": "iron, smoothing iron", "Jeans": "jean, blue jean, denim", "Keyboard": "computer keyboard, keypad", "Lampshade": "lampshade, lamp shade", "Laptop (open)": "laptop, laptop computer", "Mixing / Salad Bowl": "mixing bowl", "Safety pin": "safety pin", "Vacuum cleaner": "vacuum, vacuum cleaner", "Wheel": "car wheel"}
--------------------------------------------------------------------------------
/utils/mappings/objectnet_to_imagenet_1k.json:
--------------------------------------------------------------------------------
1 | {
2 | "Alarm clock": "analog clock; digital clock",
3 | "Backpack": "backpack, back pack, knapsack, packsack, rucksack, haversack",
4 | "Banana": "banana",
5 | "Band Aid": "Band Aid",
6 | "Basket": "shopping basket",
7 | "Bath towel": "bath towel",
8 | "Beer bottle": "beer bottle",
9 | "Bench": "park bench",
10 | "Bicycle": "mountain bike, all-terrain bike, off-roader; bicycle-built-for-two, tandem bicycle, tandem",
11 | "Binder (closed)": "binder, ring-binder",
12 | "Bottle cap": "bottlecap",
13 | "Bread loaf": "French loaf",
14 | "Broom": "broom",
15 | "Bucket": "bucket, pail",
16 | "Butcher's knife": "cleaver, meat cleaver, chopper",
17 | "Can opener": "can opener, tin opener",
18 | "Candle": "candle, taper, wax light",
19 | "Cellphone": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
20 | "Chair": "barber chair; folding chair; rocking chair, rocker",
21 | "Clothes hamper": "hamper",
22 | "Coffee/French press": "espresso maker",
23 | "Combination lock": "combination lock",
24 | "Computer mouse": "mouse, computer mouse",
25 | "Desk lamp": "table lamp",
26 | "Dishrag or hand towel": "dishrag, dishcloth",
27 | "Doormat": "doormat, welcome mat",
28 | "Dress shoe (men)": "Loafer",
29 | "Drill": "power drill",
30 | "Drinking Cup": "cup",
31 | "Drying rack for plates": "plate rack",
32 | "Envelope": "envelope",
33 | "Fan": "electric fan, blower",
34 | "Frying pan": "frying pan, frypan, skillet",
35 | "Dress": "gown",
36 | "Hair dryer": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
37 | "Hammer": "hammer",
38 | "Helmet": "football helmet; crash helmet",
39 | "Iron (for clothes)": "iron, smoothing iron",
40 | "Jeans": "jean, blue jean, denim",
41 | "Keyboard": "computer keyboard, keypad",
42 | "Ladle": "ladle",
43 | "Lampshade": "lampshade, lamp shade",
44 | "Laptop (open)": "laptop, laptop computer",
45 | "Lemon": "lemon",
46 | "Letter opener": "letter opener, paper knife, paperknife",
47 | "Lighter": "lighter, light, igniter, ignitor",
48 | "Lipstick": "lipstick, lip rouge",
49 | "Match": "matchstick",
50 | "Measuring cup": "measuring cup",
51 | "Microwave": "microwave, microwave oven",
52 | "Mixing / Salad Bowl": "mixing bowl",
53 | "Monitor": "monitor",
54 | "Mug": "coffee mug",
55 | "Nail (fastener)": "nail",
56 | "Necklace": "necklace",
57 | "Orange": "orange",
58 | "Padlock": "padlock",
59 | "Paintbrush": "paintbrush",
60 | "Paper towel": "paper towel",
61 | "Pen": "ballpoint, ballpoint pen, ballpen, Biro; quill, quill pen; fountain pen",
62 | "Pill bottle": "pill bottle",
63 | "Pillow": "pillow",
64 | "Pitcher": "pitcher, ewer",
65 | "Plastic bag": "plastic bag",
66 | "Plate": "plate",
67 | "Plunger": "plunger, plumber's helper",
68 | "Pop can": "pop bottle, soda bottle",
69 | "Portable heater": "space heater",
70 | "Printer": "printer",
71 | "Remote control": "remote control, remote",
72 | "Ruler": "rule, ruler",
73 | "Running shoe": "running shoe",
74 | "Safety pin": "safety pin",
75 | "Salt shaker": "saltshaker, salt shaker",
76 | "Sandal": "sandal",
77 | "Screw": "screw",
78 | "Shovel": "shovel",
79 | "Skirt": "hoopskirt, crinoline; miniskirt, mini; overskirt",
80 | "Sleeping bag": "sleeping bag",
81 | "Soap dispenser": "soap dispenser",
82 | "Sock": "sock",
83 | "Soup Bowl": "soup bowl",
84 | "Spatula": "spatula",
85 | "Speaker": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
86 | "Still Camera": "Polaroid camera, Polaroid Land camera; reflex camera",
87 | "Strainer": "strainer",
88 | "Stuffed animal": "teddy, teddy bear",
89 | "Suit jacket": "suit, suit of clothes",
90 | "Sunglasses": "sunglasses, dark glasses, shades",
91 | "Sweater": "sweatshirt",
92 | "Swimming trunks": "swimming trunks, bathing trunks",
93 | "T-shirt": "jersey, T-shirt, tee shirt",
94 | "TV": "television, television system",
95 | "Teapot": "teapot",
96 | "Tennis racket": "racket, racquet",
97 | "Tie": "bow tie, bow-tie, bowtie; Windsor tie",
98 | "Toaster": "toaster",
99 | "Toilet paper roll": "toilet tissue, toilet paper, bathroom tissue",
100 | "Trash bin": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
101 | "Tray": "tray",
102 | "Umbrella": "umbrella",
103 | "Vacuum cleaner": "vacuum, vacuum cleaner",
104 | "Vase": "vase",
105 | "Wallet": "wallet, billfold, notecase, pocketbook",
106 | "Watch": "digital watch",
107 | "Water bottle": "water bottle",
108 | "Weight (exercise)": "dumbbell",
109 | "Weight scale": "scale, weighing machine",
110 | "Wheel": "car wheel; paddlewheel, paddle wheel",
111 | "Whistle": "whistle",
112 | "Wine bottle": "wine bottle",
113 | "Winter glove": "mitten",
114 | "Wok": "wok"
115 | }
--------------------------------------------------------------------------------
/utils/rename_imagenet_v2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import pandas as pd
4 | from tqdm import tqdm
5 |
6 | if __name__ == '__main__':
7 | parser = argparse.ArgumentParser(description='ImageNetV2 Rename')
8 | parser.add_argument('--imagenetv2_path', metavar='IMAGENET_DIR',
9 | help='path to the existing full ImageNetV2 dataset')
10 | parser.add_argument('--datafile', help='path to labels file containing class name and folder name mapping')
11 | args = parser.parse_args()
12 |
13 | folder_names = os.listdir(args.imagenetv2_path)
14 |
15 | df = pd.read_csv(args.datafile)
16 |
17 | classnames = [classname.split('/')[1] for classname in list(df['image'])]
18 | classnames = [i for n, i in enumerate(classnames) if i not in classnames[:n]]
19 |
20 | for i in tqdm(range(1000)):
21 | os.rename(os.path.join(args.imagenetv2_path, f'{i}'), os.path.join(args.imagenetv2_path, f'{classnames[i]}'))
--------------------------------------------------------------------------------
/utils/subset_objectnet_im.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import shutil
4 | import argparse
5 | from tqdm import tqdm
6 |
7 | def create_subset(args):
8 |
9 | dirname = 'Objectnet/val'
10 | os.makedirs(os.path.join(args.im100_root, dirname), exist_ok = True)
11 |
12 | with open(args.folder_to_label, 'r') as f:
13 | folder_to_label = json.load(f)
14 | label_to_folder = {v:k for k,v in folder_to_label.items()}
15 |
16 | with open(args.objectnet_to_im100_folder, 'r') as f:
17 | objectnet_to_im100_folder = json.load(f)
18 |
19 | for label in tqdm(objectnet_to_im100_folder):
20 | folder_name = label_to_folder[label]
21 | images = os.listdir(os.path.join(args.objectnet_root, folder_name))
22 | im100_folder_name = objectnet_to_im100_folder[label]
23 | os.makedirs(os.path.join(args.im100_root, dirname, im100_folder_name), exist_ok = True)
24 | print(im100_folder_name)
25 | for image in images:
26 | shutil.copy(os.path.join(args.objectnet_root, folder_name, image), os.path.join(args.im100_root, dirname, im100_folder_name))
27 |
28 | if __name__ == '__main__':
29 | parser = argparse.ArgumentParser(description='ImageNetV2 Rename')
30 | parser.add_argument('--objectnet_root', type=str, default=None)
31 | parser.add_argument('--im100_root', type=str, default=None)
32 | parser.add_argument('--folder_to_label', type=str, default=None)
33 | parser.add_argument('--objectnet_to_im100_folder', type=str, default=None)
34 | args = parser.parse_args()
35 |
36 | create_subset(args)
--------------------------------------------------------------------------------