├── .gitignore ├── requirements.txt ├── datasets └── download_datasets.sh ├── scripts ├── swav │ └── train_interpreter.sh ├── swav_w2 │ └── train_interpreter.sh ├── mae │ └── train_interpreter.sh ├── datasetDDPM │ ├── train_interpreter.sh │ ├── generate_dataset.sh │ └── train_deeplab.sh └── ddpm │ └── train_interpreter.sh ├── checkpoints ├── ddpm │ └── download_checkpoint.sh ├── mae │ └── download_checkpoint.sh ├── swav │ └── download_checkpoint.sh └── swav_w2 │ └── download_checkpoint.sh ├── synthetic_datasets ├── ddpm │ └── download_synthetic_dataset.sh └── gan │ └── download_synthetic_dataset.sh ├── .gitmodules ├── experiments ├── cat_15 │ ├── swav.json │ ├── mae.json │ ├── swav_w2.json │ ├── ddpm.json │ └── datasetDDPM.json ├── ffhq_34 │ ├── mae.json │ ├── swav.json │ ├── swav_w2.json │ ├── ddpm.json │ └── datasetDDPM.json ├── celeba_19 │ ├── swav.json │ ├── mae.json │ ├── swav_w2.json │ └── ddpm.json ├── bedroom_28 │ ├── mae.json │ ├── swav.json │ ├── swav_w2.json │ ├── ddpm.json │ └── datasetDDPM.json ├── ade_bedroom_30 │ ├── mae.json │ ├── swav.json │ ├── swav_w2.json │ └── ddpm.json └── horse_21 │ ├── mae.json │ ├── swav.json │ ├── swav_w2.json │ ├── datasetDDPM.json │ └── ddpm.json ├── LICENSE ├── src ├── utils.py ├── datasets.py ├── pixel_classifier.py ├── data_util.py └── feature_extractors.py ├── generate_dataset.py ├── train_interpreter.py ├── train_deeplab.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *__pycache__/ 3 | *.pt 4 | *.pth 5 | pixel_classifiers 6 | 7 | *.npy 8 | *.jpg 9 | *.png 10 | *.npz -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10.0 2 | torchvision==0.11.1 3 | blobfile==1.2.7 4 | tqdm==4.62.3 5 | opencv-python==4.5.4.60 6 | mpi4py 7 | timm==0.4.12 -------------------------------------------------------------------------------- /datasets/download_datasets.sh: -------------------------------------------------------------------------------- 1 | wget -c https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/datasets.tar.gz 2 | tar -xzf datasets.tar.gz -C datasets/ 3 | rm datasets.tar.gz -------------------------------------------------------------------------------- /scripts/swav/train_interpreter.sh: -------------------------------------------------------------------------------- 1 | DATASET=$1 # Available datasets: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30 2 | 3 | python train_interpreter.py --exp experiments/${DATASET}/swav.json -------------------------------------------------------------------------------- /scripts/swav_w2/train_interpreter.sh: -------------------------------------------------------------------------------- 1 | DATASET=$1 # Available datasets: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30 2 | 3 | python train_interpreter.py --exp experiments/${DATASET}/swav_w2.json -------------------------------------------------------------------------------- /scripts/mae/train_interpreter.sh: -------------------------------------------------------------------------------- 1 | export OMP_NUM_THREADS=1 2 | 3 | DATASET=$1 # Available datasets: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30 4 | 5 | python train_interpreter.py --exp experiments/${DATASET}/mae.json -------------------------------------------------------------------------------- /checkpoints/ddpm/download_checkpoint.sh: -------------------------------------------------------------------------------- 1 | DATASET=$1 # Available datasets: lsun_bedroom, ffhq, lsun_cat, lsun_horse 2 | 3 | wget https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/models/ddpm_checkpoints/${DATASET}.pt -P checkpoints/ddpm/ -------------------------------------------------------------------------------- /checkpoints/mae/download_checkpoint.sh: -------------------------------------------------------------------------------- 1 | DATASET=$1 # Available datasets: lsun_bedroom, ffhq, lsun_cat, lsun_horse 2 | 3 | wget https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/models/mae_checkpoints/${DATASET}.pth -P checkpoints/mae -------------------------------------------------------------------------------- /checkpoints/swav/download_checkpoint.sh: -------------------------------------------------------------------------------- 1 | DATASET=$1 # Available datasets: lsun_bedroom, ffhq, lsun_cat, lsun_horse 2 | 3 | wget https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/models/swav_checkpoints/${DATASET}.pth -P checkpoints/swav -------------------------------------------------------------------------------- /checkpoints/swav_w2/download_checkpoint.sh: -------------------------------------------------------------------------------- 1 | DATASET=$1 # Available datasets: lsun_bedroom, ffhq, lsun_cat, lsun_horse 2 | 3 | wget https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/models/swav_w2_checkpoints/${DATASET}.pth -P checkpoints/swav_w2 -------------------------------------------------------------------------------- /synthetic_datasets/ddpm/download_synthetic_dataset.sh: -------------------------------------------------------------------------------- 1 | DATASET=$1 2 | 3 | wget -c https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/synthetic-datasets/ddpm/${DATASET}.tar.gz 4 | tar -xzf ${DATASET}.tar.gz -C synthetic_datasets/ddpm/ 5 | rm ${DATASET}.tar.gz -------------------------------------------------------------------------------- /synthetic_datasets/gan/download_synthetic_dataset.sh: -------------------------------------------------------------------------------- 1 | DATASET=$1 2 | 3 | wget -c https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/synthetic-datasets/gan/${DATASET}.tar.gz 4 | tar -xzf ${DATASET}.tar.gz -C synthetic_datasets/gan/ 5 | rm ${DATASET}.tar.gz -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "guided_diffusion"] 2 | path = guided_diffusion 3 | url = https://github.com/openai/guided-diffusion 4 | [submodule "swav"] 5 | path = swav 6 | url = https://github.com/facebookresearch/swav 7 | [submodule "mae"] 8 | path = mae 9 | url = https://github.com/dbaranchuk/mae 10 | -------------------------------------------------------------------------------- /scripts/datasetDDPM/train_interpreter.sh: -------------------------------------------------------------------------------- 1 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --dropout 0.1 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 2 | DATASET=$1 # Available datasets: bedroom_28, ffhq_34, cat_15, horse_21 3 | 4 | python train_interpreter.py --exp experiments/${DATASET}/datasetDDPM.json $MODEL_FLAGS 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /scripts/ddpm/train_interpreter.sh: -------------------------------------------------------------------------------- 1 | # Note: do not forget to change MODEL_FLAGS if other pretrained DDPMs are used. 2 | 3 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --dropout 0.1 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 4 | DATASET=$1 # Available datasets: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30 5 | 6 | python train_interpreter.py --exp experiments/${DATASET}/ddpm.json $MODEL_FLAGS -------------------------------------------------------------------------------- /experiments/cat_15/swav.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/cat_15/swav", 3 | "model_type": "swav", 4 | 5 | "category": "cat_15", 6 | "number_class": 15, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/cat_15/real/train", 10 | "validation_path": "datasets/cat_15/ddpm", 11 | "testing_path": "datasets/cat_15/real/test", 12 | "model_path": "checkpoints/swav/lsun_cat.pth", 13 | 14 | "dim": [256, 256, 6720], 15 | "steps": [], 16 | "blocks": [], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 30, 23 | "testing_number": 20, 24 | 25 | "upsample_mode":"bilinear", 26 | "input_activations": true 27 | } -------------------------------------------------------------------------------- /experiments/cat_15/mae.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/cat_15/mae", 3 | "model_type": "mae", 4 | 5 | "category": "cat_15", 6 | "number_class": 15, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/cat_15/real/train", 10 | "validation_path": "datasets/cat_15/ddpm", 11 | "testing_path": "datasets/cat_15/real/test", 12 | "model_path": "checkpoints/mae/lsun_cat.pth", 13 | 14 | "dim": [256, 256, 12288], 15 | "steps": [], 16 | "blocks": [], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 30, 23 | "testing_number": 20, 24 | 25 | "upsample_mode":"bilinear", 26 | "input_activations": false 27 | } 28 | -------------------------------------------------------------------------------- /experiments/ffhq_34/mae.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/ffhq_34/mae", 3 | "model_type": "mae", 4 | 5 | "category": "ffhq_34", 6 | "number_class": 34, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/ffhq_34/real/train", 10 | "validation_path": "datasets/ffhq_34/ddpm", 11 | "testing_path": "datasets/ffhq_34/real/test", 12 | "model_path": "checkpoints/mae/ffhq.pth", 13 | 14 | "dim": [256, 256, 12288], 15 | "steps": [], 16 | "blocks": [], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 20, 23 | "testing_number": 20, 24 | 25 | "upsample_mode":"bilinear", 26 | "input_activations": false 27 | } 28 | -------------------------------------------------------------------------------- /experiments/ffhq_34/swav.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/ffhq_34/swav", 3 | "model_type": "swav", 4 | 5 | "category": "ffhq_34", 6 | "number_class": 34, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/ffhq_34/real/train", 10 | "validation_path": "datasets/ffhq_34/ddpm", 11 | "testing_path": "datasets/ffhq_34/real/test", 12 | "model_path": "checkpoints/swav/ffhq.pth", 13 | 14 | "dim": [256, 256, 6720], 15 | "steps": [], 16 | "blocks": [], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 20, 23 | "testing_number": 20, 24 | 25 | "upsample_mode":"bilinear", 26 | "input_activations": true 27 | } 28 | -------------------------------------------------------------------------------- /experiments/cat_15/swav_w2.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/cat_15/swav_w2", 3 | "model_type": "swav_w2", 4 | 5 | "category": "cat_15", 6 | "number_class": 15, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/cat_15/real/train", 10 | "validation_path": "datasets/cat_15/ddpm", 11 | "testing_path": "datasets/cat_15/real/test", 12 | "model_path": "checkpoints/swav_w2/lsun_cat.pth", 13 | 14 | "dim": [256, 256, 13440], 15 | "steps": [], 16 | "blocks": [], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 30, 23 | "testing_number": 20, 24 | 25 | "upsample_mode":"bilinear", 26 | "input_activations": true 27 | } -------------------------------------------------------------------------------- /experiments/celeba_19/swav.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/celeba_19/swav", 3 | "model_type": "swav", 4 | 5 | "category": "celeba_19", 6 | "number_class": 19, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/celeba_19/real/train", 10 | "validation_path": "datasets/celeba_19/real/train", 11 | "testing_path": "datasets/celeba_19/real/test", 12 | "model_path": "checkpoints/swav/ffhq.pth", 13 | 14 | "dim": [256, 256, 6720], 15 | "steps": [], 16 | "blocks": [], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 20, 23 | "testing_number": 500, 24 | 25 | "upsample_mode":"bilinear", 26 | "input_activations": true 27 | } -------------------------------------------------------------------------------- /experiments/ffhq_34/swav_w2.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/ffhq_34/swav_W2", 3 | "model_type": "swav_w2", 4 | 5 | "category": "ffhq_34", 6 | "number_class": 34, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/ffhq_34/real/train", 10 | "validation_path": "datasets/ffhq_34/ddpm", 11 | "testing_path": "datasets/ffhq_34/real/test", 12 | "model_path": "checkpoints/swav_w2/ffhq.pth", 13 | 14 | "dim": [256, 256, 13440], 15 | "steps": [], 16 | "blocks": [], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 20, 23 | "testing_number": 20, 24 | 25 | "upsample_mode":"bilinear", 26 | "input_activations": true 27 | } 28 | -------------------------------------------------------------------------------- /experiments/bedroom_28/mae.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/bedroom_28/mae", 3 | "model_type": "mae", 4 | 5 | "category": "bedroom_28", 6 | "number_class": 29, 7 | "ignore_label": 0, 8 | 9 | "training_path": "datasets/bedroom_28/real/train", 10 | "validation_path": "datasets/bedroom_28/ddpm", 11 | "testing_path": "datasets/bedroom_28/real/test", 12 | "model_path": "checkpoints/mae/lsun_bedroom.pth", 13 | 14 | "dim": [256, 256, 12288], 15 | "steps": [], 16 | "blocks": [], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 40, 23 | "testing_number": 20, 24 | 25 | "upsample_mode":"bilinear", 26 | "input_activations": false 27 | } -------------------------------------------------------------------------------- /experiments/celeba_19/mae.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/celeba_19/mae", 3 | "model_type": "mae", 4 | 5 | "category": "celeba_19", 6 | "number_class": 19, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/celeba_19/real/train", 10 | "validation_path": "datasets/celeba_19/real/train", 11 | "testing_path": "datasets/celeba_19/real/test", 12 | "model_path": "checkpoints/mae/ffhq.pth", 13 | 14 | "dim": [256, 256, 12288], 15 | "steps": [], 16 | "blocks": [], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 20, 23 | "testing_number": 500, 24 | 25 | "upsample_mode":"bilinear", 26 | "input_activations": false 27 | } 28 | -------------------------------------------------------------------------------- /experiments/bedroom_28/swav.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/bedroom_28/swav", 3 | "model_type": "swav", 4 | 5 | "category": "bedroom_28", 6 | "number_class": 29, 7 | "ignore_label": 0, 8 | 9 | "training_path": "datasets/bedroom_28/real/train", 10 | "validation_path": "datasets/bedroom_28/ddpm", 11 | "testing_path": "datasets/bedroom_28/real/test", 12 | "model_path": "checkpoints/swav/lsun_bedroom.pth", 13 | 14 | "dim": [256, 256, 6720], 15 | "steps": [], 16 | "blocks": [], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 40, 23 | "testing_number": 20, 24 | 25 | "upsample_mode":"bilinear", 26 | "input_activations": true 27 | } -------------------------------------------------------------------------------- /experiments/celeba_19/swav_w2.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/celeba_19/swav_w2", 3 | "model_type": "swav_w2", 4 | 5 | "category": "celeba_19", 6 | "number_class": 19, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/celeba_19/real/train", 10 | "validation_path": "datasets/celeba_19/real/train", 11 | "testing_path": "datasets/celeba_19/real/test", 12 | "model_path": "checkpoints/swav_w2/ffhq.pth", 13 | 14 | "dim": [256, 256, 13440], 15 | "steps": [], 16 | "blocks": [], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 20, 23 | "testing_number": 500, 24 | 25 | "upsample_mode":"bilinear", 26 | "input_activations": true 27 | } -------------------------------------------------------------------------------- /experiments/bedroom_28/swav_w2.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/bedroom_28/swav_w2", 3 | "model_type": "swav_w2", 4 | 5 | "category": "bedroom_28", 6 | "number_class": 29, 7 | "ignore_label": 0, 8 | 9 | "training_path": "datasets/bedroom_28/real/train", 10 | "validation_path": "datasets/bedroom_28/ddpm", 11 | "testing_path": "datasets/bedroom_28/real/test", 12 | "model_path": "checkpoints/swav_w2/lsun_bedroom.pth", 13 | 14 | "dim": [256, 256, 13440], 15 | "steps": [], 16 | "blocks": [], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 40, 23 | "testing_number": 20, 24 | 25 | "upsample_mode":"bilinear", 26 | "input_activations": true 27 | } -------------------------------------------------------------------------------- /experiments/ffhq_34/ddpm.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/ffhq_34/ddpm", 3 | "model_type": "ddpm", 4 | 5 | "category": "ffhq_34", 6 | "number_class": 34, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/ffhq_34/real/train", 10 | "validation_path": "datasets/ffhq_34/ddpm", 11 | "testing_path": "datasets/ffhq_34/real/test", 12 | "model_path": "checkpoints/ddpm/ffhq.pt", 13 | 14 | "dim": [256, 256, 8448], 15 | "steps": [50, 150, 250], 16 | "blocks": [5, 6, 7, 8, 12], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 20, 23 | "testing_number": 20, 24 | 25 | "upsample_mode":"bilinear", 26 | "share_noise": true, 27 | "input_activations": false 28 | } -------------------------------------------------------------------------------- /experiments/ade_bedroom_30/mae.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/ade_bedroom_30/mae", 3 | "model_type": "mae", 4 | 5 | "category": "ade_bedroom_30", 6 | "number_class": 30, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/ade_bedroom_30/real/train", 10 | "validation_path": "datasets/ade_bedroom_30/real/train", 11 | "testing_path": "datasets/ade_bedroom_30/real/test", 12 | "model_path": "checkpoints/mae/lsun_bedroom.pth", 13 | 14 | "dim": [256, 256, 12288], 15 | "steps": [], 16 | "blocks": [], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "upsample_mode":"bilinear", 23 | "training_number": 50, 24 | "testing_number": 650, 25 | 26 | "input_activations": false 27 | } -------------------------------------------------------------------------------- /experiments/ade_bedroom_30/swav.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/ade_bedroom_30/swav", 3 | "model_type": "swav", 4 | 5 | "category": "ade_bedroom_30", 6 | "number_class": 30, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/ade_bedroom_30/real/train", 10 | "validation_path": "datasets/ade_bedroom_30/real/train", 11 | "testing_path": "datasets/ade_bedroom_30/real/test", 12 | "model_path": "checkpoints/swav/lsun_bedroom.pth", 13 | 14 | "dim": [256, 256, 6720], 15 | "steps": [], 16 | "blocks": [], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "upsample_mode":"bilinear", 23 | "training_number": 50, 24 | "testing_number": 650, 25 | 26 | "input_activations": true 27 | } -------------------------------------------------------------------------------- /experiments/cat_15/ddpm.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/cat_15/ddpm", 3 | "model_type": "ddpm", 4 | 5 | "category": "cat_15", 6 | "number_class": 15, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/cat_15/real/train", 10 | "validation_path": "datasets/cat_15/ddpm", 11 | "testing_path": "datasets/cat_15/real/test", 12 | "model_path": "checkpoints/ddpm/lsun_cat.pt", 13 | 14 | "dim": [256, 256, 8448], 15 | "steps": [50, 150, 250], 16 | "blocks": [5, 6, 7, 8, 12], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 30, 23 | "testing_number": 20, 24 | 25 | "upsample_mode":"bilinear", 26 | "share_noise": true, 27 | "input_activations": false 28 | } 29 | -------------------------------------------------------------------------------- /experiments/horse_21/mae.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/horse_21/mae", 3 | "model_type": "mae", 4 | 5 | "category": "horse_21", 6 | "number_class": 21, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/horse_21/real/train", 10 | "validation_path": "datasets/horse_21/ddpm", 11 | "testing_path": "datasets/horse_21/real/test", 12 | "model_path": "checkpoints/mae/lsun_horse.pth", 13 | 14 | "dim": [256, 256, 12288], 15 | "steps": [], 16 | "blocks": [], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 30, 23 | "testing_number": 30, 24 | 25 | "upsample_mode":"bilinear", 26 | "input_activations": false 27 | } 28 | -------------------------------------------------------------------------------- /experiments/ade_bedroom_30/swav_w2.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/ade_bedroom_30/swav_w2", 3 | "model_type": "swav_w2", 4 | 5 | "category": "ade_bedroom_30", 6 | "number_class": 30, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/ade_bedroom_30/real/train", 10 | "validation_path": "datasets/ade_bedroom_30/real/train", 11 | "testing_path": "datasets/ade_bedroom_30/real/test", 12 | "model_path": "checkpoints/swav_w2/lsun_bedroom.pth", 13 | 14 | "dim": [256, 256, 13440], 15 | "steps": [], 16 | "blocks": [], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "upsample_mode":"bilinear", 23 | "training_number": 50, 24 | "testing_number": 650, 25 | 26 | "input_activations": true 27 | } -------------------------------------------------------------------------------- /experiments/celeba_19/ddpm.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/celeba_19/ddpm", 3 | "model_type": "ddpm", 4 | 5 | "category": "celeba_19", 6 | "number_class": 19, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/celeba_19/real/train", 10 | "validation_path": "datasets/celeba_19/real/train", 11 | "testing_path": "datasets/celeba_19/real/test", 12 | "model_path": "checkpoints/ddpm/ffhq.pt", 13 | 14 | "dim": [256, 256, 8448], 15 | "steps": [50, 150, 250], 16 | "blocks": [5, 6, 7, 8, 12], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 20, 23 | "testing_number": 500, 24 | 25 | "upsample_mode":"bilinear", 26 | "share_noise": true, 27 | "input_activations": false 28 | } -------------------------------------------------------------------------------- /experiments/horse_21/swav.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/horse_21/swav", 3 | "model_type": "swav", 4 | 5 | "category": "horse_21", 6 | "number_class": 21, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/horse_21/real/train", 10 | "validation_path": "datasets/horse_21/ddpm", 11 | "testing_path": "datasets/horse_21/real/test", 12 | "model_path": "checkpoints/swav/lsun_horse.pth", 13 | 14 | "dim": [256, 256, 6720], 15 | "steps": [], 16 | "blocks": [], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 30, 23 | "testing_number": 30, 24 | 25 | "upsample_mode":"bilinear", 26 | "input_activations": true 27 | } 28 | -------------------------------------------------------------------------------- /experiments/horse_21/swav_w2.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/horse_21/swav_w2", 3 | "model_type": "swav_w2", 4 | 5 | "category": "horse_21", 6 | "number_class": 21, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/horse_21/real/train", 10 | "validation_path": "datasets/horse_21/ddpm", 11 | "testing_path": "datasets/horse_21/real/test", 12 | "model_path": "checkpoints/swav_w2/lsun_horse.pth", 13 | 14 | "dim": [256, 256, 13440], 15 | "steps": [], 16 | "blocks": [], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 30, 23 | "testing_number": 30, 24 | 25 | "upsample_mode":"bilinear", 26 | "input_activations": true 27 | } 28 | -------------------------------------------------------------------------------- /experiments/bedroom_28/ddpm.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/bedroom_28/ddpm", 3 | "model_type": "ddpm", 4 | 5 | "category": "bedroom_28", 6 | "number_class": 29, 7 | "ignore_label": 0, 8 | 9 | "training_path": "datasets/bedroom_28/real/train", 10 | "validation_path": "datasets/bedroom_28/ddpm", 11 | "testing_path": "datasets/bedroom_28/real/test", 12 | "model_path": "checkpoints/ddpm/lsun_bedroom.pt", 13 | 14 | "dim": [256, 256, 8448], 15 | "steps": [50, 150, 250], 16 | "blocks": [5, 6, 7, 8, 12], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 40, 23 | "testing_number": 20, 24 | 25 | "upsample_mode":"bilinear", 26 | "share_noise": true, 27 | "input_activations": false 28 | } 29 | 30 | -------------------------------------------------------------------------------- /experiments/cat_15/datasetDDPM.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/cat_15/datasetDDPM", 3 | "model_type": "ddpm", 4 | 5 | "category": "cat_15", 6 | "number_class": 15, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/cat_15/ddpm", 10 | "validation_path": "datasets/cat_15/real/train", 11 | "testing_path": "datasets/cat_15/real/test", 12 | "model_path": "checkpoints/ddpm/lsun_cat.pt", 13 | 14 | "dim": [256, 256, 8448], 15 | "steps": [50, 150, 250], 16 | "blocks": [5, 6, 7, 8, 12], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 30, 23 | "testing_number": 20, 24 | 25 | "deeplab_res": 256, 26 | "upsample_mode":"bilinear", 27 | "share_noise": true, 28 | "input_activations": false 29 | } 30 | -------------------------------------------------------------------------------- /experiments/ade_bedroom_30/ddpm.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/ade_bedroom_30/ddpm", 3 | "model_type": "ddpm", 4 | 5 | "category": "ade_bedroom_30", 6 | "number_class": 30, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/ade_bedroom_30/real/train", 10 | "validation_path": "datasets/ade_bedroom_30/real/train", 11 | "testing_path": "datasets/ade_bedroom_30/real/test", 12 | "model_path": "checkpoints/ddpm/lsun_bedroom.pt", 13 | 14 | "dim": [256, 256, 8448], 15 | "steps": [50, 150, 250], 16 | "blocks": [5, 6, 7, 8, 12], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "upsample_mode":"bilinear", 23 | "training_number": 50, 24 | "testing_number": 650, 25 | 26 | "share_noise": true, 27 | "input_activations": false 28 | } -------------------------------------------------------------------------------- /experiments/horse_21/datasetDDPM.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/horse_21/datasetDDPM", 3 | "model_type": "ddpm", 4 | 5 | "category": "horse_21", 6 | "number_class": 21, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/horse_21/ddpm", 10 | "validation_path": "datasets/horse_21/real/train", 11 | "testing_path": "datasets/horse_21/real/test", 12 | "model_path": "checkpoints/ddpm/lsun_horse.pt", 13 | 14 | "dim": [256, 256, 8448], 15 | "steps": [50, 150, 250], 16 | "blocks": [5, 6, 7, 8, 12], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 30, 23 | "testing_number": 30, 24 | 25 | "deeplab_res": 256, 26 | "upsample_mode":"bilinear", 27 | "share_noise": true, 28 | "input_activations": false 29 | } 30 | -------------------------------------------------------------------------------- /experiments/bedroom_28/datasetDDPM.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/bedroom_28/datasetDDPM", 3 | "model_type": "ddpm", 4 | 5 | "category": "bedroom_28", 6 | "number_class": 29, 7 | "ignore_label": 0, 8 | 9 | "training_path": "datasets/bedroom_28/ddpm", 10 | "validation_path": "datasets/bedroom_28/real/train", 11 | "testing_path": "datasets/bedroom_28/real/test", 12 | "model_path": "checkpoints/ddpm/lsun_bedroom.pt", 13 | 14 | "dim": [256, 256, 8448], 15 | "steps": [50, 150, 250], 16 | "blocks": [5, 6, 7, 8, 12], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 40, 23 | "testing_number": 20, 24 | 25 | "deeplab_res": 256, 26 | "upsample_mode":"bilinear", 27 | "share_noise": true, 28 | "input_activations": false 29 | } -------------------------------------------------------------------------------- /experiments/horse_21/ddpm.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/horse_21/ddpm", 3 | "model_type": "ddpm", 4 | 5 | "category": "horse_21", 6 | "number_class": 21, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/horse_21/real/train", 10 | "validation_path": "datasets/horse_21/ddpm", 11 | "testing_path": "datasets/horse_21/real/test", 12 | "model_path": "checkpoints/ddpm/lsun_horse.pt", 13 | 14 | "dim": [256, 256, 8448], 15 | "steps": [50, 150, 250], 16 | "blocks": [5, 6, 7, 8, 12], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 30, 23 | "testing_number": 30, 24 | 25 | "upsample_mode":"bilinear", 26 | "share_noise": true, 27 | "input_activations": false 28 | } 29 | -------------------------------------------------------------------------------- /experiments/ffhq_34/datasetDDPM.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "pixel_classifiers/ffhq_34/datasetDDPM", 3 | "model_type": "ddpm", 4 | 5 | "category": "ffhq_34", 6 | "number_class": 34, 7 | "ignore_label": 255, 8 | 9 | "training_path": "datasets/ffhq_34/ddpm", 10 | "validation_path": "datasets/ffhq_34/real/train", 11 | "testing_path": "datasets/ffhq_34/real/test", 12 | "model_path": "checkpoints/ddpm/ffhq.pt", 13 | 14 | "dim": [256, 256, 8448], 15 | "steps": [50, 150, 250], 16 | "blocks": [5, 6, 7, 8, 12], 17 | 18 | "model_num": 10, 19 | "batch_size": 64, 20 | "max_training": 30, 21 | 22 | "training_number": 20, 23 | "testing_number": 20, 24 | 25 | "deeplab_res": 256, 26 | "upsample_mode":"bilinear", 27 | "share_noise": true, 28 | "input_activations": false 29 | } -------------------------------------------------------------------------------- /scripts/datasetDDPM/generate_dataset.sh: -------------------------------------------------------------------------------- 1 | # Notes: 2 | # * This setting is used to load a single machine with 8xA100 80Gb. Please adjust the arguments for your infrastracture. 3 | # * The synthetic dataset is saved to the experiment folder, where the ensemble models are placed. (Produced by train_interpreter.sh) 4 | 5 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --dropout 0.1 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 6 | SAMPLE_FLAGS="--batch_size 100 --num_samples 10000 --timestep_respacing 1000 --use_ddim False" 7 | 8 | NUM_GPUS="8" 9 | DATASET=$1 # Available datasets: bedroom_28, ffhq_34, cat_15, horse_21 10 | 11 | echo "Generating a synthetic dataset..." 12 | python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS generate_dataset.py --exp experiments/${DATASET}/datasetDDPM.json $MODEL_FLAGS $SAMPLE_FLAGS 13 | 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yandex Research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/datasetDDPM/train_deeplab.sh: -------------------------------------------------------------------------------- 1 | # Arguments: 2 | 3 | # --data_path: your path to the synthetic data 4 | # Our released dataset: synthetic_datasets/bedroom_28/ddpm/samples_256x256x3.npz 5 | # Produced by generate_dataset.sh: set nothing or "". By default, it uses the synthetic dataset from the experiment directory. 6 | 7 | # --max_data: number of synthetic images to use for training (default: 50000). 8 | # One can consider increasing it upto 50000 to get extra 2-4% of mIoU. 9 | 10 | # --uncertainty_portion: a portion of samples with most uncertain predictions to remove (default: 0.1) 11 | # 0.2-0.25 sometimes can provide slightly better performance. 12 | 13 | # Note: One can use this script for evaluation as well. 14 | # The evaluation is performed right after the training. 15 | # The script checks whether all checkpoints are available. 16 | # If yes, the evaluation starts immediately without retraining. 17 | 18 | DATASET=$1 # Available datasets: bedroom_28, ffhq_34, cat_15, horse_21 19 | 20 | CUDA_VISIBLE_DEVICES=7 python train_deeplab.py \ 21 | --data_path synthetic_datasets/ddpm/${DATASET}/samples_256x256x3.npz \ 22 | --max_data 50000 \ 23 | --uncertainty_portion 0.1 \ 24 | --exp experiments/${DATASET}/datasetDDPM.json -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2021 NVIDIA Corporation. All rights reserved. 3 | Licensed under The MIT License (MIT) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | """ 22 | 23 | import torch 24 | from PIL import Image 25 | import numpy as np 26 | import random 27 | 28 | 29 | def multi_acc(y_pred, y_test): 30 | y_pred_softmax = torch.log_softmax(y_pred, dim=1) 31 | _, y_pred_tags = torch.max(y_pred_softmax, dim=1) 32 | 33 | correct_pred = (y_pred_tags == y_test).float() 34 | acc = correct_pred.sum() / len(correct_pred) 35 | 36 | acc = acc * 100 37 | 38 | return acc 39 | 40 | 41 | def oht_to_scalar(y_pred): 42 | y_pred_softmax = torch.log_softmax(y_pred, dim=1) 43 | _, y_pred_tags = torch.max(y_pred_softmax, dim=1) 44 | 45 | return y_pred_tags 46 | 47 | 48 | def colorize_mask(mask, palette): 49 | # mask: numpy array of the mask 50 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 51 | new_mask.putpalette(palette) 52 | return np.array(new_mask.convert('RGB')) 53 | 54 | 55 | def to_labels(masks, palette): 56 | results = np.zeros((len(masks), 256, 256), dtype=np.int32) 57 | label = 0 58 | for color in palette: 59 | idxs = np.where((masks == color).all(-1)) 60 | results[idxs] = label 61 | label += 1 62 | return results 63 | 64 | 65 | def setup_seed(seed): 66 | print('Seed: ', seed) 67 | random.seed(seed) 68 | np.random.seed(seed) 69 | torch.manual_seed(seed) 70 | torch.cuda.manual_seed_all(seed) -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | from guided_diffusion.guided_diffusion.image_datasets import _list_image_files_recursively 8 | 9 | 10 | def make_transform(model_type: str, resolution: int): 11 | """ Define input transforms for pretrained models """ 12 | if model_type == 'ddpm': 13 | transform = transforms.Compose([ 14 | transforms.Resize(resolution), 15 | transforms.ToTensor(), 16 | lambda x: 2 * x - 1 17 | ]) 18 | elif model_type in ['mae', 'swav', 'swav_w2', 'deeplab']: 19 | transform = transforms.Compose([ 20 | transforms.Resize(resolution), 21 | transforms.ToTensor(), 22 | transforms.Normalize( 23 | mean=[0.485, 0.456, 0.406], 24 | std=[0.229, 0.224, 0.225] 25 | ) 26 | ]) 27 | else: 28 | raise Exception(f"Wrong model type: {model_type}") 29 | return transform 30 | 31 | 32 | class FeatureDataset(Dataset): 33 | ''' 34 | Dataset of the pixel representations and their labels. 35 | 36 | :param X_data: pixel representations [num_pixels, feature_dim] 37 | :param y_data: pixel labels [num_pixels] 38 | ''' 39 | def __init__( 40 | self, 41 | X_data: torch.Tensor, 42 | y_data: torch.Tensor 43 | ): 44 | self.X_data = X_data 45 | self.y_data = y_data 46 | 47 | def __getitem__(self, index): 48 | return self.X_data[index], self.y_data[index] 49 | 50 | def __len__(self): 51 | return len(self.X_data) 52 | 53 | 54 | class ImageLabelDataset(Dataset): 55 | ''' 56 | :param data_dir: path to a folder with images and their annotations. 57 | Annotations are supposed to be in *.npy format. 58 | :param resolution: image and mask output resolution. 59 | :param num_images: restrict a number of images in the dataset. 60 | :param transform: image transforms. 61 | ''' 62 | def __init__( 63 | self, 64 | data_dir: str, 65 | resolution: int, 66 | num_images= -1, 67 | transform=None, 68 | ): 69 | super().__init__() 70 | self.resolution = resolution 71 | self.transform = transform 72 | self.image_paths = _list_image_files_recursively(data_dir) 73 | self.image_paths = sorted(self.image_paths) 74 | 75 | if num_images > 0: 76 | print(f"Take first {num_images} images...") 77 | self.image_paths = self.image_paths[:num_images] 78 | 79 | self.label_paths = [ 80 | '.'.join(image_path.split('.')[:-1] + ['npy']) 81 | for image_path in self.image_paths 82 | ] 83 | 84 | def __len__(self): 85 | return len(self.image_paths) 86 | 87 | def __getitem__(self, idx): 88 | # Load an image 89 | image_path = self.image_paths[idx] 90 | pil_image = Image.open(image_path) 91 | pil_image = pil_image.convert("RGB") 92 | assert pil_image.size[0] == pil_image.size[1], \ 93 | f"Only square images are supported: ({pil_image.size[0]}, {pil_image.size[1]})" 94 | 95 | tensor_image = self.transform(pil_image) 96 | # Load a corresponding mask and resize it to (self.resolution, self.resolution) 97 | label_path = self.label_paths[idx] 98 | label = np.load(label_path).astype('uint8') 99 | label = cv2.resize( 100 | label, (self.resolution, self.resolution), interpolation=cv2.INTER_NEAREST 101 | ) 102 | tensor_label = torch.from_numpy(label) 103 | return tensor_image, tensor_label 104 | 105 | 106 | class InMemoryImageLabelDataset(Dataset): 107 | ''' 108 | 109 | Same as ImageLabelDataset but images and labels are already loaded into RAM. 110 | It handles DDPM/GAN-produced datasets and is used to train DeepLabV3. 111 | 112 | :param images: np.array of image samples [num_images, H, W, 3]. 113 | :param labels: np.array of correspoding masks [num_images, H, W]. 114 | :param resolution: image and mask output resolusion. 115 | :param num_images: restrict a number of images in the dataset. 116 | :param transform: image transforms. 117 | ''' 118 | 119 | def __init__( 120 | self, 121 | images: np.ndarray, 122 | labels: np.ndarray, 123 | resolution=256, 124 | transform=None 125 | ): 126 | super().__init__() 127 | assert len(images) == len(labels) 128 | self.images = images 129 | self.labels = labels 130 | self.resolution = resolution 131 | self.transform = transform 132 | 133 | def __len__(self): 134 | return len(self.images) 135 | 136 | def __getitem__(self, idx): 137 | image = Image.fromarray(self.images[idx]) 138 | assert image.size[0] == image.size[1], \ 139 | f"Only square images are supported: ({image.size[0]}, {image.size[1]})" 140 | 141 | tensor_image = self.transform(image) 142 | label = self.labels[idx] 143 | label = cv2.resize( 144 | label, (self.resolution, self.resolution), interpolation=cv2.INTER_NEAREST 145 | ) 146 | tensor_label = torch.from_numpy(label) 147 | return tensor_image, tensor_label 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /generate_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | from tqdm import tqdm 10 | 11 | import json 12 | import random 13 | import numpy as np 14 | import torch as th 15 | import torch.distributed as dist 16 | 17 | 18 | from guided_diffusion.guided_diffusion import dist_util, logger 19 | from guided_diffusion.guided_diffusion.script_util import ( 20 | model_and_diffusion_defaults, 21 | add_dict_to_argparser 22 | ) 23 | 24 | from src.pixel_classifier import load_ensemble, predict_labels 25 | from src.feature_extractors import create_feature_extractor, collect_features 26 | 27 | 28 | def setup_dist(local_rank): 29 | dist.init_process_group(backend='nccl', init_method='env://') 30 | th.cuda.set_device(local_rank) 31 | 32 | 33 | def save_samples(num_samples, all_images, all_img_segs, all_uncertainties): 34 | arr = np.concatenate(all_images, axis=0).astype('uint8') 35 | arr = arr[: num_samples] 36 | 37 | seg_arr = np.concatenate(all_img_segs, axis=0).astype('uint8') 38 | seg_arr = seg_arr[: num_samples] 39 | 40 | uncertainties = np.concatenate(all_uncertainties, axis=0) 41 | uncertainties = uncertainties[: num_samples] 42 | 43 | if dist.get_rank() == 0: 44 | shape_str = "x".join([str(x) for x in arr.shape[1:]]) 45 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") 46 | logger.log(f"saving to {out_path}") 47 | np.savez(out_path, arr, seg_arr, uncertainties) 48 | 49 | 50 | def main(): 51 | args = create_argparser().parse_args() 52 | opts = json.load(open(args.exp, 'r')) 53 | opts.update(vars(args)) 54 | opts['image_size'] = opts['dim'][0] 55 | 56 | if len(opts['steps']) > 0: 57 | suffix = '_'.join([str(step) for step in opts['steps']]) 58 | suffix += '_' + '_'.join([str(step) for step in opts['blocks']]) 59 | opts['exp_dir'] = os.path.join(opts['exp_dir'], suffix) 60 | 61 | os.environ['OPENAI_LOGDIR'] = opts['exp_dir'] 62 | setup_dist(args.local_rank) 63 | logger.configure() 64 | feature_extractor = create_feature_extractor(**opts) 65 | model, diffusion = feature_extractor.model, feature_extractor.diffusion 66 | 67 | logger.log("loading pretrained classifiers...") 68 | classifiers = load_ensemble(opts, device=dist_util.dev()) 69 | 70 | logger.log("Sample noise for feature extraction...") 71 | if opts['share_noise']: 72 | rnd_gen = th.Generator(device=dist_util.dev()).manual_seed(args.seed) 73 | seg_noise = th.randn(1, 3, opts['image_size'], opts['image_size'], 74 | generator=rnd_gen, device=dist_util.dev()) 75 | else: 76 | seg_noise = None 77 | 78 | logger.log("sampling...") 79 | all_images = [] 80 | all_img_segs = [] 81 | all_uncertainties = [] 82 | 83 | while len(all_images) * args.batch_size < args.num_samples: 84 | sample_fn = diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop 85 | output = sample_fn( 86 | model, 87 | (args.batch_size, 3, opts['image_size'], opts['image_size']), 88 | clip_denoised=args.clip_denoised 89 | ) 90 | 91 | logger.log("predicting segmentation...") 92 | img_segs = th.zeros(args.batch_size, opts['image_size'], opts['image_size']) 93 | img_segs = img_segs.to(th.uint8).to(dist_util.dev()) 94 | uncertainties = th.zeros(args.batch_size).to(dist_util.dev()) 95 | 96 | for sample_idx in tqdm(range(args.batch_size)): 97 | img = output[sample_idx][None].clamp(-1, 1) 98 | features = feature_extractor(img, noise=seg_noise) 99 | features = collect_features(opts, features) 100 | 101 | x = features.view(opts['dim'][-1], -1).permute(1, 0) 102 | img_seg, uncertainty = predict_labels( 103 | classifiers, x, size=opts['dim'][:-1] 104 | ) 105 | img_segs[sample_idx] = img_seg.to(th.uint8) 106 | uncertainties[sample_idx] = uncertainty.item() 107 | 108 | sample = ((output + 1) * 127.5).clamp(0, 255).to(th.uint8) 109 | sample = sample.permute(0, 2, 3, 1) 110 | sample = sample.contiguous() 111 | 112 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 113 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 114 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) 115 | 116 | gathered_img_segs = [th.zeros_like(img_segs) for _ in range(dist.get_world_size())] 117 | dist.all_gather(gathered_img_segs, img_segs) # gather not supported with NCCL 118 | all_img_segs.extend([img_seg.cpu().numpy() for img_seg in gathered_img_segs]) 119 | 120 | gathered_uncertainties = [th.zeros_like(uncertainties) for _ in range(dist.get_world_size())] 121 | dist.all_gather(gathered_uncertainties, uncertainties) # gather not supported with NCCL 122 | all_uncertainties.extend([uncertainty.cpu().numpy() for uncertainty in gathered_uncertainties]) 123 | 124 | logger.log(f"created {len(all_images) * args.batch_size} samples") 125 | save_samples(args.num_samples, all_images, all_img_segs, all_uncertainties) 126 | 127 | dist.barrier() 128 | logger.log("sampling complete") 129 | 130 | 131 | def create_argparser(): 132 | parser = argparse.ArgumentParser() 133 | add_dict_to_argparser(parser, model_and_diffusion_defaults()) 134 | 135 | parser.add_argument('--exp', type=str) 136 | parser.add_argument('--seed', type=int, default=0) 137 | parser.add_argument('--batch_size', type=int, default=100) 138 | parser.add_argument('--num_samples', type=int, default=10000) 139 | parser.add_argument('--local_rank', type=int) 140 | 141 | parser.add_argument('--clip_denoised', type=bool, default=True) 142 | parser.add_argument('--use_ddim', type=bool, default=False) 143 | return parser 144 | 145 | 146 | if __name__ == "__main__": 147 | main() 148 | -------------------------------------------------------------------------------- /src/pixel_classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from collections import Counter 6 | 7 | from torch.distributions import Categorical 8 | from src.utils import colorize_mask, oht_to_scalar 9 | from src.data_util import get_palette, get_class_names 10 | from PIL import Image 11 | 12 | 13 | # Adopted from https://github.com/nv-tlabs/datasetGAN_release/blob/d9564d4d2f338eaad78132192b865b6cc1e26cac/datasetGAN/train_interpreter.py#L68 14 | class pixel_classifier(nn.Module): 15 | def __init__(self, numpy_class, dim): 16 | super(pixel_classifier, self).__init__() 17 | if numpy_class < 30: 18 | self.layers = nn.Sequential( 19 | nn.Linear(dim, 128), 20 | nn.ReLU(), 21 | nn.BatchNorm1d(num_features=128), 22 | nn.Linear(128, 32), 23 | nn.ReLU(), 24 | nn.BatchNorm1d(num_features=32), 25 | nn.Linear(32, numpy_class) 26 | ) 27 | else: 28 | self.layers = nn.Sequential( 29 | nn.Linear(dim, 256), 30 | nn.ReLU(), 31 | nn.BatchNorm1d(num_features=256), 32 | nn.Linear(256, 128), 33 | nn.ReLU(), 34 | nn.BatchNorm1d(num_features=128), 35 | nn.Linear(128, numpy_class) 36 | ) 37 | 38 | def init_weights(self, init_type='normal', gain=0.02): 39 | ''' 40 | initialize network's weights 41 | init_type: normal | xavier | kaiming | orthogonal 42 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 43 | ''' 44 | 45 | def init_func(m): 46 | classname = m.__class__.__name__ 47 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 48 | if init_type == 'normal': 49 | nn.init.normal_(m.weight.data, 0.0, gain) 50 | elif init_type == 'xavier': 51 | nn.init.xavier_normal_(m.weight.data, gain=gain) 52 | elif init_type == 'kaiming': 53 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 54 | elif init_type == 'orthogonal': 55 | nn.init.orthogonal_(m.weight.data, gain=gain) 56 | 57 | if hasattr(m, 'bias') and m.bias is not None: 58 | nn.init.constant_(m.bias.data, 0.0) 59 | 60 | elif classname.find('BatchNorm2d') != -1: 61 | nn.init.normal_(m.weight.data, 1.0, gain) 62 | nn.init.constant_(m.bias.data, 0.0) 63 | 64 | self.apply(init_func) 65 | 66 | def forward(self, x): 67 | return self.layers(x) 68 | 69 | 70 | def predict_labels(models, features, size): 71 | if isinstance(features, np.ndarray): 72 | features = torch.from_numpy(features) 73 | 74 | mean_seg = None 75 | all_seg = [] 76 | all_entropy = [] 77 | seg_mode_ensemble = [] 78 | 79 | softmax_f = nn.Softmax(dim=1) 80 | with torch.no_grad(): 81 | for MODEL_NUMBER in range(len(models)): 82 | preds = models[MODEL_NUMBER](features.cuda()) 83 | entropy = Categorical(logits=preds).entropy() 84 | all_entropy.append(entropy) 85 | all_seg.append(preds) 86 | 87 | if mean_seg is None: 88 | mean_seg = softmax_f(preds) 89 | else: 90 | mean_seg += softmax_f(preds) 91 | 92 | img_seg = oht_to_scalar(preds) 93 | img_seg = img_seg.reshape(*size) 94 | img_seg = img_seg.cpu().detach() 95 | 96 | seg_mode_ensemble.append(img_seg) 97 | 98 | mean_seg = mean_seg / len(all_seg) 99 | 100 | full_entropy = Categorical(mean_seg).entropy() 101 | 102 | js = full_entropy - torch.mean(torch.stack(all_entropy), 0) 103 | top_k = js.sort()[0][- int(js.shape[0] / 10):].mean() 104 | 105 | img_seg_final = torch.stack(seg_mode_ensemble, dim=-1) 106 | img_seg_final = torch.mode(img_seg_final, 2)[0] 107 | return img_seg_final, top_k 108 | 109 | 110 | def save_predictions(args, image_paths, preds): 111 | palette = get_palette(args['category']) 112 | os.makedirs(os.path.join(args['exp_dir'], 'predictions'), exist_ok=True) 113 | os.makedirs(os.path.join(args['exp_dir'], 'visualizations'), exist_ok=True) 114 | 115 | for i, pred in enumerate(preds): 116 | filename = image_paths[i].split('/')[-1].split('.')[0] 117 | pred = np.squeeze(pred) 118 | np.save(os.path.join(args['exp_dir'], 'predictions', filename + '.npy'), pred) 119 | 120 | mask = colorize_mask(pred, palette) 121 | Image.fromarray(mask).save( 122 | os.path.join(args['exp_dir'], 'visualizations', filename + '.jpg') 123 | ) 124 | 125 | 126 | def compute_iou(args, preds, gts, print_per_class_ious=True): 127 | class_names = get_class_names(args['category']) 128 | 129 | ids = range(args['number_class']) 130 | 131 | unions = Counter() 132 | intersections = Counter() 133 | 134 | for pred, gt in zip(preds, gts): 135 | for target_num in ids: 136 | if target_num == args['ignore_label']: 137 | continue 138 | preds_tmp = (pred == target_num).astype(int) 139 | gts_tmp = (gt == target_num).astype(int) 140 | unions[target_num] += (preds_tmp | gts_tmp).sum() 141 | intersections[target_num] += (preds_tmp & gts_tmp).sum() 142 | 143 | ious = [] 144 | for target_num in ids: 145 | if target_num == args['ignore_label']: 146 | continue 147 | iou = intersections[target_num] / (1e-8 + unions[target_num]) 148 | ious.append(iou) 149 | if print_per_class_ious: 150 | print(f"IOU for {class_names[target_num]} {iou:.4}") 151 | return np.array(ious).mean() 152 | 153 | 154 | def load_ensemble(args, device='cpu'): 155 | models = [] 156 | for i in range(args['model_num']): 157 | model_path = os.path.join(args['exp_dir'], f'model_{i}.pth') 158 | state_dict = torch.load(model_path)['model_state_dict'] 159 | model = nn.DataParallel(pixel_classifier(args["number_class"], args['dim'][-1])) 160 | model.load_state_dict(state_dict) 161 | model = model.module.to(device) 162 | models.append(model.eval()) 163 | return models 164 | -------------------------------------------------------------------------------- /train_interpreter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from tqdm import tqdm 4 | import json 5 | import os 6 | import gc 7 | 8 | from torch.utils.data import DataLoader 9 | 10 | import argparse 11 | from src.utils import setup_seed, multi_acc 12 | from src.pixel_classifier import load_ensemble, compute_iou, predict_labels, save_predictions, save_predictions, pixel_classifier 13 | from src.datasets import ImageLabelDataset, FeatureDataset, make_transform 14 | from src.feature_extractors import create_feature_extractor, collect_features 15 | 16 | from guided_diffusion.guided_diffusion.script_util import model_and_diffusion_defaults, add_dict_to_argparser 17 | from guided_diffusion.guided_diffusion.dist_util import dev 18 | 19 | 20 | def prepare_data(args): 21 | feature_extractor = create_feature_extractor(**args) 22 | 23 | print(f"Preparing the train set for {args['category']}...") 24 | dataset = ImageLabelDataset( 25 | data_dir=args['training_path'], 26 | resolution=args['image_size'], 27 | num_images=args['training_number'], 28 | transform=make_transform( 29 | args['model_type'], 30 | args['image_size'] 31 | ) 32 | ) 33 | X = torch.zeros((len(dataset), *args['dim'][::-1]), dtype=torch.float) 34 | y = torch.zeros((len(dataset), *args['dim'][:-1]), dtype=torch.uint8) 35 | 36 | if 'share_noise' in args and args['share_noise']: 37 | rnd_gen = torch.Generator(device=dev()).manual_seed(args['seed']) 38 | noise = torch.randn(1, 3, args['image_size'], args['image_size'], 39 | generator=rnd_gen, device=dev()) 40 | else: 41 | noise = None 42 | 43 | for row, (img, label) in enumerate(tqdm(dataset)): 44 | img = img[None].to(dev()) 45 | features = feature_extractor(img, noise=noise) 46 | X[row] = collect_features(args, features).cpu() 47 | 48 | for target in range(args['number_class']): 49 | if target == args['ignore_label']: continue 50 | if 0 < (label == target).sum() < 20: 51 | print(f'Delete small annotation from image {dataset.image_paths[row]} | label {target}') 52 | label[label == target] = args['ignore_label'] 53 | y[row] = label 54 | 55 | d = X.shape[1] 56 | print(f'Total dimension {d}') 57 | X = X.permute(1,0,2,3).reshape(d, -1).permute(1, 0) 58 | y = y.flatten() 59 | return X[y != args['ignore_label']], y[y != args['ignore_label']] 60 | 61 | 62 | def evaluation(args, models): 63 | feature_extractor = create_feature_extractor(**args) 64 | dataset = ImageLabelDataset( 65 | data_dir=args['testing_path'], 66 | resolution=args['image_size'], 67 | num_images=args['testing_number'], 68 | transform=make_transform( 69 | args['model_type'], 70 | args['image_size'] 71 | ) 72 | ) 73 | 74 | if 'share_noise' in args and args['share_noise']: 75 | rnd_gen = torch.Generator(device=dev()).manual_seed(args['seed']) 76 | noise = torch.randn(1, 3, args['image_size'], args['image_size'], 77 | generator=rnd_gen, device=dev()) 78 | else: 79 | noise = None 80 | 81 | preds, gts, uncertainty_scores = [], [], [] 82 | for img, label in tqdm(dataset): 83 | img = img[None].to(dev()) 84 | features = feature_extractor(img, noise=noise) 85 | features = collect_features(args, features) 86 | 87 | x = features.view(args['dim'][-1], -1).permute(1, 0) 88 | pred, uncertainty_score = predict_labels( 89 | models, x, size=args['dim'][:-1] 90 | ) 91 | gts.append(label.numpy()) 92 | preds.append(pred.numpy()) 93 | uncertainty_scores.append(uncertainty_score.item()) 94 | 95 | save_predictions(args, dataset.image_paths, preds) 96 | miou = compute_iou(args, preds, gts) 97 | print(f'Overall mIoU: ', miou) 98 | print(f'Mean uncertainty: {sum(uncertainty_scores) / len(uncertainty_scores)}') 99 | 100 | 101 | # Adopted from https://github.com/nv-tlabs/datasetGAN_release/blob/d9564d4d2f338eaad78132192b865b6cc1e26cac/datasetGAN/train_interpreter.py#L434 102 | def train(args): 103 | features, labels = prepare_data(args) 104 | train_data = FeatureDataset(features, labels) 105 | 106 | print(f" ********* max_label {args['number_class']} *** ignore_label {args['ignore_label']} ***********") 107 | print(f" *********************** Current number data {len(features)} ***********************") 108 | 109 | train_loader = DataLoader(dataset=train_data, batch_size=args['batch_size'], shuffle=True, drop_last=True) 110 | 111 | print(" *********************** Current dataloader length " + str(len(train_loader)) + " ***********************") 112 | for MODEL_NUMBER in range(args['start_model_num'], args['model_num'], 1): 113 | 114 | gc.collect() 115 | classifier = pixel_classifier(numpy_class=(args['number_class']), dim=args['dim'][-1]) 116 | classifier.init_weights() 117 | 118 | classifier = nn.DataParallel(classifier).cuda() 119 | criterion = nn.CrossEntropyLoss() 120 | optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001) 121 | classifier.train() 122 | 123 | iteration = 0 124 | break_count = 0 125 | best_loss = 10000000 126 | stop_sign = 0 127 | for epoch in range(100): 128 | for X_batch, y_batch in train_loader: 129 | X_batch, y_batch = X_batch.to(dev()), y_batch.to(dev()) 130 | y_batch = y_batch.type(torch.long) 131 | 132 | optimizer.zero_grad() 133 | y_pred = classifier(X_batch) 134 | loss = criterion(y_pred, y_batch) 135 | acc = multi_acc(y_pred, y_batch) 136 | 137 | loss.backward() 138 | optimizer.step() 139 | 140 | iteration += 1 141 | if iteration % 1000 == 0: 142 | print('Epoch : ', str(epoch), 'iteration', iteration, 'loss', loss.item(), 'acc', acc) 143 | 144 | if epoch > 3: 145 | if loss.item() < best_loss: 146 | best_loss = loss.item() 147 | break_count = 0 148 | else: 149 | break_count += 1 150 | 151 | if break_count > 50: 152 | stop_sign = 1 153 | print("*************** Break, Total iters,", iteration, ", at epoch", str(epoch), "***************") 154 | break 155 | 156 | if stop_sign == 1: 157 | break 158 | 159 | model_path = os.path.join(args['exp_dir'], 160 | 'model_' + str(MODEL_NUMBER) + '.pth') 161 | MODEL_NUMBER += 1 162 | print('save to:',model_path) 163 | torch.save({'model_state_dict': classifier.state_dict()}, 164 | model_path) 165 | 166 | 167 | if __name__ == '__main__': 168 | parser = argparse.ArgumentParser() 169 | add_dict_to_argparser(parser, model_and_diffusion_defaults()) 170 | 171 | parser.add_argument('--exp', type=str) 172 | parser.add_argument('--seed', type=int, default=0) 173 | 174 | args = parser.parse_args() 175 | setup_seed(args.seed) 176 | 177 | # Load the experiment config 178 | opts = json.load(open(args.exp, 'r')) 179 | opts.update(vars(args)) 180 | opts['image_size'] = opts['dim'][0] 181 | 182 | # Prepare the experiment folder 183 | if len(opts['steps']) > 0: 184 | suffix = '_'.join([str(step) for step in opts['steps']]) 185 | suffix += '_' + '_'.join([str(step) for step in opts['blocks']]) 186 | opts['exp_dir'] = os.path.join(opts['exp_dir'], suffix) 187 | 188 | path = opts['exp_dir'] 189 | os.makedirs(path, exist_ok=True) 190 | print('Experiment folder: %s' % (path)) 191 | os.system('cp %s %s' % (args.exp, opts['exp_dir'])) 192 | 193 | # Check whether all models in ensemble are trained 194 | pretrained = [os.path.exists(os.path.join(opts['exp_dir'], f'model_{i}.pth')) 195 | for i in range(opts['model_num'])] 196 | 197 | if not all(pretrained): 198 | # train all remaining models 199 | opts['start_model_num'] = sum(pretrained) 200 | train(opts) 201 | 202 | print('Loading pretrained models...') 203 | models = load_ensemble(opts, device='cuda') 204 | evaluation(opts, models) 205 | -------------------------------------------------------------------------------- /train_deeplab.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torchvision 5 | import torch.nn as nn 6 | from torch.utils.data import DataLoader 7 | 8 | import numpy as np 9 | import glob 10 | import json 11 | 12 | from src.utils import setup_seed 13 | from src.pixel_classifier import compute_iou, save_predictions 14 | from src.datasets import ImageLabelDataset, InMemoryImageLabelDataset, make_transform 15 | 16 | 17 | def eval_checkpoint(ckp_path, model, dataset, args, **kwargs): 18 | """ Evaluate DeepLabV3 checkpoint located in ckp_path. 19 | :param ckp_path: path to the checkpoint (.pth file) 20 | :param model: DeepLabV3 pixel classifier 21 | :param dataset: validation or test dataset 22 | :param args: experiment configuration described in the corresponding json file 23 | """ 24 | checkpoint = torch.load(ckp_path) 25 | model.load_state_dict(checkpoint['model_state_dict']) 26 | model.cuda().eval() 27 | 28 | preds, gts = [], [] 29 | for img, gt in dataset: 30 | with torch.no_grad(): 31 | pred = model(img[None].cuda())['out'] 32 | pred = torch.log_softmax(pred, dim=1) 33 | _, pred = torch.max(pred, dim=1) 34 | pred = pred.cpu().detach().numpy() 35 | preds.append(pred) 36 | gts.append(gt.numpy()) 37 | 38 | save_predictions(args, dataset.image_paths, preds) 39 | miou = compute_iou(args, preds, gts, **kwargs) 40 | return miou 41 | 42 | 43 | # Based on https://github.com/nv-tlabs/datasetGAN_release/blob/d9564d4d2f338eaad78132192b865b6cc1e26cac/datasetGAN/train_deeplab.py#L82 44 | def train(data_path, args, resume, 45 | max_data, uncertainty_portion, 46 | learning_rate, batch_size, num_epoch): 47 | """ Train DeepLabV3 on the DDPM-produced dataset. 48 | :param data_path: path to the synthetic dataset (.npz file) 49 | :param args: experiment configuration described in the corresponding json file 50 | :param resume: path to the checkpoint to resume the training from 51 | 52 | :param max_data: size of the synthetic data 53 | :param uncertainty_portion: portion of samples with most uncertain predictions to remove 54 | """ 55 | arr = np.load(data_path).values() 56 | if len(arr) == 3: 57 | images, labels, uncertainty_scores = arr 58 | else: # Needed to handle datasetGAN 59 | images, labels, latents, uncertainty_scores = arr 60 | 61 | if max_data > 0: 62 | images = images[:max_data] 63 | labels = labels[:max_data] 64 | uncertainty_scores = uncertainty_scores[:max_data] 65 | 66 | if uncertainty_portion > 0: 67 | idxs = np.argsort(uncertainty_scores) 68 | filter_out_num = int(len(idxs) * uncertainty_portion) 69 | idxs = idxs[30: -filter_out_num + 30] 70 | images = images[idxs] 71 | labels = labels[idxs] 72 | 73 | dataset = InMemoryImageLabelDataset( 74 | images=images, 75 | labels=labels, 76 | resolution=args['deeplab_res'], 77 | transform=make_transform( 78 | 'deeplab', args['deeplab_res'] 79 | ) 80 | ) 81 | 82 | train_data = DataLoader(dataset, batch_size=batch_size, num_workers=12, shuffle=True, drop_last=True) 83 | classifier = torchvision.models.segmentation.deeplabv3_resnet101( 84 | pretrained=False, progress=False, num_classes=args['number_class'], aux_loss=None 85 | ) 86 | if resume != "": 87 | checkpoint = torch.load(resume) 88 | start_epoch = int(resume.split('.')[-2].split('_')[-1]) + 1 89 | classifier.load_state_dict(checkpoint['model_state_dict']) 90 | else: 91 | start_epoch = 0 92 | 93 | classifier.cuda() 94 | criterion = nn.CrossEntropyLoss() 95 | optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate) 96 | 97 | for epoch in range(start_epoch, num_epoch, 1): 98 | for i, (img, label) in enumerate(train_data): 99 | classifier.train() 100 | optimizer.zero_grad() 101 | pred = classifier(img.cuda())['out'] 102 | loss = criterion(pred, label.to(torch.long).cuda()) 103 | loss.backward() 104 | optimizer.step() 105 | 106 | if i % 10 == 0: 107 | print(epoch, 'epoch', 'iteration', i, 'loss', loss.item()) 108 | 109 | model_path = os.path.join(base_path, f'deeplab_epoch_{epoch}.pth') 110 | 111 | print('Save to:', model_path) 112 | torch.save({'model_state_dict': classifier.state_dict()}, 113 | model_path) 114 | 115 | 116 | # Based on https://github.com/nv-tlabs/datasetGAN_release/blob/d9564d4d2f338eaad78132192b865b6cc1e26cac/datasetGAN/test_deeplab_cross_validation.py#L262 117 | def test(ckp_path, args): 118 | """ Select the best checkpoint with the highest mIoU on the hold-out validation set and evaluate it on the test set. 119 | :param ckp_path: path to the pretrained DeepLab checkpoints 120 | :param args: experiment configuration described in the corresponding .json file 121 | """ 122 | cps_all = glob.glob(ckp_path + "/*") 123 | ckp_list = sorted([data for data in cps_all if '.pth' in data]) 124 | 125 | classifier = torchvision.models.segmentation.deeplabv3_resnet101( 126 | pretrained=False, progress=False, 127 | num_classes=args['number_class'], aux_loss=None 128 | ) 129 | 130 | val_dataset = ImageLabelDataset( 131 | data_dir=args['validation_path'], 132 | resolution=args['deeplab_res'], 133 | transform=make_transform( 134 | 'deeplab', args['deeplab_res'] 135 | ) 136 | ) 137 | 138 | test_dataset = ImageLabelDataset( 139 | data_dir=args['testing_path'], 140 | resolution=args['deeplab_res'], 141 | transform=make_transform( 142 | 'deeplab', args['deeplab_res'] 143 | ) 144 | ) 145 | 146 | best_val_miou = 0 147 | for resume in ckp_list: 148 | mean_iou_val = eval_checkpoint(resume, classifier, val_dataset, 149 | args, print_per_class_ious=False) 150 | if mean_iou_val > best_val_miou: 151 | best_val_miou = mean_iou_val 152 | best_test_miou = eval_checkpoint(resume, classifier, test_dataset, args) 153 | print("Best IOU ,", str(best_test_miou)) 154 | print("Checkpoint: ", resume) 155 | 156 | print("Validation mIOU:", best_val_miou) 157 | print("Testing mIOU:" , best_test_miou ) 158 | result = {"Validation": best_val_miou, "Testing": best_test_miou} 159 | with open(os.path.join(ckp_path, 'test_val_miou.json'), 'w') as f: 160 | json.dump(result, f) 161 | 162 | 163 | if __name__ == '__main__': 164 | parser = argparse.ArgumentParser() 165 | parser.add_argument('--exp', type=str) 166 | parser.add_argument('--resume', type=str, default="") 167 | parser.add_argument('--data_path', type=str, default="") 168 | parser.add_argument('--seed', type=int, default=0) 169 | 170 | parser.add_argument('--max_data', type=int, default=0) 171 | parser.add_argument('--uncertainty_portion', type=float, default=0.1) 172 | 173 | parser.add_argument('--learning_rate', type=float, default=0.001) 174 | parser.add_argument('--batch_size', type=int, default=8) 175 | parser.add_argument('--num_epoch', type=int, default=20) 176 | 177 | args = parser.parse_args() 178 | setup_seed(args.seed) 179 | 180 | opts = json.load(open(args.exp, 'r')) 181 | opts['image_size'] = opts['dim'][0] 182 | 183 | # Prepare the experiment folder 184 | if len(opts['steps']) > 0: 185 | suffix = '_'.join([str(step) for step in opts['steps']]) 186 | suffix += '_' + '_'.join([str(step) for step in opts['blocks']]) 187 | opts['exp_dir'] = os.path.join(opts['exp_dir'], suffix) 188 | 189 | if not args.data_path: 190 | data_filename = f"samples_{opts['image_size']}x{opts['image_size']}x3.npz" 191 | data_path = os.path.join(opts['exp_dir'], data_filename) 192 | else: 193 | data_path = args.data_path 194 | 195 | base_path = os.path.join( 196 | opts['exp_dir'], "deeplab_class_%d_checkpoint_%d_filter_out_%f" \ 197 | %(opts['number_class'], args.max_data, args.uncertainty_portion) 198 | ) 199 | os.makedirs(base_path, exist_ok=True) 200 | print('Experiment folder: %s' % (base_path)) 201 | 202 | # Check whether DeepLabV3 is trained 203 | pretrained = all([os.path.exists(os.path.join(base_path, f'deeplab_epoch_{i}.pth')) 204 | for i in range(args.num_epoch)]) 205 | 206 | if not pretrained: 207 | print("training DeepLabV3...") 208 | train(data_path, opts, args.resume, 209 | args.max_data, args.uncertainty_portion, 210 | args.learning_rate, args.batch_size, args.num_epoch) 211 | 212 | print("evaluating DeepLabV3...") 213 | test(base_path, opts) 214 | 215 | 216 | -------------------------------------------------------------------------------- /src/data_util.py: -------------------------------------------------------------------------------- 1 | def get_palette(category): 2 | if category == 'ffhq_34': 3 | return ffhq_palette 4 | elif category == 'bedroom_28': 5 | return bedroom_palette 6 | elif category == 'cat_15': 7 | return cat_palette 8 | elif category == 'horse_21': 9 | return horse_palette 10 | elif category == 'ade_bedroom_30': 11 | return ade_bedroom_30_palette 12 | elif category == 'celeba_19': 13 | return celeba_palette 14 | 15 | 16 | def get_class_names(category): 17 | if category == 'ffhq_34': 18 | return ffhq_class 19 | elif category == 'bedroom_28': 20 | return bedroom_class 21 | elif category == 'cat_15': 22 | return cat_class 23 | elif category == 'horse_21': 24 | return horse_class 25 | elif category == 'ade_bedroom_30': 26 | return ade_bedroom_30_class 27 | elif category == 'celeba_19': 28 | return celeba_class 29 | 30 | 31 | ############### 32 | # Class names # 33 | ############### 34 | 35 | 36 | bedroom_class = ['background', 'bed', 'bed***footboard', 'bed***headboard', 'bed***side rail', 37 | 'carpet', 'ceiling', 'chandelier / ceiling fan blade', 'curtain', 'cushion', 'floor', 38 | 'table/nightstand/dresser', 'table/nightstand/dresser***top', 'picture / mirrow', 'pillow', 39 | 'lamp***column', 'lamp***shade', 'wall', 'window', 'curtain rod', 'window***frame', 'chair', 40 | 'picture / mirror***frame', 'plinth', 'door / door frame', 'pouf', 'wardrobe', 'plant', 'table staff' 41 | ] 42 | 43 | 44 | ffhq_class = ['background', 'head', 'head***cheek', 'head***chin', 'head***ear', 'head***ear***helix', 45 | 'head***ear***lobule', 'head***eye***bottom lid', 'head***eye***eyelashes', 'head***eye***iris', 46 | 'head***eye***pupil', 'head***eye***sclera', 'head***eye***tear duct', 'head***eye***top lid', 47 | 'head***eyebrow', 'head***forehead', 'head***frown', 'head***hair', 'head***hair***sideburns', 48 | 'head***jaw', 'head***moustache', 'head***mouth***inferior lip', 'head***mouth***oral commissure', 49 | 'head***mouth***superior lip', 'head***mouth***teeth', 'head***neck', 'head***nose', 50 | 'head***nose***ala of nose', 'head***nose***bridge', 'head***nose***nose tip', 'head***nose***nostril', 51 | 'head***philtrum', 'head***temple', 'head***wrinkles'] 52 | 53 | 54 | cat_class = ['background', 'back', 'belly', 'chest', 'leg', 'paw', 55 | 'head', 'ear', 'eye', 'mouth', 'tongue', 'nose', 'tail', 'whiskers', 'neck'] 56 | 57 | 58 | horse_class = ['background', 'person', 'back', 'barrel', 'bridle', 'chest', 'ear', 'eye', 'forelock', 'head', 59 | 'hoof', 'leg', 'mane', 'muzzle', 'neck', 'nostril', 'tail', 'thigh', 'saddle', 'shoulder', 'leg protection'] 60 | 61 | 62 | celeba_class = ['background', 'cloth', 'ear_r', 'eye_g', 'hair', 'hat', 'l_brow', 63 | 'l_ear', 'l_eye', 'l_lip', 'mouth', 'neck', 'neck_l', 'nose', 'r_brow', 64 | 'r_ear', 'r_eye', 'skin', 'u_lip'] 65 | 66 | 67 | ade_bedroom_50_class = ["wall", "bed", "floor", "table", "lamp", "ceiling", "painting", "windowpane", 68 | "pillow", "curtain", "cushion", "door", "chair", "cabinet", "chest", "mirror", "rug", "armchair", "book", 69 | "sconce", "plant", "wardrobe", "clock", "light", "flower", "vase", "fan", "box", "shelf", "television", 70 | "blind", "pot", "ottoman", "sofa", "desk", "basket", "blanket", "coffee", "plaything", "radiator", 71 | "tray", "stool", "bottle", "chandelier", "fireplacel", "towel", "railing", "canopy", "glass", "plate"] 72 | 73 | ade_bedroom_40_class = ade_bedroom_50_class[:40] 74 | ade_bedroom_30_class = ade_bedroom_50_class[:30] 75 | 76 | 77 | ########### 78 | # Palette # 79 | ########### 80 | 81 | 82 | ffhq_palette = [ 1.0000, 1.0000 , 1.0000, 83 | 0.4420, 0.5100 , 0.4234, 84 | 0.8562, 0.9537 , 0.3188, 85 | 0.2405, 0.4699 , 0.9918, 86 | 0.8434, 0.9329 ,0.7544, 87 | 0.3748, 0.7917 , 0.3256, 88 | 0.0190, 0.4943 , 0.3782, 89 | 0.7461 , 0.0137 , 0.5684, 90 | 0.1644, 0.2402 , 0.7324, 91 | 0.0200 , 0.4379 , 0.4100, 92 | 0.5853 , 0.8880 , 0.6137, 93 | 0.7991 , 0.9132 , 0.9720, 94 | 0.6816 , 0.6237 ,0.8562, 95 | 0.9981 , 0.4692 , 0.3849, 96 | 0.5351 , 0.8242 , 0.2731, 97 | 0.1747 , 0.3626 , 0.8345, 98 | 0.5323 , 0.6668 , 0.4922, 99 | 0.2122 , 0.3483 , 0.4707, 100 | 0.6844, 0.1238 , 0.1452, 101 | 0.3882 , 0.4664 , 0.1003, 102 | 0.2296, 0.0401 , 0.3030, 103 | 0.5751 , 0.5467 , 0.9835, 104 | 0.1308 , 0.9628, 0.0777, 105 | 0.2849 ,0.1846 , 0.2625, 106 | 0.9764 , 0.9420 , 0.6628, 107 | 0.3893 , 0.4456 , 0.6433, 108 | 0.8705 , 0.3957 , 0.0963, 109 | 0.6117 , 0.9702 , 0.0247, 110 | 0.3668 , 0.6694 , 0.3117, 111 | 0.6451 , 0.7302, 0.9542, 112 | 0.6171 , 0.1097, 0.9053, 113 | 0.3377 , 0.4950, 0.7284, 114 | 0.1655, 0.9254, 0.6557, 115 | 0.9450 ,0.6721, 0.6162] 116 | 117 | ffhq_palette = [int(item * 255) for item in ffhq_palette] 118 | 119 | 120 | bedroom_palette = [ 121 | 255, 255, 255, # bg 122 | 238, 229, 102, # bed 123 | 255, 72, 69, # bed footboard 124 | 124, 99 , 34, # bed headboard 125 | 193 , 127, 15, # bed side rail 126 | 106, 177, 21, # carpet 127 | 248 ,213 , 43, # ceiling 128 | 252 , 155, 83, # chandelier / ceiling fan blade 129 | 220 ,147 , 77, # curtain 130 | 99 , 83 , 3, # cushion 131 | 116 , 116 , 138, # floor 132 | 63 ,182 , 24, # table/nightstand/dresser 133 | 200 ,226 , 37, # table/nightstand/dresser top 134 | 225 , 184 , 161, # picture / mirrow 135 | 233 , 5 ,219, # pillow 136 | 142 , 172 ,248, # lamp column 137 | 153 , 112 , 146, # lamp shade 138 | 38 ,112 , 254, # wall 139 | 229 , 30 ,141, # window 140 | 99, 205, 255, # curtain rod 141 | 74, 59, 83, # window frame 142 | 186, 9, 0, # chair 143 | 107, 121, 0, # picture / mirrow frame 144 | 0, 194, 160, # plinth 145 | 255, 170, 146, # door / door frame 146 | 255, 144, 201, # pouf 147 | 185, 3, 170, # wardrobe 148 | 221, 239, 255, # plant 149 | 0, 0, 53, # table staff 150 | ] 151 | 152 | 153 | cat_palette = [255, 255, 255, 154 | 190, 153, 153, 155 | 250, 170, 30, 156 | 220, 220, 0, 157 | 107, 142, 35, 158 | 102, 102, 156, 159 | 152, 251, 152, 160 | 119, 11, 32, 161 | 244, 35, 232, 162 | 220, 20, 60, 163 | 52 , 83 ,84, 164 | 194 , 87 , 125, 165 | 143 , 176 , 255, 166 | 31 , 102 , 211, 167 | 104 , 131 , 101 168 | ] 169 | 170 | 171 | horse_palette = [255, 255, 255, 172 | 255, 74, 70, 173 | 0, 137, 65, 174 | 0, 111, 166, 175 | 163, 0, 89, 176 | 255, 219, 229, 177 | 122, 73, 0, 178 | 0, 0, 166, 179 | 99, 255, 172, 180 | 183, 151, 98, 181 | 0, 77, 67, 182 | 143, 176, 255, 183 | 241, 38, 110, 184 | 27, 210, 105, 185 | 128, 150, 147, 186 | 228, 230, 158, 187 | 160, 136, 106, 188 | 79, 198, 1, 189 | 59, 93, 255, 190 | 115, 214, 209, 191 | 255, 47, 128 192 | ] 193 | 194 | 195 | celeba_palette =[ 255, 255, 255, # 0 background 196 | 238, 229, 102,# 1 cloth 197 | 250, 150, 50,# 2 ear_r 198 | 124, 99 , 34, # 3 eye_g 199 | 193 , 127, 15,# 4 hair 200 | 225, 96 ,18, # 5 hat 201 | 220 ,147 , 77, # 6 l_brow 202 | 99 , 83 , 3, # 7 l_ear 203 | 116 , 116 , 138, # 8 l_eye 204 | 200 ,226 , 37, # 9 l_lip 205 | 225 , 184 , 161, # 10 mouth 206 | 142 , 172 ,248, # 11 neck 207 | 153 , 112 , 146, # 12 neck_l 208 | 38 ,112 , 254, # 13 nose 209 | 229 , 30 ,141, # 14 r_brow 210 | 52 , 83 ,84, # 15 r_ear 211 | 194 , 87 , 125, # 16 r_eye 212 | 248 ,213 , 42, # 17 skin 213 | 31 , 102 , 211, # 18 u_lip 214 | ] 215 | 216 | 217 | ade_bedroom_50_palette = [240, 156, 206, 69, 88, 93, 240, 49, 184, 27, 107, 126, 50, 82, 241, 218 | 54, 250, 147, 156, 213, 3, 176, 108, 79, 251, 150, 149, 66, 51, 34, 219 | 210, 97, 53, 30, 53, 102, 232, 164, 118, 204, 150, 17, 101, 86, 178, 220 | 249, 20, 213, 54, 35, 82, 157, 68, 216, 58, 161, 73, 174, 67, 67, 193, 221 | 181, 78, 169, 60, 178, 220, 204, 166, 4, 127, 85, 245, 106, 216, 222, 222 | 172, 168, 84, 148, 105, 137, 220, 89, 68, 252, 126, 29, 193, 187, 74, 223 | 40, 101, 52, 71, 61, 38, 92, 205, 40, 104, 224, 146, 74, 160, 69, 43, 224 | 220, 70, 78, 213, 249, 93, 254, 235, 71, 119, 193, 255, 102, 152, 55, 225 | 238, 133, 12, 223, 106, 116, 123, 86, 14, 174, 244, 160, 161, 142, 226 | 105, 60, 153, 61, 124, 195, 156, 253, 241, 84, 222, 202, 171, 227] 227 | 228 | ade_bedroom_40_palette = ade_bedroom_50_palette[:120] 229 | ade_bedroom_30_palette = ade_bedroom_50_palette[:90] 230 | -------------------------------------------------------------------------------- /src/feature_extractors.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from torch import nn 4 | from typing import List 5 | 6 | 7 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 8 | 9 | 10 | def create_feature_extractor(model_type, **kwargs): 11 | """ Create the feature extractor for architecture. """ 12 | if model_type == 'ddpm': 13 | print("Creating DDPM Feature Extractor...") 14 | feature_extractor = FeatureExtractorDDPM(**kwargs) 15 | elif model_type == 'mae': 16 | print("Creating MAE Feature Extractor...") 17 | feature_extractor = FeatureExtractorMAE(**kwargs) 18 | elif model_type == 'swav': 19 | print("Creating SwAV Feature Extractor...") 20 | feature_extractor = FeatureExtractorSwAV(**kwargs) 21 | elif model_type == 'swav_w2': 22 | print("Creating SwAVw2 Feature Extractor...") 23 | feature_extractor = FeatureExtractorSwAVw2(**kwargs) 24 | else: 25 | raise Exception(f"Wrong model type: {model_type}") 26 | return feature_extractor 27 | 28 | 29 | def save_tensors(module: nn.Module, features, name: str): 30 | """ Process and save activations in the module. """ 31 | if type(features) in [list, tuple]: 32 | features = [f.detach().float() if f is not None else None 33 | for f in features] 34 | setattr(module, name, features) 35 | elif isinstance(features, dict): 36 | features = {k: f.detach().float() for k, f in features.items()} 37 | setattr(module, name, features) 38 | else: 39 | setattr(module, name, features.detach().float()) 40 | 41 | 42 | def save_out_hook(self, inp, out): 43 | save_tensors(self, out, 'activations') 44 | return out 45 | 46 | 47 | def save_input_hook(self, inp, out): 48 | save_tensors(self, inp[0], 'activations') 49 | return out 50 | 51 | 52 | class FeatureExtractor(nn.Module): 53 | def __init__(self, model_path: str, input_activations: bool, **kwargs): 54 | ''' 55 | Parent feature extractor class. 56 | 57 | param: model_path: path to the pretrained model 58 | param: input_activations: 59 | If True, features are input activations of the corresponding blocks 60 | If False, features are output activations of the corresponding blocks 61 | ''' 62 | super().__init__() 63 | self._load_pretrained_model(model_path, **kwargs) 64 | print(f"Pretrained model is successfully loaded from {model_path}") 65 | self.save_hook = save_input_hook if input_activations else save_out_hook 66 | self.feature_blocks = [] 67 | 68 | def _load_pretrained_model(self, model_path: str, **kwargs): 69 | pass 70 | 71 | 72 | class FeatureExtractorDDPM(FeatureExtractor): 73 | ''' 74 | Wrapper to extract features from pretrained DDPMs. 75 | 76 | :param steps: list of diffusion steps t. 77 | :param blocks: list of the UNet decoder blocks. 78 | ''' 79 | 80 | def __init__(self, steps: List[int], blocks: List[int], **kwargs): 81 | super().__init__(**kwargs) 82 | self.steps = steps 83 | 84 | # Save decoder activations 85 | for idx, block in enumerate(self.model.output_blocks): 86 | if idx in blocks: 87 | block.register_forward_hook(self.save_hook) 88 | self.feature_blocks.append(block) 89 | 90 | def _load_pretrained_model(self, model_path, **kwargs): 91 | import inspect 92 | import guided_diffusion.guided_diffusion.dist_util as dist_util 93 | from guided_diffusion.guided_diffusion.script_util import create_model_and_diffusion 94 | 95 | # Needed to pass only expected args to the function 96 | argnames = inspect.getfullargspec(create_model_and_diffusion)[0] 97 | expected_args = {name: kwargs[name] for name in argnames} 98 | self.model, self.diffusion = create_model_and_diffusion(**expected_args) 99 | 100 | self.model.load_state_dict( 101 | dist_util.load_state_dict(model_path, map_location="cpu") 102 | ) 103 | self.model.to(dist_util.dev()) 104 | if kwargs['use_fp16']: 105 | self.model.convert_to_fp16() 106 | self.model.eval() 107 | 108 | @torch.no_grad() 109 | def forward(self, x, noise=None): 110 | activations = [] 111 | for t in self.steps: 112 | # Compute x_t and run DDPM 113 | t = torch.tensor([t]).to(x.device) 114 | noisy_x = self.diffusion.q_sample(x, t, noise=noise) 115 | self.model(noisy_x, self.diffusion._scale_timesteps(t)) 116 | 117 | # Extract activations 118 | for block in self.feature_blocks: 119 | activations.append(block.activations) 120 | block.activations = None 121 | 122 | # Per-layer list of activations [N, C, H, W] 123 | return activations 124 | 125 | 126 | class FeatureExtractorMAE(FeatureExtractor): 127 | ''' 128 | Wrapper to extract features from pretrained MAE 129 | ''' 130 | def __init__(self, num_blocks=12, **kwargs): 131 | super().__init__(**kwargs) 132 | 133 | # Save features from deep encoder blocks 134 | for layer in self.model.blocks[-num_blocks:]: 135 | layer.register_forward_hook(self.save_hook) 136 | self.feature_blocks.append(layer) 137 | 138 | def _load_pretrained_model(self, model_path, **kwargs): 139 | import mae 140 | from functools import partial 141 | sys.path.append(mae.__path__[0]) 142 | from mae.models_mae import MaskedAutoencoderViT 143 | 144 | # Create MAE with ViT-L-8 backbone 145 | model = MaskedAutoencoderViT( 146 | img_size=256, patch_size=8, embed_dim=1024, depth=24, num_heads=16, 147 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 148 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), norm_pix_loss=True 149 | ) 150 | 151 | checkpoint = torch.load(model_path, map_location='cpu') 152 | model.load_state_dict(checkpoint['model']) 153 | self.model = model.eval().to(device) 154 | 155 | @torch.no_grad() 156 | def forward(self, x, **kwargs): 157 | _, _, ids_restore = self.model.forward_encoder(x, mask_ratio=0) 158 | ids_restore = ids_restore.unsqueeze(-1) 159 | sqrt_num_patches = int(self.model.patch_embed.num_patches ** 0.5) 160 | activations = [] 161 | for block in self.feature_blocks: 162 | # remove cls token 163 | a = block.activations[:, 1:] 164 | # unshuffle patches 165 | a = torch.gather(a, dim=1, index=ids_restore.repeat(1, 1, a.shape[2])) 166 | # reshape to obtain spatial feature maps 167 | a = a.permute(0, 2, 1) 168 | a = a.view(*a.shape[:2], sqrt_num_patches, sqrt_num_patches) 169 | 170 | activations.append(a) 171 | block.activations = None 172 | # Per-layer list of activations [N, C, H, W] 173 | return activations 174 | 175 | 176 | class FeatureExtractorSwAV(FeatureExtractor): 177 | ''' 178 | Wrapper to extract features from pretrained SwAVs 179 | ''' 180 | def __init__(self, **kwargs): 181 | super().__init__(**kwargs) 182 | 183 | layers = [self.model.layer1, self.model.layer2, 184 | self.model.layer3, self.model.layer4] 185 | 186 | # Save features from sublayers 187 | for layer in layers: 188 | for l in layer[::2]: 189 | l.register_forward_hook(self.save_hook) 190 | self.feature_blocks.append(l) 191 | 192 | def _load_pretrained_model(self, model_path, **kwargs): 193 | import swav 194 | sys.path.append(swav.__path__[0]) 195 | from swav.hubconf import resnet50 196 | 197 | model = resnet50(pretrained=False).to(device).eval() 198 | model.fc = nn.Identity() 199 | model = torch.nn.DataParallel(model) 200 | state_dict = torch.load(model_path)['state_dict'] 201 | model.load_state_dict(state_dict, strict=False) 202 | self.model = model.module.eval() 203 | 204 | @torch.no_grad() 205 | def forward(self, x, **kwargs): 206 | self.model(x) 207 | 208 | activations = [] 209 | for block in self.feature_blocks: 210 | activations.append(block.activations) 211 | block.activations = None 212 | 213 | # Per-layer list of activations [N, C, H, W] 214 | return activations 215 | 216 | 217 | class FeatureExtractorSwAVw2(FeatureExtractorSwAV): 218 | ''' 219 | Wrapper to extract features from twice wider pretrained SwAVs 220 | ''' 221 | def _load_pretrained_model(self, model_path, **kwargs): 222 | import swav 223 | sys.path.append(swav.__path__[0]) 224 | from swav.hubconf import resnet50w2 225 | 226 | model = resnet50w2(pretrained=False).to(device).eval() 227 | model.fc = nn.Identity() 228 | model = torch.nn.DataParallel(model) 229 | state_dict = torch.load(model_path)['state_dict'] 230 | model.load_state_dict(state_dict, strict=False) 231 | self.model = model.module.eval() 232 | 233 | 234 | def collect_features(args, activations: List[torch.Tensor], sample_idx=0): 235 | """ Upsample activations and concatenate them to form a feature tensor """ 236 | assert all([isinstance(acts, torch.Tensor) for acts in activations]) 237 | size = tuple(args['dim'][:-1]) 238 | resized_activations = [] 239 | for feats in activations: 240 | feats = feats[sample_idx][None] 241 | feats = nn.functional.interpolate( 242 | feats, size=size, mode=args["upsample_mode"] 243 | ) 244 | resized_activations.append(feats[0]) 245 | 246 | return torch.cat(resized_activations, dim=0) 247 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Label-Efficient Semantic Segmentation with Diffusion Models 2 | 3 | **ICLR'2022** [[Project page]](https://yandex-research.github.io/ddpm-segmentation/) 4 | 5 | Official implementation of the paper [Label-Efficient Semantic Segmentation with Diffusion Models](https://arxiv.org/pdf/2112.03126.pdf) 6 | 7 | This code is based on [datasetGAN](https://github.com/nv-tlabs/datasetGAN_release) and [guided-diffusion](https://github.com/openai/guided-diffusion). 8 | 9 | **Note:** use **--recurse-submodules** when clone. 10 | 11 |   12 | ## Overview 13 | 14 | The paper investigates the representations learned by the state-of-the-art DDPMs and shows that they capture high-level semantic information valuable for downstream vision tasks. We design a simple semantic segmentation approach that exploits these representations and outperforms the alternatives in the few-shot operating point. 15 | 16 |
17 | DDPM-based Segmentation 18 |
19 | 20 |   21 | ## Updates 22 | 23 | **3/9/2022:** 24 | 25 | 1) Improved performance of DDPM-based segmentation by changing:\ 26 |   Diffusion steps: [50,150,250,350] --> [50,150,250];\ 27 |   UNet blocks: [6,7,8,9] --> [5,6,7,8,12]; 28 | 3) Trained a bit better DDPM on FFHQ-256; 29 | 4) Added [MAE](https://github.com/facebookresearch/mae) for comparison. 30 | 31 |   32 | ## Datasets 33 | 34 | The evaluation is performed on 6 collected datasets with a few annotated images in the training set: 35 | Bedroom-18, FFHQ-34, Cat-15, Horse-21, CelebA-19 and ADE-Bedroom-30. The number corresponds to the number of semantic classes. 36 | 37 | [datasets.tar.gz](https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/datasets.tar.gz) (~47Mb) 38 | 39 | 40 |   41 | ## DDPM 42 | 43 | ### Pretrained DDPMs 44 | 45 | The models trained on LSUN are adopted from [guided-diffusion](https://github.com/openai/guided-diffusion). 46 | FFHQ-256 is trained by ourselves using the same model parameters as for the LSUN models. 47 | 48 | *LSUN-Bedroom:* [lsun_bedroom.pt](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/lsun_bedroom.pt)\ 49 | *FFHQ-256:* [ffhq.pt](https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/models/ddpm_checkpoints/ffhq.pt) (Updated 3/8/2022)\ 50 | *LSUN-Cat:* [lsun_cat.pt](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/lsun_cat.pt)\ 51 | *LSUN-Horse:* [lsun_horse.pt](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/lsun_horse.pt) 52 | 53 | ### Run 54 | 55 | 1. Download the datasets:\ 56 |   ```bash datasets/download_datasets.sh``` 57 | 2. Download the DDPM checkpoint:\ 58 |    ```bash checkpoints/ddpm/download_checkpoint.sh ``` 59 | 3. Check paths in ```experiments//ddpm.json``` 60 | 4. Run: ```bash scripts/ddpm/train_interpreter.sh ``` 61 | 62 | **Available checkpoint names:** lsun_bedroom, ffhq, lsun_cat, lsun_horse\ 63 | **Available dataset names:** bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30 64 | 65 | **Note:** ```train_interpreter.sh``` is RAM consuming since it keeps all training pixel representations in memory. For ex, it requires ~210Gb for 50 training images of 256x256. (See [issue](https://github.com/nv-tlabs/datasetGAN_release/issues/34)) 66 | 67 | **Pretrained pixel classifiers** and test predictions are [here](https://www.dropbox.com/s/kap229jvmhfwh7i/pixel_classifiers.tar?dl=0). 68 | 69 | ### How to improve the performance 70 | 71 | * Tune for a particular task what diffusion steps and UNet blocks to use. 72 | 73 | 74 |   75 | ## DatasetDDPM 76 | 77 | 78 | ### Synthetic datasets 79 | 80 | To download DDPM-produced synthetic datasets (50000 samples, ~7Gb) (updated 3/8/2022):\ 81 | ```bash synthetic-datasets/ddpm/download_synthetic_dataset.sh ``` 82 | 83 | ### Run | Option #1 84 | 85 | 1. Download the synthetic dataset:\ 86 |    ```bash synthetic-datasets/ddpm/download_synthetic_dataset.sh ``` 87 | 2. Check paths in ```experiments//datasetDDPM.json``` 88 | 3. Run: ```bash scripts/datasetDDPM/train_deeplab.sh ``` 89 | 90 | ### Run | Option #2 91 | 92 | 1. Download the datasets:\ 93 |    ```bash datasets/download_datasets.sh``` 94 | 2. Download the DDPM checkpoint:\ 95 |    ```bash checkpoints/ddpm/download_checkpoint.sh ``` 96 | 3. Check paths in ```experiments//datasetDDPM.json``` 97 | 4. Train an interpreter on a few DDPM-produced annotated samples:\ 98 |    ```bash scripts/datasetDDPM/train_interpreter.sh ``` 99 | 5. Generate a synthetic dataset:\ 100 |    ```bash scripts/datasetDDPM/generate_dataset.sh ```\ 101 |     Please specify the hyperparameters in this script for the available resources.\ 102 |     On 8xA100 80Gb, it takes about 12 hours to generate 10000 samples. 103 | 104 | 5. Run: ```bash scripts/datasetDDPM/train_deeplab.sh ```\ 105 |    One needs to specify the path to the generated data. See comments in the script. 106 | 107 | **Available checkpoint names:** lsun_bedroom, ffhq, lsun_cat, lsun_horse\ 108 | **Available dataset names:** bedroom_28, ffhq_34, cat_15, horse_21 109 | 110 |   111 | ## MAE 112 | 113 | ### Pretrained MAEs 114 | 115 | We pretrain MAE models using the [official implementation](https://github.com/facebookresearch/mae) on the LSUN and FFHQ-256 datasets: 116 | 117 | *LSUN-Bedroom:* [lsun_bedroom.pth](https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/models/mae_checkpoints/lsun_bedroom.pth)\ 118 | *FFHQ-256:* [ffhq.pth](https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/models/mae_checkpoints/ffhq.pth)\ 119 | *LSUN-Cat:* [lsun_cat.pth](https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/models/mae_checkpoints/lsun_cat.pth)\ 120 | *LSUN-Horse:* [lsun_horse.pth](https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/models/mae_checkpoints/lsun_horse.pth) 121 | 122 | **Training setups**: 123 | 124 | | Dataset | Backbone | epochs | batch-size | mask-ratio | 125 | |-------------------|-------------------|---------------------|--------------------|--------------------| 126 | | LSUN Bedroom | ViT-L-8 | 150 | 1024 | 0.75 | 127 | | LSUN Cat | ViT-L-8 | 200 | 1024 | 0.75 | 128 | | LSUN Horse | ViT-L-8 | 200 | 1024 | 0.75 | 129 | | FFHQ-256 | ViT-L-8 | 400 | 1024 | 0.75 | 130 | 131 | ### Run 132 | 133 | 1. Download the datasets:\ 134 |    ```bash datasets/download_datasets.sh``` 135 | 2. Download the MAE checkpoint:\ 136 |    ```bash checkpoints/mae/download_checkpoint.sh ``` 137 | 3. Check paths in ```experiments//mae.json``` 138 | 4. Run: ```bash scripts/mae/train_interpreter.sh ``` 139 | 140 | **Available checkpoint names:** lsun_bedroom, ffhq, lsun_cat, lsun_horse\ 141 | **Available dataset names:** bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30 142 | 143 |   144 | ## SwAV 145 | 146 | ### Pretrained SwAVs 147 | 148 | We pretrain SwAV models using the [official implementation](https://github.com/facebookresearch/swav) on the LSUN and FFHQ-256 datasets: 149 | 150 | | LSUN-Bedroom | FFHQ-256 | LSUN-Cat | LSUN-Horse | 151 | |-------------------|-------------------|---------------------|--------------------| 152 | | [SwAV](https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/models/swav_checkpoints/lsun_bedroom.pth) | [SwAV](https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/models/swav_checkpoints/ffhq.pth) | [SwAV](https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/models/swav_checkpoints/lsun_cat.pth) | [SwAV](https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/models/swav_checkpoints/lsun_horse.pth) | 153 | | [SwAVw2](https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/models/swav_w2_checkpoints/lsun_bedroom.pth) | [SwAVw2](https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/models/swav_w2_checkpoints/ffhq.pth) | [SwAVw2](https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/models/swav_w2_checkpoints/lsun_cat.pth) | [SwAVw2](https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/models/swav_w2_checkpoints/lsun_horse.pth) | 154 | 155 | **Training setups**: 156 | 157 | | Dataset | Backbone | epochs | batch-size | multi-crop | num-prototypes | 158 | |-------------------|-------------------|---------------------|--------------------|--------------------|--------------------| 159 | | LSUN | RN50 | 200 | 1792 | 2x256 + 6x108 | 1000 | 160 | | FFHQ-256 | RN50 | 400 | 2048 | 2x224 + 6x96 | 200 | 161 | | LSUN | RN50w2 | 200 | 1920 | 2x256 + 4x108 | 1000 | 162 | | FFHQ-256 | RN50w2 | 400 | 2048 | 2x224 + 4x96 | 200 | 163 | 164 | ### Run 165 | 166 | 1. Download the datasets:\ 167 |    ```bash datasets/download_datasets.sh``` 168 | 2. Download the SwAV checkpoint:\ 169 |    ```bash checkpoints/{swav|swav_w2}/download_checkpoint.sh ``` 170 | 3. Check paths in ```experiments//{swav|swav_w2}.json``` 171 | 4. Run: ```bash scripts/{swav|swav_w2}/train_interpreter.sh ``` 172 | 173 | **Available checkpoint names:** lsun_bedroom, ffhq, lsun_cat, lsun_horse\ 174 | **Available dataset names:** bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30 175 | 176 | 177 |   178 | ## DatasetGAN 179 | 180 | Opposed to the [official implementation](https://github.com/nv-tlabs/datasetGAN_release), more recent StyleGAN2(-ADA) models are used. 181 | 182 | ### Synthetic datasets 183 | 184 | To download GAN-produced synthetic datasets (50000 samples): 185 | 186 | ```bash synthetic-datasets/gan/download_synthetic_dataset.sh ``` 187 | 188 | ### Run 189 | 190 | Since we almost fully adopt the [official implementation](https://github.com/nv-tlabs/datasetGAN_release), we don't provide our reimplementation here. 191 | However, one can still reproduce our results: 192 | 193 | 1. Download the synthetic dataset:\ 194 |   ```bash synthetic-datasets/gan/download_synthetic_dataset.sh ``` 195 | 2. Change paths in ```experiments//datasetDDPM.json``` 196 | 3. Change paths and run: ```bash scripts/datasetDDPM/train_deeplab.sh ``` 197 | 198 | **Available dataset names:** bedroom_28, ffhq_34, cat_15, horse_21 199 | 200 | 201 |   202 | ## Results 203 | 204 | * Performance in terms of mean IoU: 205 | 206 | | Method | Bedroom-28 | FFHQ-34 | Cat-15 | Horse-21 | CelebA-19 | ADE-Bedroom-30 | 207 | |:------------- |:-------------- |:--------------- |:--------------- |:--------------- |:--------------- |:--------------- | 208 | | ALAE | 20.0 ± 1.0 | 48.1 ± 1.3 | -- | -- | 49.7 ± 0.7 | 15.0 ± 0.5 | 209 | | VDVAE | -- | 57.3 ± 1.1 | -- | -- | 54.1 ± 1.0 | -- | 210 | | GAN Inversion | 13.9 ± 0.6 | 51.7 ± 0.8 | 21.4 ± 1.7 | 17.7 ± 0.4 | 51.5 ± 2.3 | 11.1 ± 0.2 | 211 | | GAN Encoder | 22.4 ± 1.6 | 53.9 ± 1.3 | 32.0 ± 1.8 | 26.7 ± 0.7 | 53.9 ± 0.8 | 15.7 ± 0.3 | 212 | | SwAV | 41.0 ± 2.3 | 54.7 ± 1.4 | 44.1 ± 2.1 | 51.7 ± 0.5 | 53.2 ± 1.0 | 30.3 ± 1.5 | 213 | | SwAVw2 | 42.4 ± 1.7 | 56.9 ± 1.3 | 45.1 ± 2.1 | 54.0 ± 0.9 | 52.4 ± 1.3 | 30.6 ± 1.0 | 214 | | MAE | 45.0 ± 2.0 | **58.8 ± 1.1** | **52.4 ± 2.3** | 63.4 ± 1.4 | 57.8 ± 0.4 | 31.7 ± 1.8 | 215 | | DatasetGAN | 31.3 ± 2.7 | 57.0 ± 1.0 | 36.5 ± 2.3 | 45.4 ± 1.4 | -- | -- | 216 | | DatasetDDPM | 47.9 ± 2.9 | 56.0 ± 0.9 | 47.6 ± 1.5 | 60.8 ± 1.0 | -- | -- | 217 | | **DDPM** | **49.4 ± 1.9** | **59.1 ± 1.4** | **53.7 ± 3.3** | **65.0 ± 0.8** | **59.9 ± 1.0** | **34.6 ± 1.7** | 218 | 219 |   220 | * Examples of segmentation masks predicted by the DDPM-based method: 221 | 222 |
223 | DDPM-based Segmentation 224 |
225 | 226 | 227 |   228 | ## Cite 229 | 230 | ``` 231 | @misc{baranchuk2021labelefficient, 232 | title={Label-Efficient Semantic Segmentation with Diffusion Models}, 233 | author={Dmitry Baranchuk and Ivan Rubachev and Andrey Voynov and Valentin Khrulkov and Artem Babenko}, 234 | year={2021}, 235 | eprint={2112.03126}, 236 | archivePrefix={arXiv}, 237 | primaryClass={cs.CV} 238 | } 239 | ``` 240 | --------------------------------------------------------------------------------