├── .gitignore ├── LICENSE ├── README.md ├── environment.yml ├── figure.png ├── generation ├── DeepAugment │ ├── CAE_Model │ │ ├── __pycache__ │ │ │ └── cae_32x32x32_zero_pad_bin.cpython-39.pyc │ │ └── cae_32x32x32_zero_pad_bin.py │ ├── CAE_Weights │ │ └── model_final.state │ ├── CAE_distort_imagenet.py │ ├── EDSR_Model │ │ ├── __init__.py │ │ ├── common.py │ │ └── edsr.py │ ├── EDSR_Weights │ │ └── edsr_baseline_x4.pt │ ├── EDSR_distort_imagenet.py │ ├── README.md │ ├── deepaugment.png │ ├── train.py │ ├── train.sh │ └── train_noise2net.py ├── __pycache__ │ ├── conditional_ldm.cpython-39.pyc │ ├── data.cpython-39.pyc │ ├── generate_images.cpython-39.pyc │ ├── generate_images_captions.cpython-39.pyc │ └── generate_images_i2i.cpython-39.pyc ├── cin256-v2.yaml ├── classes.py ├── conditional_ldm.py ├── data.py ├── folder_to_class.csv ├── generate_images.py ├── generate_images_captions.py ├── generate_images_i2i.py └── ldm │ ├── __pycache__ │ └── util.cpython-39.pyc │ ├── data │ ├── __init__.py │ ├── base.py │ ├── imagenet.py │ └── lsun.py │ ├── lr_scheduler.py │ ├── models │ ├── __pycache__ │ │ └── autoencoder.cpython-39.pyc │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── ddim.cpython-39.pyc │ │ └── ddpm.cpython-39.pyc │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ └── plms.py │ ├── modules │ ├── __pycache__ │ │ ├── attention.cpython-39.pyc │ │ ├── ema.cpython-39.pyc │ │ └── x_transformer.cpython-39.pyc │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── model.cpython-39.pyc │ │ │ ├── openaimodel.cpython-39.pyc │ │ │ └── util.cpython-39.pyc │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ └── distributions.cpython-39.pyc │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ └── modules.cpython-39.pyc │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py │ └── util.py ├── sd_finetune ├── data.py ├── finetune_sd.py └── parser.py ├── train ├── get_masks.py ├── resnet_configs │ ├── alexnet.yaml │ ├── efficient.yaml │ ├── inception.yaml │ ├── mobilenet.yaml │ ├── resnext101.yaml │ ├── resnext50.yaml │ ├── rn18_88_epochs.yaml │ └── vgg16.yaml ├── train_imagenet.py ├── write_imagenet.py └── write_imagenet.sh └── utils ├── create_imagenet_subset.py ├── mappings ├── classes.py ├── folder_to_class.csv ├── folder_to_objectnet_label.json ├── imagenet100.txt ├── imagenet100_to_labels.json ├── imagenet_pytorch_id_to_objectnet_id.json ├── imagenet_to_label_2012_v2 ├── imagenet_to_labels.json ├── objectnet_im100_folder.json ├── objectnet_im1k_folder.json ├── objectnet_to_im100.json ├── objectnet_to_imagenet_1k.json └── pytorch_to_imagenet_2012_id.json ├── rename_imagenet_v2.py └── subset_objectnet_im.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Build and Release Folders 2 | bin-debug/ 3 | bin-release/ 4 | [Oo]bj/ 5 | [Bb]in/ 6 | 7 | # Other files and folders 8 | .settings/ 9 | 10 | # Executables 11 | *.swf 12 | *.air 13 | *.ipa 14 | *.apk 15 | 16 | # Project files, i.e. `.project`, `.actionScriptProperties` and `.flexProperties` 17 | # should NOT be excluded as they contain compiler settings and other important 18 | # information for Eclipse / Flash Builder. 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Hritikbansal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative Robustness 2 | 3 | This repo contains the code for the experiments in the 'Leaving Reality to Imagination: Robust Classification via Generated Datasets' paper. Arxiv link: [https://arxiv.org/pdf/2302.02503.pdf](https://arxiv.org/pdf/2302.02503.pdf) 4 | 5 | [Colab](https://colab.research.google.com/drive/1I2IO8tD_l9JdCRJHOqlAP6ojMPq_BsoR?usp=sharing) 6 | 7 |

8 | 9 | Accepted as Oral Presentation at RTML ICLR 2023. 10 | Accepted at SPIGM ICML 2023. 11 | 12 | ## Link to Generated ImageNet-1K dataset 13 | 14 | You can download the `Base-Generated-ImageNet-1K` dataset from [here](https://drive.google.com/drive/folders/1-jLyiJ_S-VZMS5zQNR6e1xDOAkAVJxvs?usp=sharing). Even though we discuss three variants of generated data in the paper, we make generations using captions of the class labels public for novel usecases by the community. 15 | 16 | You can download the `Finetuned-Generated-ImageNet-1K` dataset from [here](https://drive.google.com/drive/folders/1xVBcCrae1JrmOCoHHNWCw5UN6XAOoupk?usp=sharing). 17 | 18 | Structure of the dataset looks like: 19 | 20 | ``` 21 | * train (1000 folders) 22 | * n01440764 (1300 images) 23 | * image1.jpeg 24 | * . 25 | * imageN.jpeg 26 | * . 27 | * . 28 | * val (1000 images) 29 | * n01440764 (50 images) 30 | * image1.jpeg 31 | * . 32 | * imageN.jpeg 33 | * . 34 | * . 35 | ``` 36 | 37 | ## Finetuning Stable Diffusion on the Real ImageNet-1K 38 | 39 | We provide the finetuning details and as well as the finetuned Stable Diffusion model at [https://huggingface.co/hbXNov/ucla-mint-finetune-sd-im1k](https://huggingface.co/hbXNov/ucla-mint-finetune-sd-im1k). 40 | 41 | Colab Notebook: [here](https://colab.research.google.com/drive/1I2IO8tD_l9JdCRJHOqlAP6ojMPq_BsoR?usp=sharing) 42 | 43 | The finetuning code could be found in this folder: [sd_finetune](sd_finetune). Most of this code is adopted from the `diffusers` library - [text-to-image](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image). We are really thankful to the authors!! 44 | 45 | 46 | *The rest of the README will focus on generating data from Stable Diffusion in a zero-shot manner, and training ImageNet classifiers efficiently using FFCV.* 47 | 48 | ## Data Generation Using Stable Diffusion 49 | 50 | [Stable Diffusion](https://github.com/CompVis/stable-diffusion) is a popular text-to-image generative model. Most of the code is adapted from the very popular [diffusers](https://github.com/huggingface/diffusers) library from HuggingFace. 51 | 52 | However, it might not be straightforward to generate images from Stable Diffusion on multiple GPUs. To that end, we use the [accelerate](https://huggingface.co/docs/accelerate/index) package from Huggingface. 53 | 54 | ### Requirements 55 | 56 | - Both Linux and Windows are supported, but we strongly recommend Linux for performance and compatibility reasons. 57 | - 64-bit Python 3.7+ installation. 58 | - We used 5 A6000s 24GB DRAM GPUs for generation. 59 | 60 | ### Setup 61 | 62 | ``` 63 | 1. git clone https://github.com/Hritikbansal/leaving_reality_to_imagination.git 64 | 2. cd leaving_reality_to_imagination 65 | 3. conda env create -f environment.yml 66 | 4. pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 torchaudio==0.13.0+cu117 -f https://download.pytorch.org/whl/torch_stable.html (Replace based on your computer's hardware) 67 | 5. accelerate config 68 | - This machine 69 | - multi-GPU 70 | - (How many machines?) 1 71 | - (optimize with dynamo?) NO 72 | - (Deepspeed?) NO 73 | - (FullyShardedParallel?) NO 74 | - (MegatronLM) NO 75 | - (Num of GPUs) 5 76 | - (device ids) 0,1,2,3,4 77 | - (np/fp16/bp16) no 78 | ``` 79 | 80 | ### Files 81 | 82 | 1. [generate_images_captions](generation/generate_images_captions.py) generates the images conditioned on the diverse text prompts (Class Labels). 83 | 2. [generate_images](generation/generate_images.py) generates the images conditioned on the images (Real Images). 84 | 3. [generate_images_i2i](generation.generate_images_i2i.py) generates the images conditioned on the encoded images and text (Real Images and Class Labels). 85 | 4. [conditional ldm](generation/conditional_ldm.py) generates images from the class-conditional latent diffusion model. You can download the model ckpt from [Stable Diffusion](https://github.com/CompVis/stable-diffusion) repo. 86 | 87 | Move the [classes.py](generation/classes.py) and [folder_to_class.csv](generation/folder_to_class.csv) to the `imagenet_dir`. 88 | 89 | ### Commands 90 | 91 | ```python 92 | accelerate launch --num_cpu_threads_per_process 8 -m generation.generate_images_captions --batch_size 8 --data_dir --save_image_gen --diversity --split val 93 | ``` 94 | 95 | ```python 96 | accelerate launch --num_cpu_threads_per_process 8 -m generation.generate_images --batch_size 2 --eval_test_data_dir --save_image_gen --split val 97 | ``` 98 | 99 | ```python 100 | accelerate launch --num_cpu_threads_per_process 8 -m generation.generate_images_i2i --batch_size 12 --data_dir --save_image_gen --split val --diversity 101 | ``` 102 | 103 | ```python 104 | accelerate launch --num_cpu_threads_per_process 8 conditional_ldm.py --config cin256-v2.yaml --checkpoint --save_image_gen 105 | ``` 106 | 107 | 108 | ## Training ImageNet Models Using FFCV 109 | 110 | We suggest the users to create a separate FFCV conda environment for training ImageNet models. 111 | 112 | ### Preparing the Dataset 113 | Following the ImageNet training pipeline of [FFCV](https://github.com/libffcv/ffcv-imagenet) for ResNet50, generate the dataset with the following command (`IMAGENET_DIR` should point to a PyTorch style [ImageNet dataset](https://github.com/MadryLab/pytorch-imagenet-dataset)): 114 | 115 | ```bash 116 | # Required environmental variables for the script: 117 | cd train/ 118 | export IMAGENET_DIR=/path/to/pytorch/format/imagenet/directory/ 119 | export WRITE_DIR=/your/path/here/ 120 | 121 | # Serialize images with: 122 | # - 500px side length maximum 123 | # - 50% JPEG encoded, 90% raw pixel values 124 | # - quality=90 JPEGs 125 | ./write_imagenet.sh 500 0.50 90 126 | ``` 127 | Note that we prepare the dataset with the following FFCV configuration: 128 | * ResNet-50 training: 50% JPEG 500px side length (*train_500_0.50_90.ffcv*) 129 | * ResNet-50 evaluation: 0% JPEG 500px side length (*val_500_uncompressed.ffcv*) 130 | 131 | - We have made some custom edits to [write_imagenet.py](train/write_imagenet.py) to generate augmented imagenet data 132 | 133 | ### Training 134 | 135 | ```python 136 | CUDA_VISIBLE_DEVICES=0,1,2,3,4 python train_imagenet.py --config-file resnet_configs/resnext50.yaml --data.train_dataset= --data.val_dataset= --data.num_workers=8 --logging.folder= --model.num_classes=100 (if imagenet 100) --training.distributed=1 --dist.world_size=5 137 | ``` 138 | 139 | ### Evaluation 140 | 141 | ```python 142 | CUDA_VISIBLE_DEVICES=0,1,2,3,4 python train_imagenet.py --config-file resnet_configs/rn18_88_epochs.yaml --data.train_dataset= --data.val_dataset= --data.num_workers=8 --training.path= --model.num_classes=1000 --training.distributed=1 --dist.world_size=5 --training.eval_only=1 143 | ``` 144 | 145 | ### Note 146 | 147 | 1. Since ImageNet-R and ObjectNet do not share all the classes with ImageNet1K, we use an additional `validation.imr` or `validation.obj` flag while evaluating on these datasets. 148 | 2. [create_imagenet_subset](utils/create_imagenet_subset.py) is used to create the random subset containing 100 classes. [mappings](utils/mappings/) contains the relevant `imagenet100.txt` file. 149 | 150 | ## Natural Distribution Shift Datasets 151 | 152 | We use (a) ImageNet-Sketch, (b) ImageNet-R, (c) ImageNet-V2, and (d) ObjectNet in our work. The users can navigate to their respective sources to download the data. 153 | 154 | 1. [rename_imagenetv2](utils/rename_imagenet_v2.py) renames the imagenetv2 folders that are original named based on indices 0-1000 to original imagenet folder names n0XXXXX. 155 | 2. [subset_objectnet_im](utils/subset_objectnet_im.py) is used to create a subset of ObjectNet classes that overlap with ImageNet-100/1000. 156 | 157 | 158 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: lrti 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - ca-certificates=2022.10.11=h06a4308_0 10 | - certifi=2022.9.24=py39h06a4308_0 11 | - cuda=11.6.2=0 12 | - cuda-cccl=11.6.55=hf6102b2_0 13 | - cuda-command-line-tools=11.6.2=0 14 | - cuda-compiler=11.6.2=0 15 | - cuda-cudart=11.6.55=he381448_0 16 | - cuda-cudart-dev=11.6.55=h42ad0f4_0 17 | - cuda-cuobjdump=11.6.124=h2eeebcb_0 18 | - cuda-cupti=11.6.124=h86345e5_0 19 | - cuda-cuxxfilt=11.6.124=hecbf4f6_0 20 | - cuda-driver-dev=11.6.55=0 21 | - cuda-gdb=12.0.90=0 22 | - cuda-libraries=11.6.2=0 23 | - cuda-libraries-dev=11.6.2=0 24 | - cuda-memcheck=11.8.86=0 25 | - cuda-nsight=12.0.78=0 26 | - cuda-nsight-compute=12.0.0=0 27 | - cuda-nvcc=11.6.124=hbba6d2d_0 28 | - cuda-nvdisasm=12.0.76=0 29 | - cuda-nvml-dev=11.6.55=haa9ef22_0 30 | - cuda-nvprof=12.0.90=0 31 | - cuda-nvprune=11.6.124=he22ec0a_0 32 | - cuda-nvrtc=11.6.124=h020bade_0 33 | - cuda-nvrtc-dev=11.6.124=h249d397_0 34 | - cuda-nvtx=11.6.124=h0630a44_0 35 | - cuda-nvvp=12.0.90=0 36 | - cuda-runtime=11.6.2=0 37 | - cuda-samples=11.6.101=h8efea70_0 38 | - cuda-sanitizer-api=12.0.90=0 39 | - cuda-toolkit=11.6.2=0 40 | - cuda-tools=11.6.2=0 41 | - cuda-visual-tools=11.6.2=0 42 | - gds-tools=1.5.0.59=0 43 | - ld_impl_linux-64=2.38=h1181459_1 44 | - libcublas=12.0.1.189=0 45 | - libcublas-dev=12.0.1.189=0 46 | - libcufft=11.0.0.21=0 47 | - libcufft-dev=11.0.0.21=0 48 | - libcufile=1.5.0.59=0 49 | - libcufile-dev=1.5.0.59=0 50 | - libcurand=10.3.1.50=0 51 | - libcurand-dev=10.3.1.50=0 52 | - libcusolver=11.4.2.57=0 53 | - libcusolver-dev=11.4.2.57=0 54 | - libcusparse=12.0.0.76=0 55 | - libcusparse-dev=12.0.0.76=0 56 | - libffi=3.4.2=h6a678d5_6 57 | - libgcc-ng=11.2.0=h1234567_1 58 | - libgomp=11.2.0=h1234567_1 59 | - libnpp=12.0.0.30=0 60 | - libnpp-dev=12.0.0.30=0 61 | - libnvjpeg=12.0.0.28=0 62 | - libnvjpeg-dev=12.0.0.28=0 63 | - libstdcxx-ng=11.2.0=h1234567_1 64 | - ncurses=6.3=h5eee18b_3 65 | - nsight-compute=2022.4.0.15=0 66 | - openssl=1.1.1s=h7f8727e_0 67 | - pip=22.3.1=py39h06a4308_0 68 | - python=3.9.15=h7a1cb2a_2 69 | - readline=8.2=h5eee18b_0 70 | - setuptools=65.5.0=py39h06a4308_0 71 | - sqlite=3.40.0=h5082296_0 72 | - tk=8.6.12=h1ccaba5_0 73 | - tzdata=2022g=h04d1e81_0 74 | - wheel=0.37.1=pyhd3eb1b0_0 75 | - xz=5.2.8=h5eee18b_0 76 | - zlib=1.2.13=h5eee18b_0 77 | - pip: 78 | - accelerate==0.15.0 79 | - beautifulsoup4==4.11.1 80 | - charset-normalizer==2.1.1 81 | - clip==0.2.0 82 | - contourpy==1.0.6 83 | - cycler==0.11.0 84 | - filelock==3.8.2 85 | - fonttools==4.38.0 86 | - ftfy==6.1.1 87 | - gdown==4.6.0 88 | - huggingface-hub==0.11.1 89 | - idna==3.4 90 | - joblib==1.2.0 91 | - kiwisolver==1.4.4 92 | - littleutils==0.2.2 93 | - matplotlib==3.6.2 94 | - numpy==1.23.5 95 | - ogb==1.3.5 96 | - outdated==0.2.2 97 | - packaging==22.0 98 | - pandas==1.5.2 99 | - pillow==9.3.0 100 | - psutil==5.9.4 101 | - pyparsing==3.0.9 102 | - pysocks==1.7.1 103 | - python-dateutil==2.8.2 104 | - pytz==2022.6 105 | - pyyaml==6.0 106 | - regex==2022.10.31 107 | - requests==2.28.1 108 | - scikit-learn==1.2.0 109 | - scipy==1.9.3 110 | - seaborn==0.12.2 111 | - six==1.16.0 112 | - soupsieve==2.3.2.post1 113 | - threadpoolctl==3.1.0 114 | - timm==0.3.4 115 | - tokenizers==0.13.2 116 | - tqdm==4.64.1 117 | - transformers==4.25.1 118 | - typing-extensions==4.4.0 119 | - urllib3==1.26.13 120 | - wcwidth==0.2.5 121 | - wilds==2.0.0 122 | -------------------------------------------------------------------------------- /figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/figure.png -------------------------------------------------------------------------------- /generation/DeepAugment/CAE_Model/__pycache__/cae_32x32x32_zero_pad_bin.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/DeepAugment/CAE_Model/__pycache__/cae_32x32x32_zero_pad_bin.cpython-39.pyc -------------------------------------------------------------------------------- /generation/DeepAugment/CAE_Model/cae_32x32x32_zero_pad_bin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import random 6 | 7 | 8 | class CAE(nn.Module): 9 | """ 10 | This AE module will be fed 3x128x128 patches from the original image 11 | Shapes are (batch_size, channels, height, width) 12 | Latent representation: 32x32x32 bits per patch => 240KB per image (for 720p) 13 | """ 14 | 15 | def __init__(self): 16 | super(CAE, self).__init__() 17 | 18 | self.encoded = None 19 | 20 | # ENCODER 21 | 22 | # 64x64x64 23 | self.e_conv_1 = nn.Sequential( 24 | nn.ZeroPad2d((1, 2, 1, 2)), 25 | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(5, 5), stride=(2, 2)), 26 | nn.LeakyReLU() 27 | ) 28 | 29 | # 128x32x32 30 | self.e_conv_2 = nn.Sequential( 31 | nn.ZeroPad2d((1, 2, 1, 2)), 32 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(5, 5), stride=(2, 2)), 33 | nn.LeakyReLU() 34 | ) 35 | 36 | # 128x32x32 37 | self.e_block_1 = nn.Sequential( 38 | nn.ZeroPad2d((1, 1, 1, 1)), 39 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)), 40 | nn.LeakyReLU(), 41 | 42 | nn.ZeroPad2d((1, 1, 1, 1)), 43 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)), 44 | ) 45 | 46 | # 128x32x32 47 | self.e_block_2 = nn.Sequential( 48 | nn.ZeroPad2d((1, 1, 1, 1)), 49 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)), 50 | nn.LeakyReLU(), 51 | 52 | nn.ZeroPad2d((1, 1, 1, 1)), 53 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)), 54 | ) 55 | 56 | # 128x32x32 57 | self.e_block_3 = nn.Sequential( 58 | nn.ZeroPad2d((1, 1, 1, 1)), 59 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)), 60 | nn.LeakyReLU(), 61 | 62 | nn.ZeroPad2d((1, 1, 1, 1)), 63 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)), 64 | ) 65 | 66 | # 32x32x32 67 | self.e_conv_3 = nn.Sequential( 68 | nn.Conv2d(in_channels=128, out_channels=32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)), 69 | nn.Tanh() 70 | ) 71 | 72 | # DECODER 73 | 74 | # a 75 | self.d_up_conv_1 = nn.Sequential( 76 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), stride=(1, 1)), 77 | nn.LeakyReLU(), 78 | 79 | nn.ZeroPad2d((1, 1, 1, 1)), 80 | nn.ConvTranspose2d(in_channels=64, out_channels=128, kernel_size=(2, 2), stride=(2, 2)) 81 | ) 82 | 83 | # 128x64x64 84 | self.d_block_1 = nn.Sequential( 85 | nn.ZeroPad2d((1, 1, 1, 1)), 86 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)), 87 | nn.LeakyReLU(), 88 | 89 | nn.ZeroPad2d((1, 1, 1, 1)), 90 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)), 91 | ) 92 | 93 | # 128x64x64 94 | self.d_block_2 = nn.Sequential( 95 | nn.ZeroPad2d((1, 1, 1, 1)), 96 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)), 97 | nn.LeakyReLU(), 98 | 99 | nn.ZeroPad2d((1, 1, 1, 1)), 100 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)), 101 | ) 102 | 103 | # 128x64x64 104 | self.d_block_3 = nn.Sequential( 105 | nn.ZeroPad2d((1, 1, 1, 1)), 106 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)), 107 | nn.LeakyReLU(), 108 | 109 | nn.ZeroPad2d((1, 1, 1, 1)), 110 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1)), 111 | ) 112 | 113 | # 256x128x128 114 | self.d_up_conv_2 = nn.Sequential( 115 | nn.Conv2d(in_channels=128, out_channels=32, kernel_size=(3, 3), stride=(1, 1)), 116 | nn.LeakyReLU(), 117 | 118 | nn.ZeroPad2d((1, 1, 1, 1)), 119 | nn.ConvTranspose2d(in_channels=32, out_channels=256, kernel_size=(2, 2), stride=(2, 2)) 120 | ) 121 | 122 | # 3x128x128 123 | self.d_up_conv_3 = nn.Sequential( 124 | nn.Conv2d(in_channels=256, out_channels=16, kernel_size=(3, 3), stride=(1, 1)), 125 | nn.LeakyReLU(), 126 | 127 | nn.ReflectionPad2d((2, 2, 2, 2)), 128 | nn.Conv2d(in_channels=16, out_channels=3, kernel_size=(3, 3), stride=(1, 1)), 129 | nn.Tanh() 130 | ) 131 | 132 | def forward(self, x): 133 | ec1 = self.e_conv_1(x) 134 | ec2 = self.e_conv_2(ec1) 135 | eblock1 = self.e_block_1(ec2) + ec2 136 | eblock2 = self.e_block_2(eblock1) + eblock1 137 | eblock3 = self.e_block_3(eblock2) + eblock2 138 | ec3 = self.e_conv_3(eblock3) # in [-1, 1] from tanh activation 139 | option = np.random.choice(range(9)) 140 | if option == 1: 141 | # set some weights to zero 142 | H = ec3.size()[2] 143 | W = ec3.size()[3] 144 | mask = (torch.cuda.FloatTensor(H, W).uniform_() > 0.2).float().cuda() 145 | ec3 = ec3 * mask 146 | del mask 147 | elif option == 2: 148 | # negare some of the weights 149 | H = ec3.size()[2] 150 | W = ec3.size()[3] 151 | mask = (((torch.cuda.FloatTensor(H, W).uniform_() > 0.1).float() * 2) - 1).cuda() 152 | ec3 = ec3 * mask 153 | del mask 154 | elif option == 3: 155 | num_channels = 10 156 | perm = np.array(list(np.random.permutation(num_channels)) + list(range(num_channels, ec3.size()[1]))) 157 | ec3 = ec3[:, perm, :, :] 158 | elif option == 4: 159 | num_channels = ec3.shape[1] 160 | num_channels_transform = 5 161 | 162 | _k = random.randint(1,3) 163 | _dims = [0, 1, 2] 164 | random.shuffle(_dims) 165 | _dims = _dims[:2] 166 | 167 | for i in range(num_channels_transform): 168 | filter_select = random.choice(list(range(num_channels))) 169 | ec3[:,filter_select] = torch.flip(ec3[:,filter_select], dims=_dims) 170 | elif option == 5: 171 | num_channels = ec3.shape[1] 172 | num_channels_transform = num_channels 173 | 174 | _k = random.randint(1,3) 175 | _dims = [0, 1, 2] 176 | random.shuffle(_dims) 177 | _dims = _dims[:2] 178 | 179 | for i in range(num_channels_transform): 180 | if i == num_channels_transform / 2: 181 | _dims = [_dims[1], _dims[0]] 182 | ec3[:,i] = torch.flip(ec3[:,i], dims=_dims) 183 | elif option == 6: 184 | with torch.no_grad(): 185 | c, h, w = ec3.shape[1], ec3.shape[2], ec3.shape[3] 186 | z = torch.zeros(c, c, 3, 3).cuda() 187 | for j in range(z.size(0)): 188 | shift_x, shift_y = 1, 1# np.random.randint(3, size=(2,)) 189 | z[j,j,shift_x,shift_y] = 1 # np.random.choice([1.,-1.]) 190 | 191 | # Without this line, z would be the identity convolution 192 | z = z + ((torch.rand_like(z) - 0.5) * 0.2) 193 | ec3 = F.conv2d(ec3, z, padding=1) 194 | del z 195 | elif option == 7: 196 | with torch.no_grad(): 197 | c, h, w = ec3.shape[1], ec3.shape[2], ec3.shape[3] 198 | z = torch.zeros(c, c, 3, 3).cuda() 199 | for j in range(z.size(0)): 200 | shift_x, shift_y = 1, 1# np.random.randint(3, size=(2,)) 201 | z[j,j,shift_x,shift_y] = 1 # np.random.choice([1.,-1.]) 202 | 203 | if random.random() < 0.5: 204 | rand_layer = random.randint(0, c - 1) 205 | z[j, rand_layer, random.randint(-1, 1), random.randint(-1, 1)] = 1 206 | 207 | ec3 = F.conv2d(ec3, z, padding=1) 208 | del z 209 | elif option == 8: 210 | with torch.no_grad(): 211 | c, h, w = ec3.shape[1], ec3.shape[2], ec3.shape[3] 212 | z = torch.zeros(c, c, 3, 3).cuda() 213 | shift_x, shift_y = np.random.randint(3, size=(2,)) 214 | for j in range(z.size(0)): 215 | if random.random() < 0.2: 216 | shift_x, shift_y = np.random.randint(3, size=(2,)) 217 | 218 | z[j,j,shift_x,shift_y] = 1 # np.random.choice([1.,-1.]) 219 | 220 | # Without this line, z would be the identity convolution 221 | # z = z + ((torch.rand_like(z) - 0.5) * 0.2) 222 | ec3 = F.conv2d(ec3, z, padding=1) 223 | del z 224 | 225 | # stochastic binarization 226 | with torch.no_grad(): 227 | rand = torch.rand(ec3.shape).cuda() 228 | prob = (1 + ec3) / 2 229 | eps = torch.zeros(ec3.shape).cuda() 230 | eps[rand <= prob] = (1 - ec3)[rand <= prob] 231 | eps[rand > prob] = (-ec3 - 1)[rand > prob] 232 | 233 | # encoded tensor 234 | self.encoded = 0.5 * (ec3 + eps + 1) # (-1|1) -> (0|1) 235 | if option == 0: 236 | self.encoded = self.encoded *\ 237 | (3 + 2 * np.float32(np.random.uniform()) * (2*torch.rand_like(self.encoded-1))) 238 | return self.decode(self.encoded) 239 | 240 | def decode(self, encoded): 241 | y = encoded * 2.0 - 1 # (0|1) -> (-1|1) 242 | 243 | uc1 = self.d_up_conv_1(y) 244 | dblock1 = self.d_block_1(uc1) + uc1 245 | dblock2 = self.d_block_2(dblock1) + dblock1 246 | dblock3 = self.d_block_3(dblock2) + dblock2 247 | uc2 = self.d_up_conv_2(dblock3) 248 | dec = self.d_up_conv_3(uc2) 249 | 250 | return dec 251 | -------------------------------------------------------------------------------- /generation/DeepAugment/CAE_Weights/model_final.state: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/DeepAugment/CAE_Weights/model_final.state -------------------------------------------------------------------------------- /generation/DeepAugment/EDSR_Model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel as P 7 | import torch.utils.model_zoo 8 | 9 | class Model(nn.Module): 10 | def __init__(self, args, ckp): 11 | super(Model, self).__init__() 12 | print('Making model...') 13 | 14 | self.scale = args.scale 15 | self.idx_scale = 0 16 | self.input_large = (args.model == 'VDSR') 17 | self.self_ensemble = args.self_ensemble 18 | self.chop = args.chop 19 | self.precision = args.precision 20 | self.cpu = args.cpu 21 | self.device = torch.device('cpu' if args.cpu else 'cuda') 22 | self.n_GPUs = args.n_GPUs 23 | self.save_models = args.save_models 24 | 25 | module = import_module('model.' + args.model.lower()) 26 | self.model = module.make_model(args).to(self.device) 27 | if args.precision == 'half': 28 | self.model.half() 29 | 30 | self.load( 31 | ckp.get_path('model'), 32 | pre_train=args.pre_train, 33 | resume=args.resume, 34 | cpu=args.cpu 35 | ) 36 | print(self.model, file=ckp.log_file) 37 | 38 | def forward(self, x, idx_scale): 39 | self.idx_scale = idx_scale 40 | if hasattr(self.model, 'set_scale'): 41 | self.model.set_scale(idx_scale) 42 | 43 | if self.training: 44 | if self.n_GPUs > 1: 45 | return P.data_parallel(self.model, x, range(self.n_GPUs)) 46 | else: 47 | return self.model(x) 48 | else: 49 | if self.chop: 50 | forward_function = self.forward_chop 51 | else: 52 | forward_function = self.model.forward 53 | 54 | if self.self_ensemble: 55 | return self.forward_x8(x, forward_function=forward_function) 56 | else: 57 | return forward_function(x) 58 | 59 | def save(self, apath, epoch, is_best=False): 60 | save_dirs = [os.path.join(apath, 'model_latest.pt')] 61 | 62 | if is_best: 63 | save_dirs.append(os.path.join(apath, 'model_best.pt')) 64 | if self.save_models: 65 | save_dirs.append( 66 | os.path.join(apath, 'model_{}.pt'.format(epoch)) 67 | ) 68 | 69 | for s in save_dirs: 70 | torch.save(self.model.state_dict(), s) 71 | 72 | def load(self, apath, pre_train='', resume=-1, cpu=False): 73 | load_from = None 74 | kwargs = {} 75 | if cpu: 76 | kwargs = {'map_location': lambda storage, loc: storage} 77 | 78 | if resume == -1: 79 | load_from = torch.load( 80 | os.path.join(apath, 'model_latest.pt'), 81 | **kwargs 82 | ) 83 | elif resume == 0: 84 | if pre_train == 'download': 85 | print('Download the model') 86 | dir_model = os.path.join('..', 'models') 87 | os.makedirs(dir_model, exist_ok=True) 88 | load_from = torch.utils.model_zoo.load_url( 89 | self.model.url, 90 | model_dir=dir_model, 91 | **kwargs 92 | ) 93 | elif pre_train: 94 | print('Load the model from {}'.format(pre_train)) 95 | load_from = torch.load(pre_train, **kwargs) 96 | else: 97 | load_from = torch.load( 98 | os.path.join(apath, 'model_{}.pt'.format(resume)), 99 | **kwargs 100 | ) 101 | 102 | if load_from: 103 | self.model.load_state_dict(load_from, strict=False) 104 | 105 | def forward_chop(self, *args, shave=10, min_size=160000): 106 | scale = 1 if self.input_large else self.scale[self.idx_scale] 107 | n_GPUs = min(self.n_GPUs, 4) 108 | # height, width 109 | h, w = args[0].size()[-2:] 110 | 111 | top = slice(0, h//2 + shave) 112 | bottom = slice(h - h//2 - shave, h) 113 | left = slice(0, w//2 + shave) 114 | right = slice(w - w//2 - shave, w) 115 | x_chops = [torch.cat([ 116 | a[..., top, left], 117 | a[..., top, right], 118 | a[..., bottom, left], 119 | a[..., bottom, right] 120 | ]) for a in args] 121 | 122 | y_chops = [] 123 | if h * w < 4 * min_size: 124 | for i in range(0, 4, n_GPUs): 125 | x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops] 126 | y = P.data_parallel(self.model, *x, range(n_GPUs)) 127 | if not isinstance(y, list): y = [y] 128 | if not y_chops: 129 | y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y] 130 | else: 131 | for y_chop, _y in zip(y_chops, y): 132 | y_chop.extend(_y.chunk(n_GPUs, dim=0)) 133 | else: 134 | for p in zip(*x_chops): 135 | y = self.forward_chop(*p, shave=shave, min_size=min_size) 136 | if not isinstance(y, list): y = [y] 137 | if not y_chops: 138 | y_chops = [[_y] for _y in y] 139 | else: 140 | for y_chop, _y in zip(y_chops, y): y_chop.append(_y) 141 | 142 | h *= scale 143 | w *= scale 144 | top = slice(0, h//2) 145 | bottom = slice(h - h//2, h) 146 | bottom_r = slice(h//2 - h, None) 147 | left = slice(0, w//2) 148 | right = slice(w - w//2, w) 149 | right_r = slice(w//2 - w, None) 150 | 151 | # batch size, number of color channels 152 | b, c = y_chops[0][0].size()[:-2] 153 | y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops] 154 | for y_chop, _y in zip(y_chops, y): 155 | _y[..., top, left] = y_chop[0][..., top, left] 156 | _y[..., top, right] = y_chop[1][..., top, right_r] 157 | _y[..., bottom, left] = y_chop[2][..., bottom_r, left] 158 | _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r] 159 | 160 | if len(y) == 1: y = y[0] 161 | 162 | return y 163 | 164 | def forward_x8(self, *args, forward_function=None): 165 | def _transform(v, op): 166 | if self.precision != 'single': v = v.float() 167 | 168 | v2np = v.data.cpu().numpy() 169 | if op == 'v': 170 | tfnp = v2np[:, :, :, ::-1].copy() 171 | elif op == 'h': 172 | tfnp = v2np[:, :, ::-1, :].copy() 173 | elif op == 't': 174 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 175 | 176 | ret = torch.Tensor(tfnp).to(self.device) 177 | if self.precision == 'half': ret = ret.half() 178 | 179 | return ret 180 | 181 | list_x = [] 182 | for a in args: 183 | x = [a] 184 | for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x]) 185 | 186 | list_x.append(x) 187 | 188 | list_y = [] 189 | for x in zip(*list_x): 190 | y = forward_function(*x) 191 | if not isinstance(y, list): y = [y] 192 | if not list_y: 193 | list_y = [[_y] for _y in y] 194 | else: 195 | for _list_y, _y in zip(list_y, y): _list_y.append(_y) 196 | 197 | for _list_y in list_y: 198 | for i in range(len(_list_y)): 199 | if i > 3: 200 | _list_y[i] = _transform(_list_y[i], 't') 201 | if i % 4 > 1: 202 | _list_y[i] = _transform(_list_y[i], 'h') 203 | if (i % 4) % 2 == 1: 204 | _list_y[i] = _transform(_list_y[i], 'v') 205 | 206 | y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y] 207 | if len(y) == 1: y = y[0] 208 | 209 | return y 210 | -------------------------------------------------------------------------------- /generation/DeepAugment/EDSR_Model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 8 | return nn.Conv2d( 9 | in_channels, out_channels, kernel_size, 10 | padding=(kernel_size//2), bias=bias) 11 | 12 | class MeanShift(nn.Conv2d): 13 | def __init__( 14 | self, rgb_range, 15 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 16 | 17 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 18 | std = torch.Tensor(rgb_std) 19 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 20 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | class BasicBlock(nn.Sequential): 25 | def __init__( 26 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 27 | bn=True, act=nn.ReLU(True)): 28 | 29 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 30 | if bn: 31 | m.append(nn.BatchNorm2d(out_channels)) 32 | if act is not None: 33 | m.append(act) 34 | 35 | super(BasicBlock, self).__init__(*m) 36 | 37 | class ResBlock(nn.Module): 38 | def __init__( 39 | self, conv, n_feats, kernel_size, 40 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 41 | 42 | super(ResBlock, self).__init__() 43 | m = [] 44 | for i in range(2): 45 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 46 | if bn: 47 | m.append(nn.BatchNorm2d(n_feats)) 48 | if i == 0: 49 | m.append(act) 50 | 51 | self.body = nn.Sequential(*m) 52 | self.res_scale = res_scale 53 | 54 | def forward(self, x): 55 | res = self.body(x).mul(self.res_scale) 56 | res += x 57 | 58 | return res 59 | 60 | class Upsampler(nn.Sequential): 61 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 62 | 63 | m = [] 64 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 65 | for _ in range(int(math.log(scale, 2))): 66 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 67 | m.append(nn.PixelShuffle(2)) 68 | if bn: 69 | m.append(nn.BatchNorm2d(n_feats)) 70 | if act == 'relu': 71 | m.append(nn.ReLU(True)) 72 | elif act == 'prelu': 73 | m.append(nn.PReLU(n_feats)) 74 | 75 | elif scale == 3: 76 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 77 | m.append(nn.PixelShuffle(3)) 78 | if bn: 79 | m.append(nn.BatchNorm2d(n_feats)) 80 | if act == 'relu': 81 | m.append(nn.ReLU(True)) 82 | elif act == 'prelu': 83 | m.append(nn.PReLU(n_feats)) 84 | else: 85 | raise NotImplementedError 86 | 87 | super(Upsampler, self).__init__(*m) 88 | 89 | -------------------------------------------------------------------------------- /generation/DeepAugment/EDSR_Model/edsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | url = { 6 | 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt', 7 | 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt', 8 | 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt', 9 | 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt', 10 | 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt', 11 | 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt' 12 | } 13 | 14 | def make_model(args, parent=False): 15 | return EDSR(args) 16 | 17 | class EDSR(nn.Module): 18 | def __init__(self, args, conv=common.default_conv): 19 | super(EDSR, self).__init__() 20 | 21 | n_resblocks = args.n_resblocks 22 | n_feats = args.n_feats 23 | kernel_size = 3 24 | scale = args.scale[0] 25 | act = nn.ReLU(True) 26 | url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale) 27 | if url_name in url: 28 | self.url = url[url_name] 29 | else: 30 | self.url = None 31 | self.sub_mean = common.MeanShift(args.rgb_range) 32 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 33 | 34 | # define head module 35 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 36 | 37 | # define body module 38 | m_body = [ 39 | common.ResBlock( 40 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale 41 | ) for _ in range(n_resblocks) 42 | ] 43 | m_body.append(conv(n_feats, n_feats, kernel_size)) 44 | 45 | # define tail module 46 | m_tail = [ 47 | common.Upsampler(conv, scale, n_feats, act=False), 48 | conv(n_feats, args.n_colors, kernel_size) 49 | ] 50 | 51 | self.head = nn.Sequential(*m_head) 52 | self.body = nn.Sequential(*m_body) 53 | self.tail = nn.Sequential(*m_tail) 54 | 55 | def forward(self, x): 56 | x = self.sub_mean(x) 57 | x = self.head(x) 58 | 59 | res = self.body(x) 60 | res += x 61 | 62 | x = self.tail(res) 63 | x = self.add_mean(x) 64 | 65 | return x 66 | 67 | def load_state_dict(self, state_dict, strict=True): 68 | own_state = self.state_dict() 69 | for name, param in state_dict.items(): 70 | if name in own_state: 71 | if isinstance(param, nn.Parameter): 72 | param = param.data 73 | try: 74 | own_state[name].copy_(param) 75 | except Exception: 76 | if name.find('tail') == -1: 77 | raise RuntimeError('While copying the parameter named {}, ' 78 | 'whose dimensions in the model are {} and ' 79 | 'whose dimensions in the checkpoint are {}.' 80 | .format(name, own_state[name].size(), param.size())) 81 | elif strict: 82 | if name.find('tail') == -1: 83 | raise KeyError('unexpected key "{}" in state_dict' 84 | .format(name)) 85 | 86 | -------------------------------------------------------------------------------- /generation/DeepAugment/EDSR_Weights/edsr_baseline_x4.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/DeepAugment/EDSR_Weights/edsr_baseline_x4.pt -------------------------------------------------------------------------------- /generation/DeepAugment/README.md: -------------------------------------------------------------------------------- 1 | # DeepAugment 2 | 3 | 4 | 5 | ## DeepAugment Files 6 | 7 | Here is an overview of the files needed to create ImageNet training data augmented with DeepAugment. 8 | Alternatively, you can download our [EDSR](https://drive.google.com/file/d/1Ij_D3LuHWI4_WOlsg6dJMAPKEu10g47_/view?usp=sharing) and [CAE](https://drive.google.com/file/d/1xN9Z7pZ2GNwRww7j8ClPnPNwOc12VsM5/view?usp=sharing) images directly. 9 | 10 | ## Create Datasets 11 | - `EDSR_distort_imagenet.py`: Creates a distorted version of ImageNet using EDSR (https://arxiv.org/abs/1707.02921) 12 | - `CAE_distort_imagenet.py`: Creates a distorted version of ImageNet using a CAE (https://arxiv.org/abs/1703.00395) 13 | 14 | The above scripts can be run in parallel to speed up dataset creation (multiple workers processing different classes). For example, to split dataset creation across 5 processes, run the following in parallel: 15 | ```bash 16 | CUDA_VISIBLE_DEVICES=0 python3 EDSR_distort_imagenet.py --total-workers=5 --worker-number=0 17 | CUDA_VISIBLE_DEVICES=1 python3 EDSR_distort_imagenet.py --total-workers=5 --worker-number=1 18 | CUDA_VISIBLE_DEVICES=2 python3 EDSR_distort_imagenet.py --total-workers=5 --worker-number=2 19 | CUDA_VISIBLE_DEVICES=3 python3 EDSR_distort_imagenet.py --total-workers=5 --worker-number=3 20 | CUDA_VISIBLE_DEVICES=4 python3 EDSR_distort_imagenet.py --total-workers=5 --worker-number=4 21 | ``` 22 | You will need to change the save path and original ImageNet train set path. 23 | 24 | ## DeepAugment with Noise2Net 25 | 26 | In addition to EDSR and CAE, the DeepAugment approach works with randomly sampled architectures. We call an example [Noise2Net](https://openreview.net/pdf?id=o20_NVA92tK#page=20), which can generate augmentations in memory and in parallel. See the code above. 27 | 28 | [ResNet-50 + DeepAugment (Noise2Net)](https://drive.google.com/file/d/1TlylFp-hJ5ENUNdjGBf-tLPoIIcJyPG4/view?usp=sharing) 29 | 30 | [ResNeXt-101 + AugMix + DeepAugment (Noise2Net, EDSR, CAE)](https://drive.google.com/file/d/1dJoYMma4uZwT-wLzk4Teh8dKqCVz7Tlx/view?usp=sharing) 31 | -------------------------------------------------------------------------------- /generation/DeepAugment/deepaugment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/DeepAugment/deepaugment.png -------------------------------------------------------------------------------- /generation/DeepAugment/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source ~/pytorch.sh 4 | 5 | python train.py --dist-url 'tcp://127.0.0.1:32767' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 /data/imagenet 6 | -------------------------------------------------------------------------------- /generation/__pycache__/conditional_ldm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/__pycache__/conditional_ldm.cpython-39.pyc -------------------------------------------------------------------------------- /generation/__pycache__/data.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/__pycache__/data.cpython-39.pyc -------------------------------------------------------------------------------- /generation/__pycache__/generate_images.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/__pycache__/generate_images.cpython-39.pyc -------------------------------------------------------------------------------- /generation/__pycache__/generate_images_captions.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/__pycache__/generate_images_captions.cpython-39.pyc -------------------------------------------------------------------------------- /generation/__pycache__/generate_images_i2i.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/__pycache__/generate_images_i2i.cpython-39.pyc -------------------------------------------------------------------------------- /generation/cin256-v2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 64 23 | in_channels: 3 24 | out_channels: 3 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 5 36 | num_heads: 1 37 | use_spatial_transformer: true 38 | transformer_depth: 1 39 | context_dim: 512 40 | 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 3 45 | n_embed: 8192 46 | ddconfig: 47 | double_z: false 48 | z_channels: 3 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.ClassEmbedder 65 | params: 66 | n_classes: 1001 67 | embed_dim: 512 68 | key: class_label 69 | -------------------------------------------------------------------------------- /generation/conditional_ldm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import numpy as np 5 | 6 | from tqdm import tqdm 7 | from PIL import Image 8 | from data import CondLDM 9 | from einops import rearrange 10 | from omegaconf import OmegaConf 11 | from taming.models import vqgan 12 | from accelerate import Accelerator 13 | from ldm.util import instantiate_from_config 14 | from ldm.models.diffusion.ddim import DDIMSampler 15 | 16 | os.environ["NCCL_P2P_DISABLE"] = "1" 17 | 18 | parser = argparse.ArgumentParser() 19 | 20 | parser.add_argument("--batch_size", type = int, default = 24) 21 | parser.add_argument("--split", type = str, default = "test", help = "Path to eval test data") 22 | parser.add_argument("--config", type = str, default = None, help = "Path to config file") 23 | parser.add_argument("--checkpoint", type = str, default = None, help = "Path to checkpoint") 24 | parser.add_argument("--save_image_size", type = int, default = 64) 25 | parser.add_argument("--save_image_gen", type = str, default = None, help = "Path saved generated images") 26 | 27 | ''' 28 | accelerate launch --num_cpu_threads_per_process 8 --main_process_port 9876 conditional_ldm.py --config cin256-v2.yaml --checkpoint /home/data/ckpts/hbansal/ldm/model.ckpt --save_image_gen /home/data/datasets/ImageNet100/CondGeneration/ 29 | ''' 30 | args = parser.parse_args() 31 | 32 | accelerator = Accelerator() 33 | os.makedirs(os.path.join(args.save_image_gen, args.split), exist_ok = True) 34 | 35 | def load_model_from_config(config, ckpt): 36 | print(f"Loading model from {ckpt}") 37 | pl_sd = torch.load(ckpt)#, map_location="cpu") 38 | sd = pl_sd["state_dict"] 39 | model = instantiate_from_config(config.model) 40 | m, u = model.load_state_dict(sd, strict=False) 41 | return model 42 | 43 | 44 | def get_model(): 45 | config = OmegaConf.load(args.config) 46 | model = load_model_from_config(config, args.checkpoint) 47 | return model 48 | 49 | def generate_images(model, sampler, dataloader, args): 50 | 51 | model, sampler, dataloader = accelerator.prepare(model, sampler, dataloader) 52 | model = model.to(accelerator.device) 53 | model.eval() 54 | 55 | ## Hyperparameters 56 | ddim_steps = 20 57 | ddim_eta = 0.0 58 | scale = 3.0 # for unconditional guidance 59 | 60 | model = model.module 61 | with torch.no_grad(): 62 | with model.ema_scope(): 63 | for class_indices, class_labels, folder_names in tqdm(dataloader): 64 | for folder_name in folder_names: 65 | os.makedirs(os.path.join(args.save_image_gen, args.split, folder_name), exist_ok = True) 66 | indices = list(filter(lambda x: not os.path.exists(os.path.join(args.save_image_gen, args.split, folder_names[x], str(class_indices[x].item()) + ".png")), range(len(class_indices)))) 67 | if len(indices) == 0: continue 68 | class_indices = [class_indices[i] for i in indices] 69 | class_labels = [class_labels[i] for i in indices] 70 | folder_names = [folder_names[i] for i in indices] 71 | class_labels = torch.tensor(class_labels) 72 | uc = model.get_learned_conditioning( 73 | {model.cond_stage_key: torch.tensor(len(indices)*[1000]).to(model.device)} 74 | ) 75 | c = model.get_learned_conditioning({model.cond_stage_key: class_labels.to(model.device)}) 76 | print(len(indices)) 77 | samples_ddim, _ = sampler.sample(S=ddim_steps, 78 | conditioning=c, 79 | batch_size=len(indices), 80 | shape=[3, args.save_image_size, args.save_image_size], 81 | verbose=False, 82 | unconditional_guidance_scale=scale, 83 | unconditional_conditioning=uc, 84 | eta=ddim_eta) 85 | x_samples_ddim = model.decode_first_stage(samples_ddim) 86 | x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) 87 | samples = 255. * rearrange(x_samples_ddim, 'b c h w -> b h w c').cpu().numpy() 88 | for index in range(len(samples)): 89 | im = Image.fromarray(samples[index].astype(np.uint8)) 90 | im.save(os.path.join(args.save_image_gen, args.split, folder_names[index], str(class_indices[index].item()) + ".png")) 91 | def main(): 92 | model = get_model() 93 | sampler = DDIMSampler(model) 94 | 95 | dataset = CondLDM(num_images_per_class = 1300 if args.split == 'train' else 50) 96 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size) 97 | 98 | generate_images(model, sampler, dataloader, args) 99 | 100 | if __name__ == "__main__": 101 | main() -------------------------------------------------------------------------------- /generation/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL 3 | import glob 4 | import torch 5 | import torchvision 6 | import requests 7 | import random 8 | import pickle 9 | import numpy as np 10 | import pandas as pd 11 | from io import BytesIO 12 | from PIL import Image, ImageFile 13 | from torchvision import transforms 14 | from collections import defaultdict 15 | from torchvision.datasets import VisionDataset 16 | from torch.utils.data import Dataset, DataLoader 17 | 18 | 19 | ImageFile.LOAD_TRUNCATED_IMAGES = True 20 | 21 | def preprocess_image(image): 22 | w, h = image.size 23 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 24 | image = image.resize((512, 512), resample=PIL.Image.LANCZOS) 25 | image = np.array(image).astype(np.float32) / 255.0 26 | image = image[None].transpose(0, 3, 1, 2) 27 | image = torch.from_numpy(image) 28 | return 2.0 * image - 1.0 29 | 30 | class CondLDM(Dataset): 31 | 32 | def __init__(self, num_images_per_class = 1300): 33 | 34 | im100_folders = ['n02869837', 'n01749939', 'n02488291', 'n02107142', 'n13037406', 'n02091831', 'n04517823', 'n04589890', 'n03062245', 'n01773797', 'n01735189', 'n07831146', 'n07753275', 'n03085013', 'n04485082', 'n02105505', 'n01983481', 'n02788148', 'n03530642', 'n04435653', 'n02086910', 'n02859443', 'n13040303', 'n03594734', 'n02085620', 'n02099849', 'n01558993', 'n04493381', 'n02109047', 'n04111531', 'n02877765', 'n04429376', 'n02009229', 'n01978455', 'n02106550', 'n01820546', 'n01692333', 'n07714571', 'n02974003', 'n02114855', 'n03785016', 'n03764736', 'n03775546', 'n02087046', 'n07836838', 'n04099969', 'n04592741', 'n03891251', 'n02701002', 'n03379051', 'n02259212', 'n07715103', 'n03947888', 'n04026417', 'n02326432', 'n03637318', 'n01980166', 'n02113799', 'n02086240', 'n03903868', 'n02483362', 'n04127249', 'n02089973', 'n03017168', 'n02093428', 'n02804414', 'n02396427', 'n04418357', 'n02172182', 'n01729322', 'n02113978', 'n03787032', 'n02089867', 'n02119022', 'n03777754', 'n04238763', 'n02231487', 'n03032252', 'n02138441', 'n02104029', 'n03837869', 'n03494278', 'n04136333', 'n03794056', 'n03492542', 'n02018207', 'n04067472', 'n03930630', 'n03584829', 'n02123045', 'n04229816', 'n02100583', 'n03642806', 'n04336792', 'n03259280', 'n02116738', 'n02108089', 'n03424325', 'n01855672', 'n02090622'] 35 | im100_classes = [452, 64, 374, 236, 993, 176, 882, 904, 503, 74, 57, 959, 953, 508, 872, 228, 122, 421, 599, 858, 157, 449, 994, 608, 151, 209, 15, 876, 246, 766, 455, 857, 131, 119, 234, 90, 45, 936, 479, 272, 665, 653, 659, 158, 960, 765, 908, 703, 407, 560, 317, 938, 724, 748, 331, 619, 120, 267, 155, 708, 368, 772, 167, 494, 180, 431, 342, 854, 305, 54, 268, 667, 166, 277, 662, 798, 313, 498, 299, 222, 682, 593, 775, 674, 592, 137, 758, 717, 606, 281, 796, 211, 620, 830, 544, 275, 242, 570, 99, 169] 36 | 37 | self.folders = im100_folders * num_images_per_class 38 | self.classes = im100_classes * num_images_per_class 39 | 40 | def __len__(self): 41 | return len(self.classes) 42 | 43 | def __getitem__(self, idx): 44 | return idx, self.classes[idx], self.folders[idx] 45 | 46 | class PromptDataset(Dataset): 47 | 48 | def __init__(self, root, split = 'train', dataset = 'imagenet', diversity = True, i2i = False, transform = None): 49 | 50 | self.root = root 51 | self.split = split 52 | config = eval(open(os.path.join(self.root, 'classes.py'), 'r').read()) 53 | self.diversity = diversity 54 | self.templates = config["templates"] if self.diversity else lambda x: f'a photo of a {x}' 55 | self.folders = os.listdir(os.path.join(self.root, self.split)) 56 | if dataset == 'imagenet': 57 | df = pd.read_csv(os.path.join(self.root, 'folder_to_class.csv')) 58 | df = df[['folder', 'class']] 59 | self.folder_to_class = df.set_index('folder').T.to_dict('list') 60 | elif dataset == 'cifar10': 61 | # folder names in cifar are class names 62 | self.folder_to_class = {k:k for k in self.folders} 63 | 64 | self.images = [] 65 | self.classes = [] 66 | for folder in self.folders: 67 | class_images = os.listdir(os.path.join(self.root, split, folder)) 68 | class_images = list(map(lambda x: os.path.join(split, folder, x), class_images)) 69 | self.images = self.images + class_images 70 | self.classes = self.classes + ([self.folder_to_class[folder]] * len(class_images)) 71 | 72 | self.i2i = i2i 73 | self.transform = transform 74 | 75 | def __len__(self): 76 | return len(self.images) 77 | 78 | def __getitem__(self, idx): 79 | if self.diversity: 80 | index = random.randint(0, len(self.templates) - 1) 81 | caption = self.templates[index](self.classes[idx][0]) 82 | else: 83 | caption = self.templates(self.classes[idx][0]) 84 | 85 | image_location = self.images[idx] 86 | 87 | if self.i2i: 88 | image = Image.open(os.path.join(self.root, image_location)).convert("RGB") 89 | return self.transform(image), image_location, caption, self.classes[idx][0] 90 | else: 91 | return image_location, caption, self.classes[idx][0] 92 | 93 | 94 | class ImageDataset(Dataset): 95 | def __init__(self, root, transform, split = 'train'): 96 | 97 | self.root = root 98 | self.transform = transform 99 | config = eval(open(os.path.join(self.root, 'classes.py'), 'r').read()) 100 | self.templates = config["templates"] 101 | self.folders = os.listdir(os.path.join(self.root, split)) 102 | self.images = [] 103 | for folder in self.folders: 104 | class_images = os.listdir(os.path.join(self.root, split, folder)) 105 | class_images = list(map(lambda x: os.path.join(split, folder, x), class_images)) 106 | self.images = self.images + class_images 107 | 108 | def __len__(self): 109 | return len(self.images) 110 | 111 | def __getitem__(self, idx): 112 | image = self.transform(Image.open(os.path.join(self.root, self.images[idx])).convert('RGB')) 113 | return self.images[idx], image -------------------------------------------------------------------------------- /generation/generate_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import torch 4 | import argparse 5 | from tqdm import tqdm 6 | from PIL import Image 7 | from .data import ImageDataset 8 | from torchvision import transforms 9 | from accelerate import Accelerator 10 | from diffusers import StableDiffusionImageVariationPipeline 11 | 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument("--batch_size", type = int, default = 2) 15 | parser.add_argument("--split", type = str, default = "train", help = "Path to eval test data") 16 | parser.add_argument("--data_dir", type = str, default = "/home/data/ImageNet1K/validation", help = "Path to eval test data") 17 | parser.add_argument("--save_image_gen", type = str, default = None, help = "Path saved generated images") 18 | 19 | args = parser.parse_args() 20 | 21 | accelerator = Accelerator() 22 | os.makedirs(args.save_image_gen, exist_ok = True) 23 | 24 | def generate_images(pipe, dataloader, args): 25 | pipe, dataloader = accelerator.prepare(pipe, dataloader) 26 | pipe = pipe.to(accelerator.device) 27 | filename = os.path.join(args.save_image_gen, 'images_variation.csv') 28 | with torch.no_grad(): 29 | for image_locations, original_images in tqdm(dataloader): 30 | indices = list(filter(lambda x: not os.path.exists(os.path.join(args.save_image_gen, image_locations[x])), range(len(image_locations)))) 31 | if len(indices) == 0: 32 | continue 33 | original_images = original_images[indices] 34 | image_locations = [image_locations[i] for i in indices] 35 | images = pipe(original_images, guidance_scale = 3).images 36 | for index in range(len(images)): 37 | os.makedirs(os.path.join(args.save_image_gen, os.path.dirname(image_locations[index])), exist_ok = True) 38 | images[index].save(os.path.join(args.save_image_gen, image_locations[index])) 39 | 40 | def main(): 41 | model_name_path = "lambdalabs/sd-image-variations-diffusers" 42 | pipe = StableDiffusionImageVariationPipeline.from_pretrained(model_name_path, revision = "v2.0") 43 | 44 | tform = transforms.Compose([ 45 | transforms.ToTensor(), 46 | transforms.Resize( 47 | (224, 224), 48 | interpolation=transforms.InterpolationMode.BICUBIC, 49 | antialias=False, 50 | ), 51 | transforms.Normalize( 52 | [0.48145466, 0.4578275, 0.40821073], 53 | [0.26862954, 0.26130258, 0.27577711]), 54 | ]) 55 | 56 | dataset = ImageDataset(args.data_dir, tform, split = args.split) 57 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, shuffle = False) 58 | generate_images(pipe, dataloader, args) 59 | 60 | 61 | if __name__ == "__main__": 62 | main() -------------------------------------------------------------------------------- /generation/generate_images_captions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import torch 4 | import argparse 5 | import shutil 6 | import pandas as pd 7 | from tqdm import tqdm 8 | from PIL import Image 9 | from .data import PromptDataset 10 | from torchvision import transforms 11 | from accelerate import Accelerator 12 | from diffusers import StableDiffusionPipeline 13 | 14 | parser = argparse.ArgumentParser() 15 | 16 | parser.add_argument("--batch_size", type = int, default = 2) 17 | parser.add_argument("--split", type = str, default = "train", help = "Path to eval test data") 18 | parser.add_argument("--dataset", type = str, default = "imagenet", help = "dataset name") 19 | parser.add_argument("--data_dir", type = str, default = "/home/data/ImageNet1K/validation", help = "Path to eval test data") 20 | parser.add_argument("--save_image_gen", type = str, default = None, help = "Path saved generated images") 21 | parser.add_argument('--save_real', action='store_true', help='save real or not') 22 | parser.add_argument('--diversity', action='store_true', help='diverse captions or not') 23 | 24 | args = parser.parse_args() 25 | accelerator = Accelerator() 26 | os.makedirs(args.save_image_gen, exist_ok = True) 27 | 28 | filename = os.path.join(args.save_image_gen, 'train_captions.csv') 29 | def generate_images(pipe, dataloader, args): 30 | pipe, dataloader = accelerator.prepare(pipe, dataloader) 31 | pipe = pipe.to(accelerator.device) 32 | with torch.no_grad(): 33 | for image_locations, captions, labels in tqdm(dataloader): 34 | indices = list(filter(lambda x: not os.path.exists(os.path.join(args.save_image_gen, image_locations[x])), range(len(image_locations)))) 35 | if len(indices) == 0: 36 | continue 37 | image_locations = [image_locations[i] for i in indices] 38 | captions = [captions[i] for i in indices] 39 | labels = [labels[i] for i in indices] 40 | images = pipe(captions).images 41 | for index in range(len(images)): 42 | os.makedirs(os.path.join(args.save_image_gen, os.path.dirname(image_locations[index])), exist_ok = True) 43 | path = os.path.join(args.save_image_gen, image_locations[index]) 44 | images[index].save(path) 45 | with open(filename, 'a') as csvfile: 46 | csvwriter = csv.writer(csvfile) 47 | csvwriter.writerow([path, labels[index], captions[index]]) 48 | if args.save_real: 49 | real_img_path = os.path.join(args.save_image_gen, 'real') 50 | os.makedirs(real_img_path, exist_ok = True) 51 | shutil.copy(os.path.join(args.data_dir, image_locations[index]), real_img_path) 52 | 53 | 54 | def main(): 55 | pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) 56 | dataset = PromptDataset(args.data_dir, split = args.split, dataset = args.dataset, diversity = args.diversity) 57 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, shuffle = False) 58 | generate_images(pipe, dataloader, args) 59 | 60 | if __name__ == "__main__": 61 | main() 62 | 63 | -------------------------------------------------------------------------------- /generation/generate_images_i2i.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import torch 4 | import shutil 5 | import argparse 6 | import numpy as np 7 | import pandas as pd 8 | from tqdm import tqdm 9 | from PIL import Image 10 | from .data import PromptDataset 11 | from torchvision import transforms 12 | from accelerate import Accelerator 13 | from diffusers import StableDiffusionImg2ImgPipeline 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument("--batch_size", type = int, default = 2) 18 | parser.add_argument("--split", type = str, default = "train", help = "Path to eval test data") 19 | parser.add_argument("--dataset", type = str, default = "imagenet", help = "dataset name") 20 | parser.add_argument("--data_dir", type = str, default = "/home/data/ImageNet1K/validation", help = "Path to eval test data") 21 | parser.add_argument("--save_image_gen", type = str, default = None, help = "Path saved generated images") 22 | parser.add_argument("--save_real", type = str, default = None, help = "save real data") 23 | parser.add_argument('--diversity', action='store_true', help='diverse captions or not') 24 | 25 | args = parser.parse_args() 26 | accelerator = Accelerator() 27 | os.makedirs(args.save_image_gen, exist_ok = True) 28 | 29 | def preprocess(image): 30 | image = image.resize((512, 512), resample=Image.LANCZOS) 31 | image = np.array(image).astype(np.uint8) 32 | image = (image / 127.5 - 1.0).astype(np.float32) 33 | image = torch.from_numpy(image).permute(2, 0, 1) 34 | return image 35 | 36 | def generate_images(pipe, dataloader, args): 37 | pipe, dataloader = accelerator.prepare(pipe, dataloader) 38 | pipe = pipe.to(accelerator.device) 39 | filename = os.path.join(args.save_image_gen, 'i2i.csv') 40 | with torch.no_grad(): 41 | for original_images, image_locations, captions, labels in tqdm(dataloader): 42 | indices = list(filter(lambda x: not os.path.exists(os.path.join(args.save_image_gen, image_locations[x])), range(len(image_locations)))) 43 | if len(indices) == 0: 44 | continue 45 | original_images, image_locations, captions, labels = map(lambda x: [x[i] for i in indices], (original_images, image_locations, captions, labels)) 46 | original_images = torch.stack(original_images).to(accelerator.device) 47 | images = pipe(prompt = captions, image = original_images, strength = 1).images 48 | for index in range(len(images)): 49 | os.makedirs(os.path.join(args.save_image_gen, os.path.dirname(image_locations[index])), exist_ok = True) 50 | path = os.path.join(args.save_image_gen, image_locations[index]) 51 | images[index].save(path) 52 | with open(filename, 'a') as csvfile: 53 | csvwriter = csv.writer(csvfile) 54 | csvwriter.writerow([path, labels[index], captions[index]]) 55 | if args.save_real: 56 | real_img_path = os.path.join(args.save_image_gen, 'real') 57 | os.makedirs(real_img_path, exist_ok = True) 58 | shutil.copy(os.path.join(args.data_dir, image_locations[index]), real_img_path) 59 | 60 | 61 | def main(): 62 | model_id_or_path = "runwayml/stable-diffusion-v1-5" 63 | pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) 64 | dataset = PromptDataset(args.data_dir, split = args.split, dataset = args.dataset, diversity = args.diversity, i2i = True, transform = preprocess) 65 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, shuffle = False) 66 | generate_images(pipe, dataloader, args) 67 | 68 | if __name__ == "__main__": 69 | main() -------------------------------------------------------------------------------- /generation/ldm/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /generation/ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/data/__init__.py -------------------------------------------------------------------------------- /generation/ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /generation/ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /generation/ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /generation/ldm/models/__pycache__/autoencoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/models/__pycache__/autoencoder.cpython-39.pyc -------------------------------------------------------------------------------- /generation/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /generation/ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /generation/ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc -------------------------------------------------------------------------------- /generation/ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc -------------------------------------------------------------------------------- /generation/ldm/models/diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | from torch.nn import functional as F 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from copy import deepcopy 9 | from einops import rearrange 10 | from glob import glob 11 | from natsort import natsorted 12 | 13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel 14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config 15 | 16 | __models__ = { 17 | 'class_label': EncoderUNetModel, 18 | 'segmentation': UNetModel 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | class NoisyLatentImageClassifier(pl.LightningModule): 29 | 30 | def __init__(self, 31 | diffusion_path, 32 | num_classes, 33 | ckpt_path=None, 34 | pool='attention', 35 | label_key=None, 36 | diffusion_ckpt_path=None, 37 | scheduler_config=None, 38 | weight_decay=1.e-2, 39 | log_steps=10, 40 | monitor='val/loss', 41 | *args, 42 | **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.num_classes = num_classes 45 | # get latest config of diffusion model 46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 47 | self.diffusion_config = OmegaConf.load(diffusion_config).model 48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 49 | self.load_diffusion() 50 | 51 | self.monitor = monitor 52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 54 | self.log_steps = log_steps 55 | 56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 57 | else self.diffusion_model.cond_stage_key 58 | 59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' 60 | 61 | if self.label_key not in __models__: 62 | raise NotImplementedError() 63 | 64 | self.load_classifier(ckpt_path, pool) 65 | 66 | self.scheduler_config = scheduler_config 67 | self.use_scheduler = self.scheduler_config is not None 68 | self.weight_decay = weight_decay 69 | 70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 71 | sd = torch.load(path, map_location="cpu") 72 | if "state_dict" in list(sd.keys()): 73 | sd = sd["state_dict"] 74 | keys = list(sd.keys()) 75 | for k in keys: 76 | for ik in ignore_keys: 77 | if k.startswith(ik): 78 | print("Deleting key {} from state_dict.".format(k)) 79 | del sd[k] 80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 81 | sd, strict=False) 82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | if len(unexpected) > 0: 86 | print(f"Unexpected Keys: {unexpected}") 87 | 88 | def load_diffusion(self): 89 | model = instantiate_from_config(self.diffusion_config) 90 | self.diffusion_model = model.eval() 91 | self.diffusion_model.train = disabled_train 92 | for param in self.diffusion_model.parameters(): 93 | param.requires_grad = False 94 | 95 | def load_classifier(self, ckpt_path, pool): 96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 98 | model_config.out_channels = self.num_classes 99 | if self.label_key == 'class_label': 100 | model_config.pool = pool 101 | 102 | self.model = __models__[self.label_key](**model_config) 103 | if ckpt_path is not None: 104 | print('#####################################################################') 105 | print(f'load from ckpt "{ckpt_path}"') 106 | print('#####################################################################') 107 | self.init_from_ckpt(ckpt_path) 108 | 109 | @torch.no_grad() 110 | def get_x_noisy(self, x, t, noise=None): 111 | noise = default(noise, lambda: torch.randn_like(x)) 112 | continuous_sqrt_alpha_cumprod = None 113 | if self.diffusion_model.use_continuous_noise: 114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 115 | # todo: make sure t+1 is correct here 116 | 117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 119 | 120 | def forward(self, x_noisy, t, *args, **kwargs): 121 | return self.model(x_noisy, t) 122 | 123 | @torch.no_grad() 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = rearrange(x, 'b h w c -> b c h w') 129 | x = x.to(memory_format=torch.contiguous_format).float() 130 | return x 131 | 132 | @torch.no_grad() 133 | def get_conditioning(self, batch, k=None): 134 | if k is None: 135 | k = self.label_key 136 | assert k is not None, 'Needs to provide label key' 137 | 138 | targets = batch[k].to(self.device) 139 | 140 | if self.label_key == 'segmentation': 141 | targets = rearrange(targets, 'b h w c -> b c h w') 142 | for down in range(self.numd): 143 | h, w = targets.shape[-2:] 144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 145 | 146 | # targets = rearrange(targets,'b c h w -> b h w c') 147 | 148 | return targets 149 | 150 | def compute_top_k(self, logits, labels, k, reduction="mean"): 151 | _, top_ks = torch.topk(logits, k, dim=1) 152 | if reduction == "mean": 153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 154 | elif reduction == "none": 155 | return (top_ks == labels[:, None]).float().sum(dim=-1) 156 | 157 | def on_train_epoch_start(self): 158 | # save some memory 159 | self.diffusion_model.model.to('cpu') 160 | 161 | @torch.no_grad() 162 | def write_logs(self, loss, logits, targets): 163 | log_prefix = 'train' if self.training else 'val' 164 | log = {} 165 | log[f"{log_prefix}/loss"] = loss.mean() 166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 167 | logits, targets, k=1, reduction="mean" 168 | ) 169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 170 | logits, targets, k=5, reduction="mean" 171 | ) 172 | 173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) 176 | lr = self.optimizers().param_groups[0]['lr'] 177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) 178 | 179 | def shared_step(self, batch, t=None): 180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 181 | targets = self.get_conditioning(batch) 182 | if targets.dim() == 4: 183 | targets = targets.argmax(dim=1) 184 | if t is None: 185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 186 | else: 187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 188 | x_noisy = self.get_x_noisy(x, t) 189 | logits = self(x_noisy, t) 190 | 191 | loss = F.cross_entropy(logits, targets, reduction='none') 192 | 193 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 194 | 195 | loss = loss.mean() 196 | return loss, logits, x_noisy, targets 197 | 198 | def training_step(self, batch, batch_idx): 199 | loss, *_ = self.shared_step(batch) 200 | return loss 201 | 202 | def reset_noise_accs(self): 203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 205 | 206 | def on_validation_start(self): 207 | self.reset_noise_accs() 208 | 209 | @torch.no_grad() 210 | def validation_step(self, batch, batch_idx): 211 | loss, *_ = self.shared_step(batch) 212 | 213 | for t in self.noisy_acc: 214 | _, logits, _, targets = self.shared_step(batch, t) 215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 217 | 218 | return loss 219 | 220 | def configure_optimizers(self): 221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 222 | 223 | if self.use_scheduler: 224 | scheduler = instantiate_from_config(self.scheduler_config) 225 | 226 | print("Setting up LambdaLR scheduler...") 227 | scheduler = [ 228 | { 229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 230 | 'interval': 'step', 231 | 'frequency': 1 232 | }] 233 | return [optimizer], scheduler 234 | 235 | return optimizer 236 | 237 | @torch.no_grad() 238 | def log_images(self, batch, N=8, *args, **kwargs): 239 | log = dict() 240 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 241 | log['inputs'] = x 242 | 243 | y = self.get_conditioning(batch) 244 | 245 | if self.label_key == 'class_label': 246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 247 | log['labels'] = y 248 | 249 | if ismap(y): 250 | log['labels'] = self.diffusion_model.to_rgb(y) 251 | 252 | for step in range(self.log_steps): 253 | current_time = step * self.log_time_interval 254 | 255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 256 | 257 | log[f'inputs@t{current_time}'] = x_noisy 258 | 259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 260 | pred = rearrange(pred, 'b h w c -> b c h w') 261 | 262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 263 | 264 | for key in log: 265 | log[key] = log[key][:N] 266 | 267 | return log 268 | -------------------------------------------------------------------------------- /generation/ldm/models/diffusion/ddim.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | 10 | 11 | class DDIMSampler(object): 12 | def __init__(self, model, schedule="linear", **kwargs): 13 | super().__init__() 14 | self.model = model 15 | self.ddpm_num_timesteps = model.num_timesteps 16 | self.schedule = schedule 17 | 18 | def register_buffer(self, name, attr): 19 | if type(attr) == torch.Tensor: 20 | if attr.device != torch.device("cuda"): 21 | attr = attr.to(torch.device("cuda")) 22 | setattr(self, name, attr) 23 | 24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 27 | alphas_cumprod = self.model.alphas_cumprod 28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 30 | 31 | self.register_buffer('betas', to_torch(self.model.betas)) 32 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 33 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 34 | 35 | # calculations for diffusion q(x_t | x_{t-1}) and others 36 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 37 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 38 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 39 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 41 | 42 | # ddim sampling parameters 43 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 44 | ddim_timesteps=self.ddim_timesteps, 45 | eta=ddim_eta,verbose=verbose) 46 | self.register_buffer('ddim_sigmas', ddim_sigmas) 47 | self.register_buffer('ddim_alphas', ddim_alphas) 48 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 49 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 50 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 51 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 52 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 53 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 54 | 55 | @torch.no_grad() 56 | def sample(self, 57 | S, 58 | batch_size, 59 | shape, 60 | conditioning=None, 61 | callback=None, 62 | normals_sequence=None, 63 | img_callback=None, 64 | quantize_x0=False, 65 | eta=0., 66 | mask=None, 67 | x0=None, 68 | temperature=1., 69 | noise_dropout=0., 70 | score_corrector=None, 71 | corrector_kwargs=None, 72 | verbose=True, 73 | x_T=None, 74 | log_every_t=100, 75 | unconditional_guidance_scale=1., 76 | unconditional_conditioning=None, 77 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 78 | **kwargs 79 | ): 80 | if conditioning is not None: 81 | if isinstance(conditioning, dict): 82 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 83 | if cbs != batch_size: 84 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 85 | else: 86 | if conditioning.shape[0] != batch_size: 87 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 88 | 89 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 90 | # sampling 91 | C, H, W = shape 92 | size = (batch_size, C, H, W) 93 | print(f'Data shape for DDIM sampling is {size}, eta {eta}') 94 | 95 | samples, intermediates = self.ddim_sampling(conditioning, size, 96 | callback=callback, 97 | img_callback=img_callback, 98 | quantize_denoised=quantize_x0, 99 | mask=mask, x0=x0, 100 | ddim_use_original_steps=False, 101 | noise_dropout=noise_dropout, 102 | temperature=temperature, 103 | score_corrector=score_corrector, 104 | corrector_kwargs=corrector_kwargs, 105 | x_T=x_T, 106 | log_every_t=log_every_t, 107 | unconditional_guidance_scale=unconditional_guidance_scale, 108 | unconditional_conditioning=unconditional_conditioning, 109 | ) 110 | return samples, intermediates 111 | 112 | @torch.no_grad() 113 | def ddim_sampling(self, cond, shape, 114 | x_T=None, ddim_use_original_steps=False, 115 | callback=None, timesteps=None, quantize_denoised=False, 116 | mask=None, x0=None, img_callback=None, log_every_t=100, 117 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 118 | unconditional_guidance_scale=1., unconditional_conditioning=None,): 119 | device = self.model.betas.device 120 | b = shape[0] 121 | if x_T is None: 122 | img = torch.randn(shape, device=device) 123 | else: 124 | img = x_T 125 | 126 | if timesteps is None: 127 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 128 | elif timesteps is not None and not ddim_use_original_steps: 129 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 130 | timesteps = self.ddim_timesteps[:subset_end] 131 | 132 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 133 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 134 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 135 | print(f"Running DDIM Sampling with {total_steps} timesteps") 136 | 137 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 138 | 139 | for i, step in enumerate(iterator): 140 | index = total_steps - i - 1 141 | ts = torch.full((b,), step, device=device, dtype=torch.long) 142 | 143 | if mask is not None: 144 | assert x0 is not None 145 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 146 | img = img_orig * mask + (1. - mask) * img 147 | 148 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 149 | quantize_denoised=quantize_denoised, temperature=temperature, 150 | noise_dropout=noise_dropout, score_corrector=score_corrector, 151 | corrector_kwargs=corrector_kwargs, 152 | unconditional_guidance_scale=unconditional_guidance_scale, 153 | unconditional_conditioning=unconditional_conditioning) 154 | img, pred_x0 = outs 155 | if callback: callback(i) 156 | if img_callback: img_callback(pred_x0, i) 157 | 158 | if index % log_every_t == 0 or index == total_steps - 1: 159 | intermediates['x_inter'].append(img) 160 | intermediates['pred_x0'].append(pred_x0) 161 | 162 | return img, intermediates 163 | 164 | @torch.no_grad() 165 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 166 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 167 | unconditional_guidance_scale=1., unconditional_conditioning=None): 168 | b, *_, device = *x.shape, x.device 169 | 170 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 171 | e_t = self.model.apply_model(x, t, c) 172 | else: 173 | x_in = torch.cat([x] * 2) 174 | t_in = torch.cat([t] * 2) 175 | c_in = torch.cat([unconditional_conditioning, c]) 176 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 177 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 178 | 179 | if score_corrector is not None: 180 | assert self.model.parameterization == "eps" 181 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 182 | 183 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 184 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 185 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 186 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 187 | # select parameters corresponding to the currently considered timestep 188 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 189 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 190 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 191 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 192 | 193 | # current prediction for x_0 194 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 195 | if quantize_denoised: 196 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 197 | # direction pointing to x_t 198 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 199 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 200 | if noise_dropout > 0.: 201 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 202 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 203 | return x_prev, pred_x0 204 | -------------------------------------------------------------------------------- /generation/ldm/models/diffusion/plms.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | 10 | 11 | class PLMSSampler(object): 12 | def __init__(self, model, schedule="linear", **kwargs): 13 | super().__init__() 14 | self.model = model 15 | self.ddpm_num_timesteps = model.num_timesteps 16 | self.schedule = schedule 17 | 18 | def register_buffer(self, name, attr): 19 | if type(attr) == torch.Tensor: 20 | if attr.device != torch.device("cuda"): 21 | attr = attr.to(torch.device("cuda")) 22 | setattr(self, name, attr) 23 | 24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 25 | if ddim_eta != 0: 26 | raise ValueError('ddim_eta must be 0 for PLMS') 27 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 28 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 29 | alphas_cumprod = self.model.alphas_cumprod 30 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 31 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 32 | 33 | self.register_buffer('betas', to_torch(self.model.betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 43 | 44 | # ddim sampling parameters 45 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 46 | ddim_timesteps=self.ddim_timesteps, 47 | eta=ddim_eta,verbose=verbose) 48 | self.register_buffer('ddim_sigmas', ddim_sigmas) 49 | self.register_buffer('ddim_alphas', ddim_alphas) 50 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 51 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 52 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 53 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 54 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 55 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 56 | 57 | @torch.no_grad() 58 | def sample(self, 59 | S, 60 | batch_size, 61 | shape, 62 | conditioning=None, 63 | callback=None, 64 | normals_sequence=None, 65 | img_callback=None, 66 | quantize_x0=False, 67 | eta=0., 68 | mask=None, 69 | x0=None, 70 | temperature=1., 71 | noise_dropout=0., 72 | score_corrector=None, 73 | corrector_kwargs=None, 74 | verbose=True, 75 | x_T=None, 76 | log_every_t=100, 77 | unconditional_guidance_scale=1., 78 | unconditional_conditioning=None, 79 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 80 | **kwargs 81 | ): 82 | if conditioning is not None: 83 | if isinstance(conditioning, dict): 84 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 85 | if cbs != batch_size: 86 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 87 | else: 88 | if conditioning.shape[0] != batch_size: 89 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 90 | 91 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 92 | # sampling 93 | C, H, W = shape 94 | size = (batch_size, C, H, W) 95 | print(f'Data shape for PLMS sampling is {size}') 96 | 97 | samples, intermediates = self.plms_sampling(conditioning, size, 98 | callback=callback, 99 | img_callback=img_callback, 100 | quantize_denoised=quantize_x0, 101 | mask=mask, x0=x0, 102 | ddim_use_original_steps=False, 103 | noise_dropout=noise_dropout, 104 | temperature=temperature, 105 | score_corrector=score_corrector, 106 | corrector_kwargs=corrector_kwargs, 107 | x_T=x_T, 108 | log_every_t=log_every_t, 109 | unconditional_guidance_scale=unconditional_guidance_scale, 110 | unconditional_conditioning=unconditional_conditioning, 111 | ) 112 | return samples, intermediates 113 | 114 | @torch.no_grad() 115 | def plms_sampling(self, cond, shape, 116 | x_T=None, ddim_use_original_steps=False, 117 | callback=None, timesteps=None, quantize_denoised=False, 118 | mask=None, x0=None, img_callback=None, log_every_t=100, 119 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 120 | unconditional_guidance_scale=1., unconditional_conditioning=None,): 121 | device = self.model.betas.device 122 | b = shape[0] 123 | if x_T is None: 124 | img = torch.randn(shape, device=device) 125 | else: 126 | img = x_T 127 | 128 | if timesteps is None: 129 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 130 | elif timesteps is not None and not ddim_use_original_steps: 131 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 132 | timesteps = self.ddim_timesteps[:subset_end] 133 | 134 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 135 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) 136 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 137 | print(f"Running PLMS Sampling with {total_steps} timesteps") 138 | 139 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) 140 | old_eps = [] 141 | 142 | for i, step in enumerate(iterator): 143 | index = total_steps - i - 1 144 | ts = torch.full((b,), step, device=device, dtype=torch.long) 145 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 146 | 147 | if mask is not None: 148 | assert x0 is not None 149 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 150 | img = img_orig * mask + (1. - mask) * img 151 | 152 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 153 | quantize_denoised=quantize_denoised, temperature=temperature, 154 | noise_dropout=noise_dropout, score_corrector=score_corrector, 155 | corrector_kwargs=corrector_kwargs, 156 | unconditional_guidance_scale=unconditional_guidance_scale, 157 | unconditional_conditioning=unconditional_conditioning, 158 | old_eps=old_eps, t_next=ts_next) 159 | img, pred_x0, e_t = outs 160 | old_eps.append(e_t) 161 | if len(old_eps) >= 4: 162 | old_eps.pop(0) 163 | if callback: callback(i) 164 | if img_callback: img_callback(pred_x0, i) 165 | 166 | if index % log_every_t == 0 or index == total_steps - 1: 167 | intermediates['x_inter'].append(img) 168 | intermediates['pred_x0'].append(pred_x0) 169 | 170 | return img, intermediates 171 | 172 | @torch.no_grad() 173 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 174 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 175 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): 176 | b, *_, device = *x.shape, x.device 177 | 178 | def get_model_output(x, t): 179 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 180 | e_t = self.model.apply_model(x, t, c) 181 | else: 182 | x_in = torch.cat([x] * 2) 183 | t_in = torch.cat([t] * 2) 184 | c_in = torch.cat([unconditional_conditioning, c]) 185 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 186 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 187 | 188 | if score_corrector is not None: 189 | assert self.model.parameterization == "eps" 190 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 191 | 192 | return e_t 193 | 194 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 195 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 196 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 197 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 198 | 199 | def get_x_prev_and_pred_x0(e_t, index): 200 | # select parameters corresponding to the currently considered timestep 201 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 202 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 203 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 204 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 205 | 206 | # current prediction for x_0 207 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 208 | if quantize_denoised: 209 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 210 | # direction pointing to x_t 211 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 212 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 213 | if noise_dropout > 0.: 214 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 215 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 216 | return x_prev, pred_x0 217 | 218 | e_t = get_model_output(x, t) 219 | if len(old_eps) == 0: 220 | # Pseudo Improved Euler (2nd order) 221 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 222 | e_t_next = get_model_output(x_prev, t_next) 223 | e_t_prime = (e_t + e_t_next) / 2 224 | elif len(old_eps) == 1: 225 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 226 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 227 | elif len(old_eps) == 2: 228 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 229 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 230 | elif len(old_eps) >= 3: 231 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 232 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 233 | 234 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 235 | 236 | return x_prev, pred_x0, e_t 237 | -------------------------------------------------------------------------------- /generation/ldm/modules/__pycache__/attention.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/__pycache__/attention.cpython-39.pyc -------------------------------------------------------------------------------- /generation/ldm/modules/__pycache__/ema.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/__pycache__/ema.cpython-39.pyc -------------------------------------------------------------------------------- /generation/ldm/modules/__pycache__/x_transformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/__pycache__/x_transformer.cpython-39.pyc -------------------------------------------------------------------------------- /generation/ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads 172 | 173 | q = self.to_q(x) 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 181 | 182 | if exists(mask): 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 198 | super().__init__() 199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 203 | self.norm1 = nn.LayerNorm(dim) 204 | self.norm2 = nn.LayerNorm(dim) 205 | self.norm3 = nn.LayerNorm(dim) 206 | self.checkpoint = checkpoint 207 | 208 | def forward(self, x, context=None): 209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 210 | 211 | def _forward(self, x, context=None): 212 | x = self.attn1(self.norm1(x)) + x 213 | x = self.attn2(self.norm2(x), context=context) + x 214 | x = self.ff(self.norm3(x)) + x 215 | return x 216 | 217 | 218 | class SpatialTransformer(nn.Module): 219 | """ 220 | Transformer block for image-like data. 221 | First, project the input (aka embedding) 222 | and reshape to b, t, d. 223 | Then apply standard transformer action. 224 | Finally, reshape to image 225 | """ 226 | def __init__(self, in_channels, n_heads, d_head, 227 | depth=1, dropout=0., context_dim=None): 228 | super().__init__() 229 | self.in_channels = in_channels 230 | inner_dim = n_heads * d_head 231 | self.norm = Normalize(in_channels) 232 | 233 | self.proj_in = nn.Conv2d(in_channels, 234 | inner_dim, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | 239 | self.transformer_blocks = nn.ModuleList( 240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 241 | for d in range(depth)] 242 | ) 243 | 244 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 245 | in_channels, 246 | kernel_size=1, 247 | stride=1, 248 | padding=0)) 249 | 250 | def forward(self, x, context=None): 251 | # note: if no context is given, cross-attention defaults to self-attention 252 | b, c, h, w = x.shape 253 | x_in = x 254 | x = self.norm(x) 255 | x = self.proj_in(x) 256 | x = rearrange(x, 'b c h w -> b (h w) c') 257 | for block in self.transformer_blocks: 258 | x = block(x, context=context) 259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 260 | x = self.proj_out(x) 261 | return x + x_in -------------------------------------------------------------------------------- /generation/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /generation/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /generation/ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /generation/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc -------------------------------------------------------------------------------- /generation/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /generation/ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNorm32(32, channels) 206 | 207 | 208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 209 | class SiLU(nn.Module): 210 | def forward(self, x): 211 | return x * torch.sigmoid(x) 212 | 213 | 214 | class GroupNorm32(nn.GroupNorm): 215 | def forward(self, x): 216 | return super().forward(x.float()).type(x.dtype) 217 | 218 | def conv_nd(dims, *args, **kwargs): 219 | """ 220 | Create a 1D, 2D, or 3D convolution module. 221 | """ 222 | if dims == 1: 223 | return nn.Conv1d(*args, **kwargs) 224 | elif dims == 2: 225 | return nn.Conv2d(*args, **kwargs) 226 | elif dims == 3: 227 | return nn.Conv3d(*args, **kwargs) 228 | raise ValueError(f"unsupported dimensions: {dims}") 229 | 230 | 231 | def linear(*args, **kwargs): 232 | """ 233 | Create a linear module. 234 | """ 235 | return nn.Linear(*args, **kwargs) 236 | 237 | 238 | def avg_pool_nd(dims, *args, **kwargs): 239 | """ 240 | Create a 1D, 2D, or 3D average pooling module. 241 | """ 242 | if dims == 1: 243 | return nn.AvgPool1d(*args, **kwargs) 244 | elif dims == 2: 245 | return nn.AvgPool2d(*args, **kwargs) 246 | elif dims == 3: 247 | return nn.AvgPool3d(*args, **kwargs) 248 | raise ValueError(f"unsupported dimensions: {dims}") 249 | 250 | 251 | class HybridConditioner(nn.Module): 252 | 253 | def __init__(self, c_concat_config, c_crossattn_config): 254 | super().__init__() 255 | self.concat_conditioner = instantiate_from_config(c_concat_config) 256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 257 | 258 | def forward(self, c_concat, c_crossattn): 259 | c_concat = self.concat_conditioner(c_concat) 260 | c_crossattn = self.crossattn_conditioner(c_crossattn) 261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 262 | 263 | 264 | def noise_like(shape, device, repeat=False): 265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 266 | noise = lambda: torch.randn(shape, device=device) 267 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /generation/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /generation/ldm/modules/distributions/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/distributions/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /generation/ldm/modules/distributions/__pycache__/distributions.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/distributions/__pycache__/distributions.cpython-39.pyc -------------------------------------------------------------------------------- /generation/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /generation/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /generation/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /generation/ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /generation/ldm/modules/encoders/__pycache__/modules.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/encoders/__pycache__/modules.cpython-39.pyc -------------------------------------------------------------------------------- /generation/ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import clip 5 | from einops import rearrange, repeat 6 | import kornia 7 | 8 | 9 | from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test 10 | 11 | 12 | class AbstractEncoder(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def encode(self, *args, **kwargs): 17 | raise NotImplementedError 18 | 19 | 20 | 21 | class ClassEmbedder(nn.Module): 22 | def __init__(self, embed_dim, n_classes=1000, key='class'): 23 | super().__init__() 24 | self.key = key 25 | self.embedding = nn.Embedding(n_classes, embed_dim) 26 | 27 | def forward(self, batch, key=None): 28 | if key is None: 29 | key = self.key 30 | # this is for use in crossattn 31 | c = batch[key][:, None] 32 | c = self.embedding(c) 33 | return c 34 | 35 | 36 | class TransformerEmbedder(AbstractEncoder): 37 | """Some transformer encoder layers""" 38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 39 | super().__init__() 40 | self.device = device 41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 42 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 43 | 44 | def forward(self, tokens): 45 | tokens = tokens.to(self.device) # meh 46 | z = self.transformer(tokens, return_embeddings=True) 47 | return z 48 | 49 | def encode(self, x): 50 | return self(x) 51 | 52 | 53 | class BERTTokenizer(AbstractEncoder): 54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 55 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 56 | super().__init__() 57 | from transformers import BertTokenizerFast # TODO: add to reuquirements 58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 59 | self.device = device 60 | self.vq_interface = vq_interface 61 | self.max_length = max_length 62 | 63 | def forward(self, text): 64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 66 | tokens = batch_encoding["input_ids"].to(self.device) 67 | return tokens 68 | 69 | @torch.no_grad() 70 | def encode(self, text): 71 | tokens = self(text) 72 | if not self.vq_interface: 73 | return tokens 74 | return None, None, [None, None, tokens] 75 | 76 | def decode(self, text): 77 | return text 78 | 79 | 80 | class BERTEmbedder(AbstractEncoder): 81 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 82 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 83 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 84 | super().__init__() 85 | self.use_tknz_fn = use_tokenizer 86 | if self.use_tknz_fn: 87 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 88 | self.device = device 89 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 90 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 91 | emb_dropout=embedding_dropout) 92 | 93 | def forward(self, text): 94 | if self.use_tknz_fn: 95 | tokens = self.tknz_fn(text)#.to(self.device) 96 | else: 97 | tokens = text 98 | z = self.transformer(tokens, return_embeddings=True) 99 | return z 100 | 101 | def encode(self, text): 102 | # output of length 77 103 | return self(text) 104 | 105 | 106 | class SpatialRescaler(nn.Module): 107 | def __init__(self, 108 | n_stages=1, 109 | method='bilinear', 110 | multiplier=0.5, 111 | in_channels=3, 112 | out_channels=None, 113 | bias=False): 114 | super().__init__() 115 | self.n_stages = n_stages 116 | assert self.n_stages >= 0 117 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 118 | self.multiplier = multiplier 119 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 120 | self.remap_output = out_channels is not None 121 | if self.remap_output: 122 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 123 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 124 | 125 | def forward(self,x): 126 | for stage in range(self.n_stages): 127 | x = self.interpolator(x, scale_factor=self.multiplier) 128 | 129 | 130 | if self.remap_output: 131 | x = self.channel_mapper(x) 132 | return x 133 | 134 | def encode(self, x): 135 | return self(x) 136 | 137 | 138 | class FrozenCLIPTextEmbedder(nn.Module): 139 | """ 140 | Uses the CLIP transformer encoder for text. 141 | """ 142 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): 143 | super().__init__() 144 | self.model, _ = clip.load(version, jit=False, device="cpu") 145 | self.device = device 146 | self.max_length = max_length 147 | self.n_repeat = n_repeat 148 | self.normalize = normalize 149 | 150 | def freeze(self): 151 | self.model = self.model.eval() 152 | for param in self.parameters(): 153 | param.requires_grad = False 154 | 155 | def forward(self, text): 156 | tokens = clip.tokenize(text).to(self.device) 157 | z = self.model.encode_text(tokens) 158 | if self.normalize: 159 | z = z / torch.linalg.norm(z, dim=1, keepdim=True) 160 | return z 161 | 162 | def encode(self, text): 163 | z = self(text) 164 | if z.ndim==2: 165 | z = z[:, None, :] 166 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) 167 | return z 168 | 169 | 170 | class FrozenClipImageEmbedder(nn.Module): 171 | """ 172 | Uses the CLIP image encoder. 173 | """ 174 | def __init__( 175 | self, 176 | model, 177 | jit=False, 178 | device='cuda' if torch.cuda.is_available() else 'cpu', 179 | antialias=False, 180 | ): 181 | super().__init__() 182 | self.model, _ = clip.load(name=model, device=device, jit=jit) 183 | 184 | self.antialias = antialias 185 | 186 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 187 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 188 | 189 | def preprocess(self, x): 190 | # normalize to [0,1] 191 | x = kornia.geometry.resize(x, (224, 224), 192 | interpolation='bicubic',align_corners=True, 193 | antialias=self.antialias) 194 | x = (x + 1.) / 2. 195 | # renormalize according to clip 196 | x = kornia.enhance.normalize(x, self.mean, self.std) 197 | return x 198 | 199 | def forward(self, x): 200 | # x is assumed to be in range [-1,1] 201 | return self.model.encode_image(self.preprocess(x)) 202 | 203 | -------------------------------------------------------------------------------- /generation/ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /generation/ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hritikbansal/generative-robustness/165efa28a3eae608366ecf47bb1a1c3932d5c5c0/generation/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /generation/ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /generation/ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /generation/ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /generation/ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == '__is_first_stage__': 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 86 | 87 | 88 | def get_obj_from_str(string, reload=False): 89 | module, cls = string.rsplit(".", 1) 90 | if reload: 91 | module_imp = importlib.import_module(module) 92 | importlib.reload(module_imp) 93 | return getattr(importlib.import_module(module, package=None), cls) 94 | 95 | 96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 97 | # create dummy dataset instance 98 | 99 | # run prefetching 100 | if idx_to_fn: 101 | res = func(data, worker_id=idx) 102 | else: 103 | res = func(data) 104 | Q.put([idx, res]) 105 | Q.put("Done") 106 | 107 | 108 | def parallel_data_prefetch( 109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 110 | ): 111 | # if target_data_type not in ["ndarray", "list"]: 112 | # raise ValueError( 113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 114 | # ) 115 | if isinstance(data, np.ndarray) and target_data_type == "list": 116 | raise ValueError("list expected but function got ndarray.") 117 | elif isinstance(data, abc.Iterable): 118 | if isinstance(data, dict): 119 | print( 120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 121 | ) 122 | data = list(data.values()) 123 | if target_data_type == "ndarray": 124 | data = np.asarray(data) 125 | else: 126 | data = list(data) 127 | else: 128 | raise TypeError( 129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 130 | ) 131 | 132 | if cpu_intensive: 133 | Q = mp.Queue(1000) 134 | proc = mp.Process 135 | else: 136 | Q = Queue(1000) 137 | proc = Thread 138 | # spawn processes 139 | if target_data_type == "ndarray": 140 | arguments = [ 141 | [func, Q, part, i, use_worker_id] 142 | for i, part in enumerate(np.array_split(data, n_proc)) 143 | ] 144 | else: 145 | step = ( 146 | int(len(data) / n_proc + 1) 147 | if len(data) % n_proc != 0 148 | else int(len(data) / n_proc) 149 | ) 150 | arguments = [ 151 | [func, Q, part, i, use_worker_id] 152 | for i, part in enumerate( 153 | [data[i: i + step] for i in range(0, len(data), step)] 154 | ) 155 | ] 156 | processes = [] 157 | for i in range(n_proc): 158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 159 | processes += [p] 160 | 161 | # start processes 162 | print(f"Start prefetching...") 163 | import time 164 | 165 | start = time.time() 166 | gather_res = [[] for _ in range(n_proc)] 167 | try: 168 | for p in processes: 169 | p.start() 170 | 171 | k = 0 172 | while k < n_proc: 173 | # get result 174 | res = Q.get() 175 | if res == "Done": 176 | k += 1 177 | else: 178 | gather_res[res[0]] = res[1] 179 | 180 | except Exception as e: 181 | print("Exception: ", e) 182 | for p in processes: 183 | p.terminate() 184 | 185 | raise e 186 | finally: 187 | for p in processes: 188 | p.join() 189 | print(f"Prefetching complete. [{time.time() - start} sec.]") 190 | 191 | if target_data_type == 'ndarray': 192 | if not isinstance(gather_res[0], np.ndarray): 193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 194 | 195 | # order outputs 196 | return np.concatenate(gather_res, axis=0) 197 | elif target_data_type == 'list': 198 | out = [] 199 | for r in gather_res: 200 | out.extend(r) 201 | return out 202 | else: 203 | return gather_res 204 | -------------------------------------------------------------------------------- /sd_finetune/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL 3 | import glob 4 | import torch 5 | import torchvision 6 | import requests 7 | import random 8 | import pickle 9 | import numpy as np 10 | import pandas as pd 11 | from io import BytesIO 12 | from PIL import Image, ImageFile 13 | from torchvision import transforms 14 | from collections import defaultdict 15 | from torchvision.datasets import VisionDataset 16 | from torch.utils.data import Dataset, DataLoader 17 | 18 | 19 | ImageFile.LOAD_TRUNCATED_IMAGES = True 20 | 21 | 22 | 23 | class ImageCaptionDataset(Dataset): 24 | 25 | def __init__(self, filename, tokenizer, image_transform, caption_key = 'caption', image_key = 'image'): 26 | 27 | df = pd.read_csv(filename) 28 | 29 | self.transform = image_transform 30 | self.captions = df[caption_key].tolist() 31 | self.tokenizer = tokenizer 32 | self.images = df[image_key].tolist() 33 | 34 | self.t_captions= tokenizer(self.captions, max_length = tokenizer.model_max_length, padding = 'max_length', truncation = True, return_tensors = 'pt') 35 | 36 | def __len__(self): 37 | return len(self.images) 38 | 39 | def __getitem__(self, index): 40 | 41 | item = {} 42 | item['image_location'] = self.images[index] 43 | item['caption'] = self.captions[index] 44 | item['input_ids'] = self.t_captions['input_ids'][index] 45 | item['pixel_values'] = self.transform(Image.open(self.images[index]).convert('RGB')) 46 | 47 | return item -------------------------------------------------------------------------------- /sd_finetune/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | def parse_args(): 5 | 6 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 7 | parser.add_argument( 8 | "--pretrained_model_name_or_path", 9 | type=str, 10 | default="runwayml/stable-diffusion-v1-5", 11 | help="Path to pretrained model or model identifier from huggingface.co/models.", 12 | ) 13 | parser.add_argument( 14 | "--revision", 15 | type=str, 16 | default=None, 17 | required=False, 18 | help="Revision of pretrained model identifier from huggingface.co/models.", 19 | ) 20 | parser.add_argument( 21 | "--dataset_name", 22 | type=str, 23 | default=None, 24 | help=( 25 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," 26 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 27 | " or to a folder containing files that 🤗 Datasets can understand." 28 | ), 29 | ) 30 | parser.add_argument( 31 | "--dataset_config_name", 32 | type=str, 33 | default=None, 34 | help="The config of the Dataset, leave as None if there's only one config.", 35 | ) 36 | parser.add_argument( 37 | "--train_data_file", 38 | type=str, 39 | default=None 40 | ) 41 | parser.add_argument( 42 | "--image_column", type=str, default="image", help="The column of the dataset containing an image." 43 | ) 44 | parser.add_argument( 45 | "--caption_column", 46 | type=str, 47 | default="text", 48 | help="The column of the dataset containing a caption or a list of captions.", 49 | ) 50 | parser.add_argument( 51 | "--max_train_samples", 52 | type=int, 53 | default=None, 54 | help=( 55 | "For debugging purposes or quicker training, truncate the number of training examples to this " 56 | "value if set." 57 | ), 58 | ) 59 | parser.add_argument( 60 | "--output_dir", 61 | type=str, 62 | default="sd-model-finetuned", 63 | help="The output directory where the model predictions and checkpoints will be written.", 64 | ) 65 | parser.add_argument( 66 | "--cache_dir", 67 | type=str, 68 | default=None, 69 | help="The directory where the downloaded models and datasets will be stored.", 70 | ) 71 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 72 | parser.add_argument( 73 | "--resolution", 74 | type=int, 75 | default=512, 76 | help=( 77 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 78 | " resolution" 79 | ), 80 | ) 81 | parser.add_argument( 82 | "--center_crop", 83 | default=False, 84 | action="store_true", 85 | help=( 86 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 87 | " cropped. The images will be resized to the resolution first before cropping." 88 | ), 89 | ) 90 | parser.add_argument( 91 | "--random_flip", 92 | action="store_true", 93 | help="whether to randomly flip images horizontally", 94 | ) 95 | parser.add_argument( 96 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 97 | ) 98 | parser.add_argument("--num_train_epochs", type=int, default=100) 99 | parser.add_argument( 100 | "--max_train_steps", 101 | type=int, 102 | default=None, 103 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 104 | ) 105 | parser.add_argument( 106 | "--gradient_accumulation_steps", 107 | type=int, 108 | default=1, 109 | help="Number of updates steps to accumulate before performing a backward/update pass.", 110 | ) 111 | parser.add_argument( 112 | "--gradient_checkpointing", 113 | action="store_true", 114 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 115 | ) 116 | parser.add_argument( 117 | "--learning_rate", 118 | type=float, 119 | default=1e-4, 120 | help="Initial learning rate (after the potential warmup period) to use.", 121 | ) 122 | parser.add_argument( 123 | "--scale_lr", 124 | action="store_true", 125 | default=False, 126 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 127 | ) 128 | parser.add_argument( 129 | "--lr_scheduler", 130 | type=str, 131 | default="constant", 132 | help=( 133 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 134 | ' "constant", "constant_with_warmup"]' 135 | ), 136 | ) 137 | parser.add_argument( 138 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 139 | ) 140 | parser.add_argument( 141 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 142 | ) 143 | parser.add_argument( 144 | "--allow_tf32", 145 | action="store_true", 146 | help=( 147 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 148 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 149 | ), 150 | ) 151 | parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") 152 | parser.add_argument( 153 | "--non_ema_revision", 154 | type=str, 155 | default=None, 156 | required=False, 157 | help=( 158 | "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" 159 | " remote repository specified with --pretrained_model_name_or_path." 160 | ), 161 | ) 162 | parser.add_argument( 163 | "--dataloader_num_workers", 164 | type=int, 165 | default=0, 166 | help=( 167 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 168 | ), 169 | ) 170 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 171 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 172 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 173 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 174 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 175 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 176 | 177 | parser.add_argument("--freeze_unet", action="store_true", help="Whether or not to freeze the unet.") 178 | 179 | parser.add_argument("--wandb", action="store_true", help="Whether or not to log using wandb") 180 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 181 | parser.add_argument( 182 | "--hub_model_id", 183 | type=str, 184 | default=None, 185 | help="The name of the repository to keep in sync with the local `output_dir`.", 186 | ) 187 | parser.add_argument( 188 | "--logging_dir", 189 | type=str, 190 | default="logs", 191 | help=( 192 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 193 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 194 | ), 195 | ) 196 | parser.add_argument( 197 | "--mixed_precision", 198 | type=str, 199 | default=None, 200 | choices=["no", "fp16", "bf16"], 201 | help=( 202 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 203 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 204 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 205 | ), 206 | ) 207 | parser.add_argument( 208 | "--report_to", 209 | type=str, 210 | default=None, 211 | help=( 212 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 213 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 214 | ), 215 | ) 216 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 217 | parser.add_argument( 218 | "--checkpointing_steps", 219 | type=int, 220 | default=500, 221 | help=( 222 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 223 | " training using `--resume_from_checkpoint`." 224 | ), 225 | ) 226 | parser.add_argument( 227 | "--checkpoints_total_limit", 228 | type=int, 229 | default=None, 230 | help=( 231 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." 232 | " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" 233 | " for more docs" 234 | ), 235 | ) 236 | parser.add_argument( 237 | "--resume_from_checkpoint", 238 | type=str, 239 | default=None, 240 | help=( 241 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 242 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 243 | ), 244 | ) 245 | parser.add_argument( 246 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 247 | ) 248 | 249 | args = parser.parse_args() 250 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 251 | if env_local_rank != -1 and env_local_rank != args.local_rank: 252 | args.local_rank = env_local_rank 253 | 254 | # # Sanity checks 255 | # if args.dataset_name is None and args.train_data_dir is None: 256 | # raise ValueError("Need either a dataset name or a training folder.") 257 | 258 | # default to using the same revision for the non-ema model if not specified 259 | if args.non_ema_revision is None: 260 | args.non_ema_revision = args.revision 261 | 262 | return args -------------------------------------------------------------------------------- /train/resnet_configs/alexnet.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | gpu: 0 3 | in_memory: 1 4 | num_workers: 12 5 | dist: 6 | world_size: 2 7 | logging: 8 | folder: /tmp/ 9 | log_level: 0 10 | lr: 11 | lr: 0.5 12 | lr_peak_epoch: 2 13 | lr_schedule_type: cyclic 14 | model: 15 | arch: alexnet 16 | resolution: 17 | end_ramp: 76 18 | max_res: 192 19 | min_res: 160 20 | start_ramp: 65 21 | training: 22 | batch_size: 128 23 | bn_wd: 0 24 | distributed: 0 25 | epochs: 56 26 | label_smoothing: 0.1 27 | momentum: 0.9 28 | optimizer: sgd 29 | weight_decay: 5e-5 30 | use_blurpool: 1 31 | validation: 32 | lr_tta: true 33 | resolution: 256 -------------------------------------------------------------------------------- /train/resnet_configs/efficient.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | gpu: 0 3 | in_memory: 1 4 | num_workers: 12 5 | dist: 6 | world_size: 2 7 | logging: 8 | folder: /tmp/ 9 | log_level: 0 10 | lr: 11 | lr: 0.5 12 | lr_peak_epoch: 2 13 | lr_schedule_type: cyclic 14 | model: 15 | arch: efficientnet_b0 16 | resolution: 17 | end_ramp: 76 18 | max_res: 192 19 | min_res: 160 20 | start_ramp: 65 21 | training: 22 | batch_size: 128 23 | bn_wd: 0 24 | distributed: 0 25 | epochs: 56 26 | label_smoothing: 0.1 27 | momentum: 0.9 28 | optimizer: sgd 29 | weight_decay: 5e-5 30 | use_blurpool: 1 31 | validation: 32 | lr_tta: true 33 | resolution: 256 -------------------------------------------------------------------------------- /train/resnet_configs/inception.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | gpu: 0 3 | in_memory: 1 4 | num_workers: 12 5 | dist: 6 | world_size: 2 7 | logging: 8 | folder: /tmp/ 9 | log_level: 0 10 | lr: 11 | lr: 0.5 12 | lr_peak_epoch: 2 13 | lr_schedule_type: cyclic 14 | model: 15 | arch: inception 16 | resolution: 17 | end_ramp: 76 18 | max_res: 192 19 | min_res: 160 20 | start_ramp: 65 21 | training: 22 | batch_size: 128 23 | bn_wd: 0 24 | distributed: 0 25 | epochs: 40 26 | label_smoothing: 0.1 27 | momentum: 0.9 28 | optimizer: sgd 29 | weight_decay: 5e-5 30 | use_blurpool: 1 31 | validation: 32 | lr_tta: true 33 | resolution: 256 -------------------------------------------------------------------------------- /train/resnet_configs/mobilenet.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | gpu: 0 3 | in_memory: 1 4 | num_workers: 12 5 | dist: 6 | world_size: 2 7 | logging: 8 | folder: /tmp/ 9 | log_level: 0 10 | lr: 11 | lr: 0.5 12 | lr_peak_epoch: 2 13 | lr_schedule_type: cyclic 14 | model: 15 | arch: mobilenet_v2 16 | resolution: 17 | end_ramp: 76 18 | max_res: 192 19 | min_res: 160 20 | start_ramp: 65 21 | training: 22 | batch_size: 128 23 | bn_wd: 0 24 | distributed: 0 25 | epochs: 40 26 | label_smoothing: 0.1 27 | momentum: 0.9 28 | optimizer: sgd 29 | weight_decay: 5e-5 30 | use_blurpool: 1 31 | validation: 32 | lr_tta: true 33 | resolution: 256 -------------------------------------------------------------------------------- /train/resnet_configs/resnext101.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | gpu: 0 3 | in_memory: 1 4 | num_workers: 12 5 | dist: 6 | world_size: 2 7 | logging: 8 | folder: /tmp/ 9 | log_level: 0 10 | lr: 11 | lr: 0.5 12 | lr_peak_epoch: 2 13 | lr_schedule_type: cyclic 14 | model: 15 | arch: resnext101_32x8d 16 | resolution: 17 | end_ramp: 76 18 | max_res: 192 19 | min_res: 160 20 | start_ramp: 65 21 | training: 22 | batch_size: 64 23 | bn_wd: 0 24 | distributed: 0 25 | epochs: 56 26 | label_smoothing: 0.1 27 | momentum: 0.9 28 | optimizer: sgd 29 | weight_decay: 5e-5 30 | use_blurpool: 1 31 | validation: 32 | lr_tta: true 33 | resolution: 256 -------------------------------------------------------------------------------- /train/resnet_configs/resnext50.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | gpu: 0 3 | in_memory: 1 4 | num_workers: 12 5 | dist: 6 | world_size: 3 7 | logging: 8 | folder: /tmp/ 9 | log_level: 0 10 | lr: 11 | lr: 0.5 12 | lr_peak_epoch: 2 13 | lr_schedule_type: cyclic 14 | model: 15 | arch: resnext50_32x4d 16 | resolution: 17 | end_ramp: 76 18 | max_res: 192 19 | min_res: 160 20 | start_ramp: 65 21 | training: 22 | batch_size: 128 23 | bn_wd: 0 24 | distributed: 0 25 | epochs: 56 26 | label_smoothing: 0.1 27 | momentum: 0.9 28 | optimizer: sgd 29 | weight_decay: 5e-5 30 | use_blurpool: 1 31 | validation: 32 | lr_tta: true 33 | resolution: 256 -------------------------------------------------------------------------------- /train/resnet_configs/rn18_88_epochs.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | gpu: 0 3 | in_memory: 1 4 | num_workers: 12 5 | dist: 6 | world_size: 2 7 | logging: 8 | folder: /tmp/ 9 | log_level: 0 10 | lr: 11 | lr: 0.5 12 | lr_peak_epoch: 2 13 | lr_schedule_type: cyclic 14 | model: 15 | arch: resnet18 16 | resolution: 17 | end_ramp: 76 18 | max_res: 192 19 | min_res: 160 20 | start_ramp: 65 21 | training: 22 | batch_size: 128 23 | bn_wd: 0 24 | distributed: 0 25 | epochs: 56 26 | label_smoothing: 0.1 27 | momentum: 0.9 28 | optimizer: sgd 29 | weight_decay: 5e-5 30 | use_blurpool: 1 31 | validation: 32 | lr_tta: true 33 | resolution: 256 -------------------------------------------------------------------------------- /train/resnet_configs/vgg16.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | gpu: 0 3 | in_memory: 1 4 | num_workers: 12 5 | dist: 6 | world_size: 2 7 | logging: 8 | folder: /tmp/ 9 | log_level: 0 10 | lr: 11 | lr: 0.5 12 | lr_peak_epoch: 2 13 | lr_schedule_type: cyclic 14 | model: 15 | arch: vgg16_bn 16 | resolution: 17 | end_ramp: 76 18 | max_res: 192 19 | min_res: 160 20 | start_ramp: 65 21 | training: 22 | batch_size: 64 23 | bn_wd: 0 24 | distributed: 0 25 | epochs: 40 26 | label_smoothing: 0.1 27 | momentum: 0.9 28 | optimizer: sgd 29 | weight_decay: 5e-5 30 | use_blurpool: 1 31 | validation: 32 | lr_tta: true 33 | resolution: 256 -------------------------------------------------------------------------------- /train/write_imagenet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is directly taken from FFCV-Imagenet https://github.com/libffcv/ffcv-imagenet 3 | ''' 4 | from torch.utils.data import Subset, ConcatDataset 5 | from ffcv.writer import DatasetWriter 6 | from ffcv.fields import IntField, RGBImageField 7 | from torchvision.datasets import CIFAR10, ImageFolder 8 | 9 | from argparse import ArgumentParser 10 | from fastargs import Section, Param 11 | from fastargs.validation import And, OneOf 12 | from fastargs.decorators import param, section 13 | from fastargs import get_current_config 14 | 15 | from customdataset import CustomizeDataset 16 | 17 | Section('cfg', 'arguments to give the writer').params( 18 | dataset=Param(And(str, OneOf(['cifar', 'imagenet'])), 'Which dataset to write', default='imagenet'), 19 | split=Param(And(str, OneOf(['train', 'val'])), 'Train or val set', required=True), 20 | data_dir=Param(str, 'Where to find the PyTorch dataset', required=True), 21 | write_path=Param(str, 'Where to write the new dataset', required=True), 22 | write_mode=Param(str, 'Mode: raw, smart or jpg', required=False, default='smart'), 23 | max_resolution=Param(int, 'Max image side length', required=True), 24 | num_workers=Param(int, 'Number of workers to use', default=16), 25 | chunk_size=Param(int, 'Chunk size for writing', default=100), 26 | jpeg_quality=Param(float, 'Quality of jpeg images', default=90), 27 | subset=Param(int, 'How many images to use (-1 for all)', default=-1), 28 | compress_probability=Param(float, 'compress probability', default=None) 29 | ) 30 | 31 | @section('cfg') 32 | @param('dataset') 33 | @param('split') 34 | @param('data_dir') 35 | @param('write_path') 36 | @param('max_resolution') 37 | @param('num_workers') 38 | @param('chunk_size') 39 | @param('subset') 40 | @param('jpeg_quality') 41 | @param('write_mode') 42 | @param('compress_probability') 43 | def main(dataset, split, data_dir, write_path, max_resolution, num_workers, 44 | chunk_size, subset, jpeg_quality, write_mode, 45 | compress_probability): 46 | if dataset == 'cifar': 47 | my_dataset = CIFAR10(root=data_dir, train=(split == 'train'), download=True) 48 | elif dataset == 'imagenet': 49 | my_dataset = ImageFolder(root=data_dir) 50 | elif dataset == 'imagenet_aug': 51 | my_dataset_1 = ImageFolder(root=data_dir) 52 | data_dir_2 = "path/to/generated/imagenet/train/or/val" 53 | my_dataset_2 = ImageFolder(root=data_dir_2) 54 | my_dataset = ConcatDataset([my_dataset_1, my_dataset_2]) 55 | else: 56 | raise ValueError('Unrecognized dataset', dataset) 57 | 58 | if subset > 0: my_dataset = Subset(my_dataset, range(subset)) 59 | writer = DatasetWriter(write_path, { 60 | 'image': RGBImageField(write_mode=write_mode, 61 | max_resolution=max_resolution, 62 | compress_probability=compress_probability, 63 | jpeg_quality=jpeg_quality), 64 | 'label': IntField(), 65 | }, num_workers=num_workers) 66 | 67 | writer.from_indexed_dataset(my_dataset, chunksize=chunk_size) 68 | 69 | if __name__ == '__main__': 70 | config = get_current_config() 71 | parser = ArgumentParser() 72 | config.augment_argparse(parser) 73 | config.collect_argparse_args(parser) 74 | config.validate(mode='stderr') 75 | config.summary() 76 | main() 77 | -------------------------------------------------------------------------------- /train/write_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #This code is directly taken from FFCV-Imagenet https://github.com/libffcv/ffcv-imagenet 4 | 5 | write_dataset () { 6 | write_path=$WRITE_DIR/${1}_${2}_${3}_${4}.ffcv 7 | echo "Writing ImageNet ${1} dataset to ${write_path}" 8 | python write_imagenet.py \ 9 | --cfg.dataset=imagenet \ 10 | --cfg.split=${1} \ 11 | --cfg.data_dir=$IMAGENET_DIR/${1} \ 12 | --cfg.write_path=$write_path \ 13 | --cfg.max_resolution=${2} \ 14 | --cfg.write_mode=proportion \ 15 | --cfg.compress_probability=${3} \ 16 | --cfg.jpeg_quality=$4 17 | } 18 | 19 | write_dataset train $1 $2 $3 20 | # write_dataset val $1 $2 $3 -------------------------------------------------------------------------------- /utils/create_imagenet_subset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Aniruddha Saha 3 | Description: This scripts creates the Imagenet100 subset. 4 | Instead of using symlinks like in the original repo, this copies the images 5 | for ease of use in the Dataset class in this repo which loads images from lists. 6 | ''' 7 | import os 8 | import re 9 | import errno 10 | import argparse 11 | import shutil 12 | from tqdm import tqdm 13 | 14 | 15 | def create_subset(class_list, full_imagenet_path, subset_imagenet_path, *, 16 | splits=('val',)): 17 | full_imagenet_path = os.path.abspath(full_imagenet_path) 18 | subset_imagenet_path = os.path.abspath(subset_imagenet_path) 19 | os.makedirs(subset_imagenet_path, exist_ok=True) 20 | for split in splits: 21 | os.makedirs(os.path.join(subset_imagenet_path, split), exist_ok=True) 22 | for c in tqdm(class_list): 23 | if re.match(r"n[0-9]{8}", c) is None: 24 | raise ValueError( 25 | f"Expected class names to be of the format nXXXXXXXX, where " 26 | f"each X represents a numerical number, e.g., n04589890, but " 27 | f"got {c}") 28 | for split in splits: 29 | try: 30 | shutil.copytree( 31 | os.path.join(full_imagenet_path, split, c), 32 | os.path.join(subset_imagenet_path, split, c) 33 | ) 34 | except: 35 | print(f'Class {c} is not present') 36 | print(f'Finished creating ImageNet subset at {subset_imagenet_path}!') 37 | 38 | 39 | if __name__ == '__main__': 40 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Subset Creation') 41 | parser.add_argument('--full_imagenet_path', metavar='IMAGENET_DIR', 42 | help='path to the existing full ImageNet dataset') 43 | parser.add_argument('--subset_imagenet_path', metavar='SUBSET_DIR', 44 | help='path to create the ImageNet subset dataset') 45 | parser.add_argument('--subset', type=str, 46 | default=os.path.join(os.path.dirname(__file__), 'imagenet100_classes.txt'), 47 | help='file contains a list of subset classes') 48 | args = parser.parse_args() 49 | 50 | print(f'Using class names specified in {args.subset}.') 51 | with open(args.subset, 'r') as f: 52 | class_list = [l.strip() for l in f.readlines()] 53 | 54 | create_subset(class_list, args.full_imagenet_path, args.subset_imagenet_path) -------------------------------------------------------------------------------- /utils/mappings/folder_to_objectnet_label.json: -------------------------------------------------------------------------------- 1 | { 2 | "squeegee": "Squeegee", 3 | "umbrella": "Umbrella", 4 | "eyeglasses": "Eyeglasses", 5 | "coaster": "Coaster", 6 | "winter_glove": "Winter glove", 7 | "pet_food_container": "Pet food container", 8 | "scissors": "Scissors", 9 | "computer_mouse": "Computer mouse", 10 | "still_camera": "Still Camera", 11 | "weight_scale": "Weight scale", 12 | "cutting_board": "Cutting board", 13 | "spatula": "Spatula", 14 | "plunger": "Plunger", 15 | "paper": "Paper", 16 | "paper_bag": "Paper bag", 17 | "microwave": "Microwave", 18 | "keyboard": "Keyboard", 19 | "standing_lamp": "Standing lamp", 20 | "chair": "Chair", 21 | "mixing_salad_bowl": "Mixing / Salad Bowl", 22 | "milk": "Milk", 23 | "blender": "Blender", 24 | "flashlight": "Flashlight", 25 | "floss_container": "Floss container", 26 | "rake": "Rake", 27 | "sandal": "Sandal", 28 | "book_closed": "Book (closed)", 29 | "extension_cable": "Extension cable", 30 | "drawer_open": "Drawer (open)", 31 | "tv": "TV", 32 | "backpack": "Backpack", 33 | "playing_cards": "Playing cards", 34 | "jar": "Jar", 35 | "frying_pan": "Frying pan", 36 | "bench": "Bench", 37 | "coffee_grinder": "Coffee grinder", 38 | "jam": "Jam", 39 | "tape": "Tape / duct tape", 40 | "dog_bed": "Dog bed", 41 | "throw_pillow": "Throw pillow", 42 | "orange": "Orange", 43 | "speaker": "Speaker", 44 | "paint_can": "Paint can", 45 | "hand_towel_or_rag": "Dishrag or hand towel", 46 | "hat": "Hat", 47 | "toaster": "Toaster", 48 | "match": "Match", 49 | "plate": "Plate", 50 | "ironing_board": "Ironing board", 51 | "alarm_clock": "Alarm clock", 52 | "sunglasses": "Sunglasses", 53 | "power_bar": "Power bar", 54 | "wine_glass": "Wine glass", 55 | "ladle": "Ladle", 56 | "ruler": "Ruler", 57 | "blanket": "Blanket", 58 | "contact_lens_case": "Contact lens case", 59 | "watch": "Watch", 60 | "trash_bin": "Trash bin", 61 | "vase": "Vase", 62 | "oven_mitts": "Oven mitts", 63 | "spoon": "Spoon", 64 | "thermos": "Thermos", 65 | "fan": "Fan", 66 | "egg_carton": "Egg carton", 67 | "lighter": "Lighter", 68 | "glue_container": "Glue container", 69 | "bottle_cap": "Bottle cap", 70 | "wallet": "Wallet", 71 | "coffee_table": "Coffee table", 72 | "nail_clippers": "Nail clippers", 73 | "cellphone_case": "Cellphone case", 74 | "canned_food": "Canned food", 75 | "letter_opener": "Letter opener", 76 | "stapler": "Stapler", 77 | "whisk": "Whisk", 78 | "tongs": "Tongs", 79 | "lettuce": "Lettuce", 80 | "shoelace": "Shoelace", 81 | "button": "Button", 82 | "hair_dryer": "Hair dryer", 83 | "bracelet": "Bracelet", 84 | "spray_bottle": "Spray bottle", 85 | "laptop_charger": "Laptop charger", 86 | "portable_heater": "Portable heater", 87 | "suit_jacket": "Suit jacket", 88 | "dress_shoe_men": "Dress shoe (men)", 89 | "rock": "Rock", 90 | "water_filter": "Water filter", 91 | "earbuds": "Earbuds", 92 | "biscuits": "Biscuits", 93 | "mouthwash": "Mouthwash", 94 | "slipper": "Slipper", 95 | "eraser_white_board": "Eraser (white board)", 96 | "bicycle": "Bicycle", 97 | "ziploc_bag": "Ziploc bag", 98 | "dish_soap": "Dish soap", 99 | "travel_case": "Travel case", 100 | "receipt": "Receipt", 101 | "boots": "Boots", 102 | "sponge": "Sponge", 103 | "full_sized_towel": "Bath towel", 104 | "flour_container": "Flour container", 105 | "pill_bottle": "Pill bottle", 106 | "candle": "Candle", 107 | "calendar": "Calendar", 108 | "tweezers": "Tweezers", 109 | "dvd_player": "DVD player", 110 | "plastic_wrap": "Plastic wrap", 111 | "ribbon": "Ribbon", 112 | "blouse": "Blouse", 113 | "walking_cane": "Walking cane", 114 | "leaf": "Leaf", 115 | "lipstick": "Lipstick", 116 | "hair_brush": "Hair brush", 117 | "night_light": "Night light", 118 | "kettle": "Kettle", 119 | "honey_container": "Honey container", 120 | "tarp": "Tarp", 121 | "ice": "Ice", 122 | "drinking_straw": "Drinking straw", 123 | "detergent": "Detergent", 124 | "mug": "Mug", 125 | "toilet_paper_roll": "Toilet paper roll", 126 | "wok": "Wok", 127 | "swimming_trunks": "Swimming trunks", 128 | "clothes_hamper": "Clothes hamper", 129 | "removable_blade": "Removable blade", 130 | "shampoo_bottle": "Shampoo bottle", 131 | "skirt": "Skirt", 132 | "loofah": "Loofah", 133 | "broom": "Broom", 134 | "photograph_printed": "Photograph (printed)", 135 | "multitool": "Multitool", 136 | "makeup": "Makeup", 137 | "nail_file": "Nail file", 138 | "brooch": "Brooch", 139 | "cork": "Cork", 140 | "coffee_french_press": "Coffee/French press", 141 | "toy": "Toy", 142 | "thermometer": "Thermometer", 143 | "hairclip": "Hairclip", 144 | "document_folder_closed": "Document folder (closed)", 145 | "pliers": "Pliers", 146 | "strainer": "Strainer", 147 | "comb": "Comb", 148 | "water_bottle": "Water bottle", 149 | "peeler": "Peeler", 150 | "monitor": "Monitor", 151 | "box": "Box", 152 | "pill_organizer": "Pill organizer", 153 | "stopper_sink_tub": "Stopper (sink/tub)", 154 | "walker": "Walker", 155 | "step_stool": "Step stool", 156 | "bills_money": "Bills (money)", 157 | "skateboard": "Skateboard", 158 | "running_shoe": "Running shoe", 159 | "coin_money": "Coin (money)", 160 | "magazine": "Magazine", 161 | "drying_rack_for_clothes": "Drying rack for clothes", 162 | "toothpaste": "Toothpaste", 163 | "paper_towel": "Paper towel", 164 | "remote_control": "Remote control", 165 | "sugar_container": "Sugar container", 166 | "dress_pants": "Dress pants", 167 | "scarf": "Scarf", 168 | "dress_shirt": "Dress shirt", 169 | "cheese": "Cheese", 170 | "can_opener": "Can opener", 171 | "shovel": "Shovel", 172 | "paintbrush": "Paintbrush", 173 | "tennis_racket": "Tennis racket", 174 | "battery": "Battery", 175 | "stuffed_animal": "Stuffed animal", 176 | "jeans": "Jeans", 177 | "tanktop": "Tanktop", 178 | "dust_pan": "Dust pan", 179 | "earring": "Earring", 180 | "tomato": "Tomato", 181 | "marker": "Marker", 182 | "makeup_brush": "Makeup brush", 183 | "ring": "Ring", 184 | "air_freshener": "Air freshener", 185 | "tablecloth": "Tablecloth", 186 | "teabag": "Teabag", 187 | "belt": "Belt", 188 | "razor": "Razor", 189 | "clothes_hanger": "Clothes hanger", 190 | "bookend": "Bookend", 191 | "sweater": "Sweater", 192 | "sock": "Sock", 193 | "usb_flash_drive": "Usb flash drive", 194 | "cellphone_charger": "Cellphone charger", 195 | "pepper_shaker": "Pepper shaker", 196 | "phone_landline": "Phone (landline)", 197 | "banana": "Banana", 198 | "printer": "Printer", 199 | "paperclip": "Paperclip", 200 | "fork": "Fork", 201 | "headphones_over_ear": "Headphones (over ear)", 202 | "cooking_oil_bottle": "Cooking oil bottle", 203 | "deodorant": "Deodorant", 204 | "usb_cable": "Usb cable", 205 | "shorts": "Shorts", 206 | "bread_loaf": "Bread loaf", 207 | "pillow": "Pillow", 208 | "drinking_cup": "Drinking Cup", 209 | "envelope": "Envelope", 210 | "mouse_pad": "Mouse pad", 211 | "chopstick": "Chopstick", 212 | "t-shirt": "T-shirt", 213 | "padlock": "Padlock", 214 | "ice_cube_tray": "Ice cube tray", 215 | "chess_piece": "Chess piece", 216 | "cereal": "Cereal", 217 | "hairtie": "Hairtie", 218 | "teapot": "Teapot", 219 | "board_game": "Board game", 220 | "butchers_knife": "Butcher's knife", 221 | "soup_bowl": "Soup Bowl", 222 | "beer_bottle": "Beer bottle", 223 | "nail_polish": "Nail polish", 224 | "hand_mirror": "Hand mirror", 225 | "combination_lock": "Combination lock", 226 | "nut_for_screw": "Nut for a screw", 227 | "nail_fastener": "Nail (fastener)", 228 | "figurine_or_statue": "Figurine or statue", 229 | "soap_bar": "Soap bar", 230 | "bucket": "Bucket", 231 | "binder_closed": "Binder (closed)", 232 | "video_camera": "Video Camera", 233 | "baseball_glove": "Baseball glove", 234 | "tape_measure": "Tape measure", 235 | "tissue": "Tissue", 236 | "coffee_beans": "Coffee beans", 237 | "scrub_brush": "Scrub brush", 238 | "drill": "Drill", 239 | "suitcase": "Suitcase", 240 | "newspaper": "Newspaper", 241 | "sleeping_bag": "Sleeping bag", 242 | "dress_shoe_women": "Dress shoe (women)", 243 | "trophy": "Trophy", 244 | "plastic_bag": "Plastic bag", 245 | "doormat": "Doormat", 246 | "webcam": "Webcam", 247 | "rolling_pin": "Rolling pin", 248 | "pencil": "Pencil", 249 | "table_knife": "Table knife", 250 | "bread_knife": "Bread knife", 251 | "toothbrush": "Toothbrush", 252 | "bathrobe": "Bathrobe", 253 | "paper_plates": "Paper plates", 254 | "placemat": "Placemat", 255 | "light_bulb": "Light bulb", 256 | "soap_dispenser": "Soap dispenser", 257 | "nightstand": "Nightstand", 258 | "pen": "Pen", 259 | "squeeze_bottle": "Squeeze bottle", 260 | "wheel": "Wheel", 261 | "dress": "Dress", 262 | "helmet": "Helmet", 263 | "lemon": "Lemon", 264 | "hammer": "Hammer", 265 | "lampshade": "Lampshade", 266 | "salt_shaker": "Salt shaker", 267 | "power_cable": "Power cable", 268 | "vacuum_cleaner": "Vacuum cleaner", 269 | "iron_for_clothes": "Iron (for clothes)", 270 | "laptop_open": "Laptop (open)", 271 | "poster": "Poster", 272 | "coffee_machine": "Coffee machine", 273 | "tie": "Tie", 274 | "cd_case": "CD case", 275 | "baseball_bat": "Baseball bat", 276 | "tablet_ipad": "Tablet / iPad", 277 | "bottle_opener": "Bottle opener", 278 | "briefcase": "Briefcase", 279 | "baking_sheet": "Baking sheet", 280 | "screw": "Screw", 281 | "pitcher": "Pitcher", 282 | "notepad": "Notepad", 283 | "tote_bag": "Tote bag", 284 | "raincoat": "Raincoat", 285 | "necklace": "Necklace", 286 | "band_aid": "Band Aid", 287 | "notebook": "Notebook", 288 | "measuring_cup": "Measuring cup", 289 | "weight_exercise": "Weight (exercise)", 290 | "handbag": "Handbag", 291 | "bike_pump": "Bike pump", 292 | "bottle_stopper": "Bottle stopper", 293 | "chocolate": "Chocolate", 294 | "safety_pin": "Safety pin", 295 | "plastic_cup": "Plastic cup", 296 | "butter": "Butter", 297 | "cellphone": "Cellphone", 298 | "drying_rack_for_dishes": "Drying rack for plates", 299 | "trash_bag": "Trash bag", 300 | "tray": "Tray", 301 | "wine_bottle": "Wine bottle", 302 | "whistle": "Whistle", 303 | "key_chain": "Key chain", 304 | "napkin": "Napkin", 305 | "desk_lamp": "Desk lamp", 306 | "first_aid_kit": "First aid kit", 307 | "bed_sheet": "Bed sheet", 308 | "beer_can": "Beer can", 309 | "wrench": "Wrench", 310 | "pop_can": "Pop can", 311 | "basket": "Basket", 312 | "leggings": "Leggings", 313 | "egg": "Egg", 314 | "sewing_kit": "Sewing kit" 315 | } 316 | -------------------------------------------------------------------------------- /utils/mappings/imagenet100.txt: -------------------------------------------------------------------------------- 1 | n02869837 2 | n01749939 3 | n02488291 4 | n02107142 5 | n13037406 6 | n02091831 7 | n04517823 8 | n04589890 9 | n03062245 10 | n01773797 11 | n01735189 12 | n07831146 13 | n07753275 14 | n03085013 15 | n04485082 16 | n02105505 17 | n01983481 18 | n02788148 19 | n03530642 20 | n04435653 21 | n02086910 22 | n02859443 23 | n13040303 24 | n03594734 25 | n02085620 26 | n02099849 27 | n01558993 28 | n04493381 29 | n02109047 30 | n04111531 31 | n02877765 32 | n04429376 33 | n02009229 34 | n01978455 35 | n02106550 36 | n01820546 37 | n01692333 38 | n07714571 39 | n02974003 40 | n02114855 41 | n03785016 42 | n03764736 43 | n03775546 44 | n02087046 45 | n07836838 46 | n04099969 47 | n04592741 48 | n03891251 49 | n02701002 50 | n03379051 51 | n02259212 52 | n07715103 53 | n03947888 54 | n04026417 55 | n02326432 56 | n03637318 57 | n01980166 58 | n02113799 59 | n02086240 60 | n03903868 61 | n02483362 62 | n04127249 63 | n02089973 64 | n03017168 65 | n02093428 66 | n02804414 67 | n02396427 68 | n04418357 69 | n02172182 70 | n01729322 71 | n02113978 72 | n03787032 73 | n02089867 74 | n02119022 75 | n03777754 76 | n04238763 77 | n02231487 78 | n03032252 79 | n02138441 80 | n02104029 81 | n03837869 82 | n03494278 83 | n04136333 84 | n03794056 85 | n03492542 86 | n02018207 87 | n04067472 88 | n03930630 89 | n03584829 90 | n02123045 91 | n04229816 92 | n02100583 93 | n03642806 94 | n04336792 95 | n03259280 96 | n02116738 97 | n02108089 98 | n03424325 99 | n01855672 100 | n02090622 -------------------------------------------------------------------------------- /utils/mappings/imagenet100_to_labels.json: -------------------------------------------------------------------------------- 1 | {"15": "robin, American robin, Turdus migratorius", "45": "Gila monster, Heloderma suspectum", "54": "hognose snake, puff adder, sand viper", "57": "garter snake, grass snake", "64": "green mamba", "74": "garden spider, Aranea diademata", "90": "lorikeet", "99": "goose", "119": "rock crab, Cancer irroratus", "120": "fiddler crab", "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus", "131": "little blue heron, Egretta caerulea", "137": "American coot, marsh hen, mud hen, water hen, Fulica americana", "151": "Chihuahua", "155": "Shih-Tzu", "157": "papillon", "158": "toy terrier", "166": "Walker hound, Walker foxhound", "167": "English foxhound", "169": "borzoi, Russian wolfhound", "176": "Saluki, gazelle hound", "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", "209": "Chesapeake Bay retriever", "211": "vizsla, Hungarian pointer", "222": "kuvasz", "228": "komondor", "234": "Rottweiler", "236": "Doberman, Doberman pinscher", "242": "boxer", "246": "Great Dane", "267": "standard poodle", "268": "Mexican hairless", "272": "coyote, prairie wolf, brush wolf, Canis latrans", "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", "277": "red fox, Vulpes vulpes", "281": "tabby, tabby cat", "299": "meerkat, mierkat", "305": "dung beetle", "313": "walking stick, walkingstick, stick insect", "317": "leafhopper", "331": "hare", "342": "wild boar, boar, Sus scrofa", "368": "gibbon, Hylobates lar", "374": "langur", "407": "ambulance", "421": "bannister, banister, balustrade, balusters, handrail", "431": "bassinet", "449": "boathouse", "452": "bonnet, poke bonnet", "455": "bottlecap", "479": "car wheel", "494": "chime, bell, gong", "498": "cinema, movie theater, movie theatre, movie house, picture palace", "503": "cocktail shaker", "508": "computer keyboard, keypad", "544": "Dutch oven", "560": "football helmet", "570": "gasmask, respirator, gas helmet", "592": "hard disc, hard disk, fixed disk", "593": "harmonica, mouth organ, harp, mouth harp", "599": "honeycomb", "606": "iron, smoothing iron", "608": "jean, blue jean, denim", "619": "lampshade, lamp shade", "620": "laptop, laptop computer", "653": "milk can", "659": "mixing bowl", "662": "modem", "665": "moped", "667": "mortarboard", "674": "mousetrap", "682": "obelisk", "703": "park bench", "708": "pedestal, plinth, footstall", "717": "pickup, pickup truck", "724": "pirate, pirate ship", "748": "purse", "758": "reel", "765": "rocking chair, rocker", "766": "rotisserie", "772": "safety pin", "775": "sarong", "796": "ski mask", "798": "slide rule, slipstick", "830": "stretcher", "854": "theater curtain, theatre curtain", "857": "throne", "858": "tile roof", "872": "tripod", "876": "tub, vat", "882": "vacuum, vacuum cleaner", "904": "window screen", "908": "wing", "936": "head cabbage", "938": "cauliflower", "953": "pineapple, ananas", "959": "carbonara", "960": "chocolate sauce, chocolate syrup", "993": "gyromitra", "994": "stinkhorn, carrion fungus"} -------------------------------------------------------------------------------- /utils/mappings/imagenet_pytorch_id_to_objectnet_id.json: -------------------------------------------------------------------------------- 1 | {"409": 1, "530": 1, "414": 2, "954": 4, "419": 5, "790": 8, "434": 9, "440": 13, "703": 16, "671": 17, "444": 17, "446": 20, "455": 29, "930": 35, "462": 38, "463": 39, "499": 40, "473": 45, "470": 46, "487": 48, "423": 52, "559": 52, "765": 52, "588": 57, "550": 64, "507": 67, "673": 68, "846": 75, "533": 78, "539": 81, "630": 86, "740": 88, "968": 89, "729": 92, "549": 98, "545": 102, "567": 109, "578": 83, "589": 112, "587": 115, "560": 120, "518": 120, "606": 124, "608": 128, "508": 131, "618": 132, "619": 133, "620": 134, "951": 138, "623": 139, "626": 142, "629": 143, "644": 149, "647": 150, "651": 151, "659": 153, "664": 154, "504": 157, "677": 159, "679": 164, "950": 171, "695": 173, "696": 175, "700": 179, "418": 182, "749": 182, "563": 182, "720": 188, "721": 190, "725": 191, "728": 193, "923": 196, "731": 199, "737": 200, "811": 201, "742": 205, "761": 210, "769": 216, "770": 217, "772": 218, "773": 219, "774": 220, "783": 223, "792": 229, "601": 231, "655": 231, "689": 231, "797": 232, "804": 235, "806": 236, "809": 237, "813": 238, "632": 239, "732": 248, "759": 248, "828": 250, "850": 251, "834": 253, "837": 255, "841": 256, "842": 257, "610": 258, "851": 259, "849": 268, "752": 269, "457": 273, "906": 273, "859": 275, "999": 276, "412": 284, "868": 286, "879": 289, "882": 292, "883": 293, "893": 297, "531": 298, "898": 299, "543": 302, "778": 303, "479": 304, "694": 304, "902": 306, "907": 307, "658": 309, "909": 310} -------------------------------------------------------------------------------- /utils/mappings/objectnet_im100_folder.json: -------------------------------------------------------------------------------- 1 | {"Bench": "n03891251", "Bottle cap": "n02877765", "Chair": "n04099969", "Helmet": "n03379051", "Iron (for clothes)": "n03584829", "Jeans": "n03594734", "Keyboard": "n03085013", "Lampshade": "n03637318", "Laptop (open)": "n03642806", "Mixing / Salad Bowl": "n03775546", "Safety pin": "n04127249", "Vacuum cleaner": "n04517823", "Wheel": "n02974003"} -------------------------------------------------------------------------------- /utils/mappings/objectnet_im1k_folder.json: -------------------------------------------------------------------------------- 1 | {"Alarm clock": "n02708093", "Backpack": "n02769748", "Banana": "n07753592", "Band Aid": "n02786058", "Basket": "n04204238", "Bath towel": "n02808304", "Beer bottle": "n02823428", "Bench": "n03891251", "Bicycle": "n03792782", "Binder (closed)": "n02840245", "Bottle cap": "n02877765", "Bread loaf": "n07684084", "Broom": "n02906734", "Bucket": "n02909870", "Butcher's knife": "n03041632", "Can opener": "n02951585", "Candle": "n02948072", "Cellphone": "n02992529", "Chair": "n02791124", "Clothes hamper": "n03482405", "Coffee/French press": "n03297495", "Combination lock": "n03075370", "Computer mouse": "n03793489", "Desk lamp": "n04380533", "Dishrag or hand towel": "n03207743", "Doormat": "n03223299", "Dress shoe (men)": "n03680355", "Drill": "n03995372", "Drinking Cup": "n07930864", "Drying rack for plates": "n03961711", "Envelope": "n03291819", "Fan": "n03271574", "Frying pan": "n03400231", "Dress": "n03450230", "Hair dryer": "n03483316", "Hammer": "n03481172", "Helmet": "n03379051", "Iron (for clothes)": "n03584829", "Jeans": "n03594734", "Keyboard": "n03085013", "Ladle": "n03633091", "Lampshade": "n03637318", "Laptop (open)": "n03642806", "Lemon": "n07749582", "Letter opener": "n03658185", "Lighter": "n03666591", "Lipstick": "n03676483", "Match": "n03729826", "Measuring cup": "n03733805", "Microwave": "n03761084", "Mixing / Salad Bowl": "n03775546", "Monitor": "n03782006", "Mug": "n03063599", "Nail (fastener)": "n03804744", "Necklace": "n03814906", "Orange": "n07747607", "Padlock": "n03874599", "Paintbrush": "n03876231", "Paper towel": "n03887697", "Pen": "n02783161", "Pill bottle": "n03937543", "Pillow": "n03938244", "Pitcher": "n03950228", "Plastic bag": "n03958227", "Plate": "n07579787", "Plunger": "n03970156", "Pop can": "n03983396", "Portable heater": "n04265275", "Printer": "n04004767", "Remote control": "n04074963", "Ruler": "n04118776", "Running shoe": "n04120489", "Safety pin": "n04127249", "Salt shaker": "n04131690", "Sandal": "n04133789", "Screw": "n04153751", "Shovel": "n04208210", "Skirt": "n03534580", "Sleeping bag": "n04235860", "Soap dispenser": "n04254120", "Sock": "n04254777", "Soup Bowl": "n04263257", "Spatula": "n04270147", "Speaker": "n03691459", "Still Camera": "n03976467", "Strainer": "n04332243", "Stuffed animal": "n04399382", "Suit jacket": "n04350905", "Sunglasses": "n04356056", "Sweater": "n04370456", "Swimming trunks": "n04371430", "T-shirt": "n03595614", "TV": "n04404412", "Teapot": "n04398044", "Tennis racket": "n04039381", "Tie": "n02883205", "Toaster": "n04442312", "Toilet paper roll": "n15075141", "Trash bin": "n02747177", "Tray": "n04476259", "Umbrella": "n04507155", "Vacuum cleaner": "n04517823", "Vase": "n04522168", "Wallet": "n04548362", "Watch": "n03197337", "Water bottle": "n04557648", "Weight (exercise)": "n03255030", "Weight scale": "n04141975", "Wheel": "n02974003", "Whistle": "n04579432", "Wine bottle": "n04591713", "Winter glove": "n03775071", "Wok": "n04596742"} -------------------------------------------------------------------------------- /utils/mappings/objectnet_to_im100.json: -------------------------------------------------------------------------------- 1 | {"Bench": "park bench", "Bottle cap": "bottlecap", "Chair": "rocking chair, rocker", "Helmet": "football helmet", "Iron (for clothes)": "iron, smoothing iron", "Jeans": "jean, blue jean, denim", "Keyboard": "computer keyboard, keypad", "Lampshade": "lampshade, lamp shade", "Laptop (open)": "laptop, laptop computer", "Mixing / Salad Bowl": "mixing bowl", "Safety pin": "safety pin", "Vacuum cleaner": "vacuum, vacuum cleaner", "Wheel": "car wheel"} -------------------------------------------------------------------------------- /utils/mappings/objectnet_to_imagenet_1k.json: -------------------------------------------------------------------------------- 1 | { 2 | "Alarm clock": "analog clock; digital clock", 3 | "Backpack": "backpack, back pack, knapsack, packsack, rucksack, haversack", 4 | "Banana": "banana", 5 | "Band Aid": "Band Aid", 6 | "Basket": "shopping basket", 7 | "Bath towel": "bath towel", 8 | "Beer bottle": "beer bottle", 9 | "Bench": "park bench", 10 | "Bicycle": "mountain bike, all-terrain bike, off-roader; bicycle-built-for-two, tandem bicycle, tandem", 11 | "Binder (closed)": "binder, ring-binder", 12 | "Bottle cap": "bottlecap", 13 | "Bread loaf": "French loaf", 14 | "Broom": "broom", 15 | "Bucket": "bucket, pail", 16 | "Butcher's knife": "cleaver, meat cleaver, chopper", 17 | "Can opener": "can opener, tin opener", 18 | "Candle": "candle, taper, wax light", 19 | "Cellphone": "cellular telephone, cellular phone, cellphone, cell, mobile phone", 20 | "Chair": "barber chair; folding chair; rocking chair, rocker", 21 | "Clothes hamper": "hamper", 22 | "Coffee/French press": "espresso maker", 23 | "Combination lock": "combination lock", 24 | "Computer mouse": "mouse, computer mouse", 25 | "Desk lamp": "table lamp", 26 | "Dishrag or hand towel": "dishrag, dishcloth", 27 | "Doormat": "doormat, welcome mat", 28 | "Dress shoe (men)": "Loafer", 29 | "Drill": "power drill", 30 | "Drinking Cup": "cup", 31 | "Drying rack for plates": "plate rack", 32 | "Envelope": "envelope", 33 | "Fan": "electric fan, blower", 34 | "Frying pan": "frying pan, frypan, skillet", 35 | "Dress": "gown", 36 | "Hair dryer": "hand blower, blow dryer, blow drier, hair dryer, hair drier", 37 | "Hammer": "hammer", 38 | "Helmet": "football helmet; crash helmet", 39 | "Iron (for clothes)": "iron, smoothing iron", 40 | "Jeans": "jean, blue jean, denim", 41 | "Keyboard": "computer keyboard, keypad", 42 | "Ladle": "ladle", 43 | "Lampshade": "lampshade, lamp shade", 44 | "Laptop (open)": "laptop, laptop computer", 45 | "Lemon": "lemon", 46 | "Letter opener": "letter opener, paper knife, paperknife", 47 | "Lighter": "lighter, light, igniter, ignitor", 48 | "Lipstick": "lipstick, lip rouge", 49 | "Match": "matchstick", 50 | "Measuring cup": "measuring cup", 51 | "Microwave": "microwave, microwave oven", 52 | "Mixing / Salad Bowl": "mixing bowl", 53 | "Monitor": "monitor", 54 | "Mug": "coffee mug", 55 | "Nail (fastener)": "nail", 56 | "Necklace": "necklace", 57 | "Orange": "orange", 58 | "Padlock": "padlock", 59 | "Paintbrush": "paintbrush", 60 | "Paper towel": "paper towel", 61 | "Pen": "ballpoint, ballpoint pen, ballpen, Biro; quill, quill pen; fountain pen", 62 | "Pill bottle": "pill bottle", 63 | "Pillow": "pillow", 64 | "Pitcher": "pitcher, ewer", 65 | "Plastic bag": "plastic bag", 66 | "Plate": "plate", 67 | "Plunger": "plunger, plumber's helper", 68 | "Pop can": "pop bottle, soda bottle", 69 | "Portable heater": "space heater", 70 | "Printer": "printer", 71 | "Remote control": "remote control, remote", 72 | "Ruler": "rule, ruler", 73 | "Running shoe": "running shoe", 74 | "Safety pin": "safety pin", 75 | "Salt shaker": "saltshaker, salt shaker", 76 | "Sandal": "sandal", 77 | "Screw": "screw", 78 | "Shovel": "shovel", 79 | "Skirt": "hoopskirt, crinoline; miniskirt, mini; overskirt", 80 | "Sleeping bag": "sleeping bag", 81 | "Soap dispenser": "soap dispenser", 82 | "Sock": "sock", 83 | "Soup Bowl": "soup bowl", 84 | "Spatula": "spatula", 85 | "Speaker": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", 86 | "Still Camera": "Polaroid camera, Polaroid Land camera; reflex camera", 87 | "Strainer": "strainer", 88 | "Stuffed animal": "teddy, teddy bear", 89 | "Suit jacket": "suit, suit of clothes", 90 | "Sunglasses": "sunglasses, dark glasses, shades", 91 | "Sweater": "sweatshirt", 92 | "Swimming trunks": "swimming trunks, bathing trunks", 93 | "T-shirt": "jersey, T-shirt, tee shirt", 94 | "TV": "television, television system", 95 | "Teapot": "teapot", 96 | "Tennis racket": "racket, racquet", 97 | "Tie": "bow tie, bow-tie, bowtie; Windsor tie", 98 | "Toaster": "toaster", 99 | "Toilet paper roll": "toilet tissue, toilet paper, bathroom tissue", 100 | "Trash bin": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", 101 | "Tray": "tray", 102 | "Umbrella": "umbrella", 103 | "Vacuum cleaner": "vacuum, vacuum cleaner", 104 | "Vase": "vase", 105 | "Wallet": "wallet, billfold, notecase, pocketbook", 106 | "Watch": "digital watch", 107 | "Water bottle": "water bottle", 108 | "Weight (exercise)": "dumbbell", 109 | "Weight scale": "scale, weighing machine", 110 | "Wheel": "car wheel; paddlewheel, paddle wheel", 111 | "Whistle": "whistle", 112 | "Wine bottle": "wine bottle", 113 | "Winter glove": "mitten", 114 | "Wok": "wok" 115 | } -------------------------------------------------------------------------------- /utils/rename_imagenet_v2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pandas as pd 4 | from tqdm import tqdm 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser(description='ImageNetV2 Rename') 8 | parser.add_argument('--imagenetv2_path', metavar='IMAGENET_DIR', 9 | help='path to the existing full ImageNetV2 dataset') 10 | parser.add_argument('--datafile', help='path to labels file containing class name and folder name mapping') 11 | args = parser.parse_args() 12 | 13 | folder_names = os.listdir(args.imagenetv2_path) 14 | 15 | df = pd.read_csv(args.datafile) 16 | 17 | classnames = [classname.split('/')[1] for classname in list(df['image'])] 18 | classnames = [i for n, i in enumerate(classnames) if i not in classnames[:n]] 19 | 20 | for i in tqdm(range(1000)): 21 | os.rename(os.path.join(args.imagenetv2_path, f'{i}'), os.path.join(args.imagenetv2_path, f'{classnames[i]}')) -------------------------------------------------------------------------------- /utils/subset_objectnet_im.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import shutil 4 | import argparse 5 | from tqdm import tqdm 6 | 7 | def create_subset(args): 8 | 9 | dirname = 'Objectnet/val' 10 | os.makedirs(os.path.join(args.im100_root, dirname), exist_ok = True) 11 | 12 | with open(args.folder_to_label, 'r') as f: 13 | folder_to_label = json.load(f) 14 | label_to_folder = {v:k for k,v in folder_to_label.items()} 15 | 16 | with open(args.objectnet_to_im100_folder, 'r') as f: 17 | objectnet_to_im100_folder = json.load(f) 18 | 19 | for label in tqdm(objectnet_to_im100_folder): 20 | folder_name = label_to_folder[label] 21 | images = os.listdir(os.path.join(args.objectnet_root, folder_name)) 22 | im100_folder_name = objectnet_to_im100_folder[label] 23 | os.makedirs(os.path.join(args.im100_root, dirname, im100_folder_name), exist_ok = True) 24 | print(im100_folder_name) 25 | for image in images: 26 | shutil.copy(os.path.join(args.objectnet_root, folder_name, image), os.path.join(args.im100_root, dirname, im100_folder_name)) 27 | 28 | if __name__ == '__main__': 29 | parser = argparse.ArgumentParser(description='ImageNetV2 Rename') 30 | parser.add_argument('--objectnet_root', type=str, default=None) 31 | parser.add_argument('--im100_root', type=str, default=None) 32 | parser.add_argument('--folder_to_label', type=str, default=None) 33 | parser.add_argument('--objectnet_to_im100_folder', type=str, default=None) 34 | args = parser.parse_args() 35 | 36 | create_subset(args) --------------------------------------------------------------------------------