├── .gitignore
├── LICENSE
├── README.md
├── asset
├── bloodmnist_vis.png
├── cifar100_vis.png
├── cifar10_vis.png
├── dermamnist_vis.png
├── eurosat_vis.png
├── imagenette_vis.png
├── pathmnist_vis.png
├── stl10_vis.png
└── workflow_box.png
├── environment.yml
└── src
├── classfier_tuning
├── clip
│ ├── README.md
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── clip.py
│ ├── model.py
│ └── tokenizer.py
└── src
│ ├── args.py
│ ├── clip_filtering.py
│ ├── clip_filtering_eurosat.py
│ ├── ct_fsl.py
│ ├── ct_zsl.py
│ ├── datasets
│ ├── __init__.py
│ ├── cifar10.py
│ ├── cifar100.py
│ ├── common.py
│ ├── eurosat.py
│ ├── fmow.py
│ ├── imagenet.py
│ ├── imagenet_a.py
│ ├── imagenet_classnames.py
│ ├── imagenet_r.py
│ ├── imagenet_sketch.py
│ ├── imagenet_vid_robust.py
│ ├── imagenetv2.py
│ ├── iwildcam.py
│ ├── iwildcam_metadata
│ │ └── labels.csv
│ ├── objectnet.py
│ ├── objectnet_metadata
│ │ ├── folder_to_objectnet_label.json
│ │ ├── imagenet_to_label_2012_v2
│ │ ├── objectnet_to_imagenet_1k.json
│ │ └── pytorch_to_imagenet_2012_id.json
│ ├── transfer_ds
│ │ ├── Randaug.py
│ │ ├── __init__.py
│ │ ├── aircraft.py
│ │ ├── cal_mean_std.py
│ │ ├── caltech.py
│ │ ├── constants.py
│ │ ├── cub.py
│ │ ├── cub_for_robust_codebase.py
│ │ ├── cub_transform.py
│ │ ├── dtd.py
│ │ ├── fine_tunify.py
│ │ ├── food_101.py
│ │ ├── imbalance_cifar.py
│ │ ├── process_dataset
│ │ │ ├── pro_aircraft.py
│ │ │ ├── pro_caltech101.py
│ │ │ ├── pro_cars.py
│ │ │ ├── pro_flowers.py
│ │ │ ├── pro_imgnet.py
│ │ │ ├── pro_pool15.py
│ │ │ └── process_pets.py
│ │ ├── transfer_datasets.py
│ │ ├── transform_ckpt.py
│ │ └── utils.py
│ ├── ytbb-robust_metadata
│ │ ├── anchor_labels.json
│ │ ├── class_idx_map.json
│ │ ├── pmk_labels.json
│ │ ├── rev_class_idx_map.json
│ │ ├── ytbb_class_index.json
│ │ └── ytbb_robustness_test_anchors_full.csv
│ └── ytbb_robust.py
│ ├── eurosat_text_feature.pt
│ ├── get_classifier_weights.py
│ ├── imagenet_text_feature.pt
│ ├── models
│ ├── __init__.py
│ ├── eval.py
│ ├── finetune.py
│ ├── finetune_retina.py
│ ├── modeling.py
│ ├── utils.py
│ ├── zeroshot.py
│ └── zeroshot_retina.py
│ ├── select_glide_ims_by_clip.py
│ └── templates
│ ├── __init__.py
│ ├── fmow_template.py
│ ├── iwildcam_template.py
│ ├── openai_imagenet_template.py
│ ├── simple_template.py
│ ├── transfer_ds_template.py
│ └── utils.py
├── dataset.py
├── diffuser_inversion.py
├── models
├── __init__.py
├── densenet.py
├── dla.py
├── dla_simple.py
├── dpn.py
├── efficientnet.py
├── googlenet.py
├── lenet.py
├── mobilenet.py
├── mobilenetv2.py
├── pnasnet.py
├── preact_resnet.py
├── regnet.py
├── resnet.py
├── resnext.py
├── senet.py
├── shufflenet.py
├── shufflenetv2.py
└── vgg.py
├── pipeline_emb.py
├── sample_dataset.py
├── script
├── clip_retrieval_stl10.py
├── compute_statistics.py
├── convert_vae.py
├── split_euro_dataset.py
└── split_zhou_EuroSAT.json
└── train_net.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | .DS_Store
30 | *.pyc
31 | *.log
32 | .vscode/
33 | src/wandb/
34 |
35 | checkpoint/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Yongchao Zhou
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Diffusion Inversion
2 |
3 | [Project Page](https://sites.google.com/view/diffusion-inversion)
4 | | [ArXiv](https://arxiv.org/abs/2305.15316)
5 |
6 | This repo contains code for steer Stable Diffusion Model to generate data for downstream classifier training. Please see our paper and project page for more results.
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | ## Abstract
20 | Acquiring high-quality data for training discriminative models is a crucial yet challenging aspect of building effective predictive systems. In this paper, we present Diffusion Inversion, a simple yet effective method that leverages the pre-trained generative model, Stable Diffusion, to generate diverse, high-quality training data for image classification. Our approach captures the original data distribution and ensures data coverage by inverting images to the latent space of Stable Diffusion, and generates diverse novel training images by conditioning the generative model on noisy versions of these vectors. We identify three key components that allow our generated images to successfully supplant the original dataset, leading to a 2-3x enhancement in sample complexity and a 6.5x decrease in sampling time. Moreover, our approach consistently outperforms generic prompt-based steering methods and KNN retrieval baseline across a wide range of datasets. Additionally, we demonstrate the compatibility of our approach with widely-used data augmentation techniques, as well as the reliability of the generated data in supporting various neural architectures and enhancing few-shot learning.
21 |
22 | ## Method
23 | Stable Diffusion, a model trained on billions of image-text pairs, boasts a wealth of generalizable knowledge. To harness this knowledge for specific classification tasks, we propose a two-stage method that guides a pre-trained generator, $G$, towards the target domain dataset. In the first stage, we map each image to the model's latent space, generating a dataset of latent embedding vectors. Then, we produce novel image variants by running the inverse diffusion process conditioned on perturbed versions of these vectors. We illustrate our approach in Figure below.
24 |
25 |
26 |
27 |
28 | ## Reproducing
29 | ### Environment
30 |
31 | - You can set up the environment using the command below.
32 |
33 | ```bash
34 | conda env create -f environment.yaml
35 | conda activate di
36 | ```
37 |
38 | ### Training
39 |
40 | ```bash
41 | path="--pretrained_model_name_or_path=CompVis/stable-diffusion-v1-4 --output_dir=$PROJDIR/diffusion_inversion/logs/stl10 --dataset_name=stl10 --data_dir=~/tensorflow_datasets"
42 | args="--gradient_accumulation_steps=1 --num_tokens=5 --resolution=256 --train_batch_size=50 --num_emb=100 --max_train_steps=8000"
43 | lr="--lr_warmup_steps=0 --interpolation=bicubic --lr_scheduler=constant --learning_rate=3e-02"
44 | log="--checkpointing_steps=1000 --save_steps=1000 --save_image_steps=400 --resume_from_checkpoint=latest"
45 |
46 | accelerate launch src/diffuser_inversion.py $path $args $lr $log --group_id=0
47 | ...
48 | accelerate launch src/diffuser_inversion.py $path $args $lr $log --group_id=50
49 | ```
50 |
51 | ### Sampling
52 |
53 | ```bash
54 | path="--dataset_name=stl10 --model_root_dir=$PROJDIR/diffusion_inversion/logs/stl10/res256_bicubic/emb100_token5_lr0.03_constant --dm_name=CompVis/stable-diffusion-v1-4"
55 | train_config="--emb_ch=768 --num_tokens=5 --num_classes=10 --num_emb=100 --sampling_resolution=256 --save_resolution=96 --outdir=$PROJDIR/inversion_data/stl10/scaling"
56 | sampling_config="--num_inference_steps=100 --batch_size=100 --interpolation_strength=0.1 --num_samples=5 --emb_noise=0.1 --train_steps=3000 --seed=42"
57 |
58 | python sample_dataset.py $path $train_config $sampling_config --group_id=0
59 | ...
60 | python sample_dataset.py $path $train_config $sampling_config --group_id=50
61 | ```
62 |
63 | ### Evaluation
64 | ```bash
65 | path="--output=$PROJDIR/project/diffusion_inversion/arch"
66 | stl10="--dataset-name=stl10 --group-size=100 --num-steps=50000"
67 | pstl10="--syn-data-dir=$PROJDIR/inversion_data/stl10/scaling/res96_bicubic --syn-pattern=tstep[0-9]*_infstep100_gs[0-9]*_noise0.[0-9]*_itep0.[0-9]*_seed[0-9]*"
68 | args="--batch-size=128 --warmup-steps=1000 --num-data=5000 --num-steps=50000 --optimizer=sgd --weight-decay=5e-4 --real-bs=0 --syn-bs=128"
69 | log="--num-evals=20 --seed=42 --wandb-name=DI-stl10 --log-wandb"
70 |
71 | # Train on real data
72 | python train_net.py $path $stl10 $args $log --model=resnet18 --lr=1e-1
73 |
74 | # Train on synthetic data
75 | python train_net.py $path $stl10 $pstl10 $args $log --model=resnet18 --lr=1e-1
76 | ```
77 |
78 |
--------------------------------------------------------------------------------
/asset/bloodmnist_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/asset/bloodmnist_vis.png
--------------------------------------------------------------------------------
/asset/cifar100_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/asset/cifar100_vis.png
--------------------------------------------------------------------------------
/asset/cifar10_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/asset/cifar10_vis.png
--------------------------------------------------------------------------------
/asset/dermamnist_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/asset/dermamnist_vis.png
--------------------------------------------------------------------------------
/asset/eurosat_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/asset/eurosat_vis.png
--------------------------------------------------------------------------------
/asset/imagenette_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/asset/imagenette_vis.png
--------------------------------------------------------------------------------
/asset/pathmnist_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/asset/pathmnist_vis.png
--------------------------------------------------------------------------------
/asset/stl10_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/asset/stl10_vis.png
--------------------------------------------------------------------------------
/asset/workflow_box.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/asset/workflow_box.png
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: di
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - python=3.10.6
8 | - cudatoolkit=11.3
9 | - pytorch=1.12.1
10 | - torchvision=0.13.1
11 | - fire=0.4.0
12 | - imageio=2.22.4
13 | - einops=0.6.0
14 | - pip
15 | - pip:
16 | - tensorflow~=2.10
17 | - tensorflow-datasets~=4.7.0
18 | - ml_collections~=0.1.1
19 | - seaborn~=0.12.1
20 | - kornia~=0.6.8
21 | - timm~=0.6.12
22 | - transformers~=4.25.1
23 | - safetensors~=0.2.5
24 | - accelerate~=0.15
25 | - datasets~=2.7.1
26 | - diffusers
27 | - clu
28 | - gsutil
29 | - medmnist
30 |
--------------------------------------------------------------------------------
/src/classfier_tuning/clip/README.md:
--------------------------------------------------------------------------------
1 | This folder is a lightly modified version of https://github.com/openai/CLIP.
2 |
--------------------------------------------------------------------------------
/src/classfier_tuning/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/src/classfier_tuning/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/src/classfier_tuning/clip/tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | if not special_tokens:
74 | special_tokens = ['', '']
75 | else:
76 | special_tokens = ['', ''] + special_tokens
77 | vocab.extend(special_tokens)
78 | self.encoder = dict(zip(vocab, range(len(vocab))))
79 | self.decoder = {v: k for k, v in self.encoder.items()}
80 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
81 | self.cache = {t:t for t in special_tokens}
82 | special = "|".join(special_tokens)
83 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
84 |
85 | self.vocab_size = len(self.encoder)
86 | self.all_special_ids = [self.encoder[t] for t in special_tokens]
87 |
88 | def bpe(self, token):
89 | if token in self.cache:
90 | return self.cache[token]
91 | word = tuple(token[:-1]) + ( token[-1] + '',)
92 | pairs = get_pairs(word)
93 |
94 | if not pairs:
95 | return token+''
96 |
97 | while True:
98 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
99 | if bigram not in self.bpe_ranks:
100 | break
101 | first, second = bigram
102 | new_word = []
103 | i = 0
104 | while i < len(word):
105 | try:
106 | j = word.index(first, i)
107 | new_word.extend(word[i:j])
108 | i = j
109 | except:
110 | new_word.extend(word[i:])
111 | break
112 |
113 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
114 | new_word.append(first+second)
115 | i += 2
116 | else:
117 | new_word.append(word[i])
118 | i += 1
119 | new_word = tuple(new_word)
120 | word = new_word
121 | if len(word) == 1:
122 | break
123 | else:
124 | pairs = get_pairs(word)
125 | word = ' '.join(word)
126 | self.cache[token] = word
127 | return word
128 |
129 | def encode(self, text):
130 | bpe_tokens = []
131 | text = whitespace_clean(basic_clean(text)).lower()
132 | for token in re.findall(self.pat, text):
133 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
134 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
135 | return bpe_tokens
136 |
137 | def decode(self, tokens):
138 | text = ''.join([self.decoder[token] for token in tokens])
139 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
140 | return text
--------------------------------------------------------------------------------
/src/classfier_tuning/src/ct_fsl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
4 |
5 | import wandb
6 | from src.models.finetune import finetune_fsl
7 | from src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier
8 | from src.models.zeroshot import get_zeroshot_classifier
9 | from src.args import parse_arguments
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomResizedCrop, RandomHorizontalFlip
11 | from PIL import Image
12 |
13 | def _convert_to_rgb(image):
14 | return image.convert('RGB')
15 |
16 | normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
17 |
18 | def classifier_tuning(args):
19 | assert args.save is not None, 'Please provide a path to store models'
20 | print('import success')
21 |
22 | classification_head_path = os.path.join(args.save, f'head.pt')
23 | args.data_location_real = os.path.join(args.data_location_real, f'shot_{args.shot}')
24 | args.data_location_syn = os.path.join(args.data_location_syn, f'shot_{args.shot}')
25 | args.save = os.path.join(args.save, f'shot_{args.shot}', args.cache_name_syn)
26 | args.cache_dir = os.path.join(args.cache_dir, f'shot_{args.shot}')
27 | args.results_db = os.path.join(args.results_db, f'shot_{args.shot}', args.cache_name_syn)
28 |
29 | # Build and save zero-shot model
30 | image_encoder = ImageEncoder(args, keep_lang=True)
31 | if not os.path.exists(classification_head_path):
32 | classification_head = get_zeroshot_classifier(args, image_encoder.model)
33 | classification_head.save(classification_head_path)
34 | else:
35 | classification_head = ClassificationHead.load(classification_head_path)
36 | delattr(image_encoder.model, 'transformer')
37 | classifier = ImageClassifier(image_encoder, classification_head, process_images=False)
38 |
39 | zeroshot_checkpoint = os.path.join(args.save, 'zeroshot'+args.train_dataset+'.pt')
40 | classifier.save(zeroshot_checkpoint)
41 |
42 | # Standard fine-tuning
43 | args.load = zeroshot_checkpoint
44 | args.save = os.path.join(args.save, 'finetuned')
45 |
46 | # Mimic eurosat low-res images, val data aug
47 | train_data_aug = Compose([
48 | # Resize(64), # resize to 32/64 for Cifar / Eurosat
49 | Resize(224, interpolation=Image.BICUBIC),
50 | CenterCrop(224),
51 | _convert_to_rgb,
52 | ToTensor(),
53 | normalize,
54 | ])
55 |
56 | wandb.init(project=f"eurosat-fewshot", config=args)
57 | finetuned_checkpoint = finetune_fsl(args, train_data_aug)
58 |
59 |
60 | if __name__ == '__main__':
61 | args = parse_arguments()
62 | classifier_tuning(args)
63 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/ct_zsl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(
4 | '/h/zycluke/Project/diffusion_g/src/SyntheticData/src/classifier-tuning')
5 |
6 | from src.models.finetune import finetune_zsl
7 | from src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier
8 | from src.models.zeroshot import get_zeroshot_classifier
9 | from src.args import parse_arguments
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomResizedCrop, RandomHorizontalFlip
11 | from PIL import Image
12 |
13 | def _convert_to_rgb(image):
14 | return image.convert('RGB')
15 |
16 | normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
17 |
18 | def classifier_tuning(args):
19 | assert args.save is not None, 'Please provide a path to store models'
20 | print('import success')
21 |
22 | # Build and save zero-shot model
23 | image_encoder = ImageEncoder(args, keep_lang=True)
24 | classification_head = get_zeroshot_classifier(args, image_encoder.model)
25 | delattr(image_encoder.model, 'transformer')
26 | classifier = ImageClassifier(image_encoder, classification_head, process_images=False)
27 |
28 | zeroshot_checkpoint = os.path.join(args.save, 'zeroshot'+args.train_dataset+'.pt')
29 | classifier.save(zeroshot_checkpoint)
30 |
31 | # Standard fine-tuning
32 | args.load = zeroshot_checkpoint
33 | args.save = os.path.join(args.save, 'finetuned')
34 |
35 | # Mimic eurosat low-res images, val data aug
36 | train_data_aug = Compose([
37 | # Resize(64), # resize to 32/64 for Cifar / Eurosat
38 | Resize(224, interpolation=Image.BICUBIC),
39 | CenterCrop(224),
40 | _convert_to_rgb,
41 | ToTensor(),
42 | normalize,
43 | ])
44 |
45 | finetuned_checkpoint = finetune_zsl(args, train_data_aug)
46 |
47 |
48 |
49 | if __name__ == '__main__':
50 | args = parse_arguments()
51 | classifier_tuning(args)
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .cifar10 import *
2 | from .cifar100 import *
3 | # from .fmow import FMOWID, FMOWOOD, FMOW
4 | from .imagenet import ImageNet as ImageNet_ft
5 | # from .imagenetv2 import ImageNetV2
6 | # from .imagenet_a import ImageNetAValClasses, ImageNetA
7 | # from .imagenet_r import ImageNetRValClasses, ImageNetR
8 | # from .imagenet_sketch import ImageNetSketch
9 | # from .imagenet_vid_robust import ImageNetVidRobustValClasses, ImageNetVidRobust
10 | # from .iwildcam import IWildCamID, IWildCamOOD, IWildCamIDNonEmpty, IWildCamOODNonEmpty, IWildCam
11 | # from .objectnet import ObjectNetValClasses, ObjectNet
12 | # from .ytbb_robust import YTBBRobustValClasses, YTBBRobust
13 | # from .transfer_ds.transfer_datasets import *
14 | from .eurosat import EuroSAT
15 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/cifar10.py:
--------------------------------------------------------------------------------
1 | import os
2 | import PIL
3 | import torch
4 | import numpy as np
5 | import torchvision
6 | from torchvision import transforms
7 | from torchvision.datasets import CIFAR10 as PyTorchCIFAR10
8 | from torchvision.datasets import VisionDataset
9 |
10 | cifar_classnames = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
11 |
12 | class CIFAR10_theirs:
13 | def __init__(self, preprocess,
14 | location=os.path.expanduser('~/data'),
15 | batch_size=128,
16 | num_workers=16,
17 | classnames=None):
18 |
19 |
20 | self.train_dataset = PyTorchCIFAR10(
21 | root=location, download=True, train=True, transform=preprocess
22 | )
23 |
24 | self.train_loader = torch.utils.data.DataLoader(
25 | self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
26 | )
27 |
28 | self.test_dataset = PyTorchCIFAR10(
29 | root=location, download=True, train=False, transform=preprocess
30 | )
31 |
32 | self.test_loader = torch.utils.data.DataLoader(
33 | self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
34 | )
35 |
36 | self.classnames = self.test_dataset.classes
37 |
38 | def convert(x):
39 | if isinstance(x, np.ndarray):
40 | return torchvision.transforms.functional.to_pil_image(x)
41 | return x
42 |
43 | class BasicVisionDataset(VisionDataset):
44 | def __init__(self, images, targets, transform=None, target_transform=None):
45 | if transform is not None:
46 | transform.transforms.insert(0, convert)
47 | super(BasicVisionDataset, self).__init__(root=None, transform=transform, target_transform=target_transform)
48 | assert len(images) == len(targets)
49 |
50 | self.images = images
51 | self.targets = targets
52 |
53 | def __getitem__(self, index):
54 | return self.transform(self.images[index]), self.targets[index]
55 |
56 | def __len__(self):
57 | return len(self.targets)
58 |
59 | class CIFAR101:
60 | def __init__(self,
61 | preprocess,
62 | location=os.path.expanduser('~/data'),
63 | batch_size=128,
64 | num_workers=16,
65 | classnames=None):
66 |
67 | data_root = os.path.join(location, "CIFAR-10.1")
68 | data = np.load(os.path.join(data_root, 'cifar10.1_v6_data.npy'), allow_pickle=True)
69 | labels = np.load(os.path.join(data_root, 'cifar10.1_v6_labels.npy'), allow_pickle=True)
70 |
71 | use_cuda = torch.cuda.is_available()
72 |
73 | # Data loading code
74 | kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {}
75 |
76 | self.train_loader = None
77 |
78 | self.test_dataset = BasicVisionDataset(
79 | images=data, targets=torch.Tensor(labels).long(),
80 | transform=preprocess,
81 | )
82 |
83 | self.test_loader = torch.utils.data.DataLoader(
84 | self.test_dataset, batch_size=batch_size, shuffle=False, **kwargs
85 | )
86 |
87 | self.classnames = cifar_classnames
88 |
89 |
90 | class CIFAR102:
91 | def __init__(self,
92 | preprocess,
93 | location=os.path.expanduser('~/data'),
94 | batch_size=128,
95 | num_workers=16,
96 | classnames=None):
97 |
98 | train_data = np.load(os.path.join(location, "CIFAR-10.2", 'cifar102_train.npy'), allow_pickle=True).item()
99 | test_data = np.load(os.path.join(location, "CIFAR-10.2", 'cifar102_test.npy'), allow_pickle=True).item()
100 |
101 |
102 | train_data_images = train_data['images']
103 | train_data_labels = train_data['labels']
104 |
105 | test_data_images = test_data['images']
106 | test_data_labels = test_data['labels']
107 |
108 | use_cuda = torch.cuda.is_available()
109 |
110 | # Data loading code
111 | kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {}
112 |
113 | self.test_dataset = BasicVisionDataset(
114 | images=test_data_images, targets=torch.Tensor(test_data_labels).long(),
115 | transform=preprocess,
116 | )
117 |
118 | self.test_loader = torch.utils.data.DataLoader(
119 | self.test_dataset, batch_size=batch_size, shuffle=False, **kwargs
120 | )
121 |
122 | self.classnames = cifar_classnames
123 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/cifar100.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torchvision.datasets import CIFAR100 as PyTorchCIFAR100
4 |
5 | class CIFAR100_theirs:
6 | def __init__(self,
7 | preprocess,
8 | location=os.path.expanduser('~/data'),
9 | batch_size=128,
10 | num_workers=16,
11 | classnames=None):
12 |
13 | self.train_dataset = PyTorchCIFAR100(
14 | root=location, download=True, train=True, transform=preprocess
15 | )
16 |
17 | self.train_loader = torch.utils.data.DataLoader(
18 | self.train_dataset, batch_size=batch_size, num_workers=num_workers
19 | )
20 |
21 | self.test_dataset = PyTorchCIFAR100(
22 | root=location, download=True, train=False, transform=preprocess
23 | )
24 |
25 | self.test_loader = torch.utils.data.DataLoader(
26 | self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
27 | )
28 |
29 | self.classnames = self.test_dataset.classes
30 |
31 |
32 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/common.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import json
4 | import glob
5 | import collections
6 | import random
7 |
8 | import numpy as np
9 |
10 | from tqdm import tqdm
11 |
12 | import torchvision.datasets as datasets
13 | from torch.utils.data import Dataset, DataLoader, Sampler
14 |
15 |
16 | class SubsetSampler(Sampler):
17 | def __init__(self, indices):
18 | self.indices = indices
19 |
20 | def __iter__(self):
21 | return (i for i in self.indices)
22 |
23 | def __len__(self):
24 | return len(self.indices)
25 |
26 | class ImageFolderWithPaths(datasets.ImageFolder):
27 | def __init__(self, path, transform, flip_label_prob=0.0):
28 | super().__init__(path, transform)
29 | self.flip_label_prob = flip_label_prob
30 | if self.flip_label_prob > 0:
31 | print(f'Flipping labels with probability {self.flip_label_prob}')
32 | num_classes = len(self.classes)
33 | for i in range(len(self.samples)):
34 | if random.random() < self.flip_label_prob:
35 | new_label = random.randint(0, num_classes-1)
36 | self.samples[i] = (
37 | self.samples[i][0],
38 | new_label
39 | )
40 |
41 | def __getitem__(self, index):
42 | image, label = super(ImageFolderWithPaths, self).__getitem__(index)
43 | return {
44 | 'images': image,
45 | 'labels': label,
46 | 'image_paths': self.samples[index][0]
47 | }
48 |
49 |
50 | def maybe_dictionarize(batch):
51 | if isinstance(batch, dict):
52 | return batch
53 |
54 | if len(batch) == 2:
55 | batch = {'images': batch[0], 'labels': batch[1]}
56 | elif len(batch) == 3:
57 | batch = {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]}
58 | else:
59 | raise ValueError(f'Unexpected number of elements: {len(batch)}')
60 |
61 | return batch
62 |
63 |
64 | def get_features_helper(image_encoder, dataloader, device):
65 | all_data = collections.defaultdict(list)
66 |
67 | image_encoder = image_encoder.to(device)
68 | # image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[x for x in range(torch.cuda.device_count())])
69 | image_encoder.eval()
70 |
71 | with torch.no_grad():
72 | for batch in tqdm(dataloader):
73 | batch = maybe_dictionarize(batch)
74 | features = image_encoder(batch['images'].cuda())
75 |
76 | all_data['features'].append(features.cpu())
77 |
78 | for key, val in batch.items():
79 | if key == 'images':
80 | continue
81 | if hasattr(val, 'cpu'):
82 | val = val.cpu()
83 | all_data[key].append(val)
84 | else:
85 | all_data[key].extend(val)
86 | del batch, features
87 |
88 | for key, val in all_data.items():
89 | if torch.is_tensor(val[0]):
90 | all_data[key] = torch.cat(val).numpy()
91 |
92 | return all_data
93 |
94 |
95 | def get_features(is_train, image_encoder, dataset, device, is_real=True, args=None):
96 | split = 'train' if is_train else 'val'
97 | if is_real is False:
98 | split= args.cache_name_syn
99 | dname = type(dataset).__name__
100 | cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}'
101 | if image_encoder.cache_dir is not None:
102 | cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}'
103 | cached_files = glob.glob(f'{cache_dir}/*')
104 | if image_encoder.cache_dir is not None and len(cached_files) > 0:
105 | print(f'Getting features from {cache_dir}')
106 | data = {}
107 | for cached_file in cached_files:
108 | name = os.path.splitext(os.path.basename(cached_file))[0]
109 | data[name] = torch.load(cached_file)
110 | else:
111 | # import ipdb
112 | # ipdb.set_trace(context=20)
113 | print(f'Did not find cached features at {cache_dir}. Building from scratch.')
114 | loader = dataset.train_loader if is_train else dataset.test_loader
115 | data = get_features_helper(image_encoder, loader, device)
116 | if image_encoder.cache_dir is None:
117 | print('Not caching because no cache directory was passed.')
118 | else:
119 | os.makedirs(cache_dir, exist_ok=True)
120 | print(f'Caching data at {cache_dir}')
121 | for name, val in data.items():
122 | torch.save(val, f'{cache_dir}/{name}.pt', pickle_protocol=4)
123 | return data
124 |
125 |
126 | class FeatureDataset(Dataset):
127 | def __init__(self, is_train, image_encoder, dataset, device, is_real=True, args=None):
128 | self.data = get_features(is_train, image_encoder, dataset, device, is_real, args=args)
129 |
130 | def __len__(self):
131 | return len(self.data['features'])
132 |
133 | def __getitem__(self, idx):
134 | data = {k: v[idx] for k, v in self.data.items()}
135 | data['features'] = torch.from_numpy(data['features']).float()
136 | return data
137 |
138 |
139 | def get_dataloader(dataset, is_train, args, image_encoder=None, is_real=True):
140 | if image_encoder is not None:
141 | feature_dataset = FeatureDataset(
142 | is_train, image_encoder, dataset, args.device, is_real, args=args)
143 | dataloader = DataLoader(feature_dataset, batch_size=args.batch_size, shuffle=is_train)
144 | else:
145 | dataloader = dataset.train_loader if is_train else dataset.test_loader
146 | return dataloader
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/eurosat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from .common import ImageFolderWithPaths, SubsetSampler
5 | import numpy as np
6 |
7 | class EuroSAT:
8 | def __init__(self,
9 | preprocess,
10 | location=os.path.expanduser('~/data'),
11 | batch_size=32,
12 | num_workers=32,
13 | classnames='eurosat'):
14 | self.preprocess = preprocess
15 | self.location = location
16 | self.batch_size = batch_size
17 | self.num_workers = num_workers
18 | self.classnames = ['Annual Crop Land', 'Forest', 'Herbaceous Vegetation Land',
19 | 'Highway or Road', 'Industrial Building', 'Pasture Land',
20 | 'Permanent Crop Land', 'Residential Building', 'River', 'Sea or Lake']
21 |
22 | self.populate_train()
23 | try:
24 | self.populate_test()
25 | except:
26 | self.test_dataset = None
27 | self.test_loader = None
28 |
29 | def populate_train(self):
30 | traindir = os.path.join(self.location, 'train')
31 | self.train_dataset = ImageFolderWithPaths(
32 | traindir,
33 | transform=self.preprocess)
34 | sampler = self.get_train_sampler()
35 | kwargs = {'shuffle' : True} if sampler is None else {}
36 | self.train_loader = torch.utils.data.DataLoader(
37 | self.train_dataset,
38 | sampler=sampler,
39 | batch_size=self.batch_size,
40 | num_workers=self.num_workers,
41 | **kwargs,
42 | )
43 |
44 | def populate_test(self):
45 | self.test_dataset = self.get_test_dataset()
46 | self.test_loader = torch.utils.data.DataLoader(
47 | self.test_dataset,
48 | batch_size=self.batch_size,
49 | num_workers=self.num_workers,
50 | sampler=self.get_test_sampler()
51 | )
52 |
53 | def get_test_path(self):
54 | test_path = os.path.join(self.location, 'val_in_folder')
55 | if not os.path.exists(test_path):
56 | test_path = os.path.join(self.location, 'val')
57 | return test_path
58 |
59 | def get_train_sampler(self):
60 | return None
61 |
62 | def get_test_sampler(self):
63 | return None
64 |
65 | def get_test_dataset(self):
66 | return ImageFolderWithPaths(self.get_test_path(), transform=self.preprocess)
67 |
68 | def name(self):
69 | return 'eurosat'
70 |
71 | class EuroSATrain(EuroSAT):
72 |
73 | def get_test_dataset(self):
74 | pass
75 |
76 | class EuroSATK(EuroSAT):
77 |
78 | def get_train_sampler(self):
79 | idxs = np.zeros(len(self.train_dataset.targets))
80 | target_array = np.array(self.train_dataset.targets)
81 | for c in range(1000):
82 | m = target_array == c
83 | n = len(idxs[m])
84 | arr = np.zeros(n)
85 | arr[:self.k()] = 1
86 | np.random.shuffle(arr)
87 | idxs[m] = arr
88 |
89 | idxs = idxs.astype('int')
90 | sampler = SubsetSampler(np.where(idxs)[0])
91 | return sampler
92 |
93 |
94 | def project_logits(logits, class_sublist_mask, device):
95 | if isinstance(logits, list):
96 | return [project_logits(l, class_sublist_mask, device) for l in logits]
97 | if logits.size(1) > sum(class_sublist_mask):
98 | return logits[:, class_sublist_mask].to(device)
99 | else:
100 | return logits.to(device)
101 |
102 | class EuroSATSubsample(EuroSAT):
103 | def __init__(self, *args, **kwargs):
104 | super().__init__(*args, **kwargs)
105 | class_sublist, self.class_sublist_mask = self.get_class_sublist_and_mask()
106 | self.classnames = [self.classnames[i] for i in class_sublist]
107 |
108 | def get_class_sublist_and_mask(self):
109 | raise NotImplementedError()
110 |
111 | def populate_train(self):
112 | pass
113 |
114 | def project_logits(self, logits, device):
115 | return project_logits(logits, self.class_sublist_mask, device)
116 |
117 | class EuroSATSubsampleValClasses(EuroSAT):
118 | def get_class_sublist_and_mask(self):
119 | raise NotImplementedError()
120 |
121 | def populate_train(self):
122 | pass
123 |
124 | def get_test_sampler(self):
125 | self.class_sublist, self.class_sublist_mask = self.get_class_sublist_and_mask()
126 | idx_subsample_list = [range(x * 50, (x + 1) * 50) for x in self.class_sublist]
127 | idx_subsample_list = sorted([item for sublist in idx_subsample_list for item in sublist])
128 |
129 | sampler = SubsetSampler(idx_subsample_list)
130 | return sampler
131 |
132 | def project_labels(self, labels, device):
133 | projected_labels = [self.class_sublist.index(int(label)) for label in labels]
134 | return torch.LongTensor(projected_labels).to(device)
135 |
136 | def project_logits(self, logits, device):
137 | return project_logits(logits, self.class_sublist_mask, device)
138 |
139 | # ks = [1, 2, 4, 8, 16, 25, 32, 50, 64, 128, 600]
140 |
141 | # for k in ks:
142 | # cls_name = f"ImageNet{k}"
143 | # dyn_cls = type(cls_name, (ImageNetK, ), {
144 | # "k": lambda self, num_samples=k: num_samples,
145 | # })
146 | # globals()[cls_name] = dyn_cls
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/fmow.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import wilds
4 |
5 | from torchvision.datasets import CIFAR10 as PyTorchCIFAR10
6 | from wilds.common.data_loaders import get_train_loader, get_eval_loader
7 |
8 | class FMOW:
9 | test_subset = None
10 |
11 | def __init__(self,
12 | preprocess,
13 | location=os.path.expanduser('~/data'),
14 | batch_size=128,
15 | num_workers=16,
16 | subset='test',
17 | classnames=None,
18 | **kwargs):
19 |
20 | self.dataset = wilds.get_dataset(dataset='fmow', root_dir=location)
21 |
22 | self.train_dataset = self.dataset.get_subset('train', transform=preprocess)
23 | self.train_loader = get_train_loader("standard", self.train_dataset, num_workers=num_workers, batch_size=batch_size)
24 |
25 | self.test_dataset = self.dataset.get_subset(self.test_subset, transform=preprocess)
26 | self.test_loader = get_eval_loader("standard", self.test_dataset, num_workers=num_workers, batch_size=batch_size)
27 |
28 | self.classnames = [
29 | "airport", "airport_hangar", "airport_terminal", "amusement_park", "aquaculture",
30 | "archaeological_site", "barn", "border_checkpoint", "burial_site", "car_dealership",
31 | "construction_site", "crop_field", "dam", "debris_or_rubble", "educational_institution",
32 | "electric_substation", "factory_or_powerplant", "fire_station", "flooded_road", "fountain",
33 | "gas_station", "golf_course", "ground_transportation_station", "helipad", "hospital",
34 | "impoverished_settlement", "interchange", "lake_or_pond", "lighthouse", "military_facility",
35 | "multi-unit_residential", "nuclear_powerplant", "office_building", "oil_or_gas_facility", "park",
36 | "parking_lot_or_garage", "place_of_worship", "police_station", "port", "prison", "race_track",
37 | "railway_bridge", "recreational_facility", "road_bridge", "runway", "shipyard", "shopping_mall",
38 | "single-unit_residential", "smokestack", "solar_farm", "space_facility", "stadium", "storage_tank",
39 | "surface_mine", "swimming_pool", "toll_booth", "tower", "tunnel_opening", "waste_disposal",
40 | "water_treatment_facility", "wind_farm", "zoo"
41 | ]
42 |
43 | def post_loop_metrics(self, labels, preds, metadata, args):
44 | metadata = torch.stack(metadata)
45 | preds = preds.argmax(dim=1, keepdim=True).view_as(labels)
46 | results = self.dataset.eval(preds, labels, metadata)
47 | return results[0]
48 |
49 | class FMOWID(FMOW):
50 | def __init__(self, *args, **kwargs):
51 | self.test_subset = 'id_test'
52 | super().__init__(*args, **kwargs)
53 |
54 | class FMOWOOD(FMOW):
55 | def __init__(self, *args, **kwargs):
56 | self.test_subset = 'test'
57 | super().__init__(*args, **kwargs)
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from .common import ImageFolderWithPaths, SubsetSampler
5 | from .imagenet_classnames import get_classnames
6 | import numpy as np
7 |
8 | class ImageNet:
9 | def __init__(self,
10 | preprocess,
11 | location=os.path.expanduser('~/data'),
12 | batch_size=32,
13 | num_workers=32,
14 | classnames='openai'):
15 | self.preprocess = preprocess
16 | self.location = location
17 | self.batch_size = batch_size
18 | self.num_workers = num_workers
19 | self.classnames = get_classnames(classnames)
20 |
21 | self.populate_train()
22 | try:
23 | self.populate_test()
24 | except:
25 | self.test_dataset = None
26 | self.test_loader = None
27 |
28 | def populate_train(self):
29 | traindir = os.path.join(self.location, 'train')
30 | self.train_dataset = ImageFolderWithPaths(
31 | traindir,
32 | transform=self.preprocess)
33 | sampler = self.get_train_sampler()
34 | kwargs = {'shuffle' : True} if sampler is None else {}
35 | self.train_loader = torch.utils.data.DataLoader(
36 | self.train_dataset,
37 | sampler=sampler,
38 | batch_size=self.batch_size,
39 | num_workers=self.num_workers,
40 | **kwargs,
41 | )
42 |
43 | def populate_test(self):
44 | self.test_dataset = self.get_test_dataset()
45 | self.test_loader = torch.utils.data.DataLoader(
46 | self.test_dataset,
47 | batch_size=self.batch_size,
48 | num_workers=self.num_workers,
49 | sampler=self.get_test_sampler()
50 | )
51 |
52 | def get_test_path(self):
53 | test_path = os.path.join(self.location, 'val_in_folder')
54 | if not os.path.exists(test_path):
55 | test_path = os.path.join(self.location, 'val')
56 | return test_path
57 |
58 | def get_train_sampler(self):
59 | return None
60 |
61 | def get_test_sampler(self):
62 | return None
63 |
64 | def get_test_dataset(self):
65 | return ImageFolderWithPaths(self.get_test_path(), transform=self.preprocess)
66 |
67 | def name(self):
68 | return 'imagenet'
69 |
70 | class ImageNetTrain(ImageNet):
71 |
72 | def get_test_dataset(self):
73 | pass
74 |
75 | class ImageNetK(ImageNet):
76 |
77 | def get_train_sampler(self):
78 | idxs = np.zeros(len(self.train_dataset.targets))
79 | target_array = np.array(self.train_dataset.targets)
80 | for c in range(1000):
81 | m = target_array == c
82 | n = len(idxs[m])
83 | arr = np.zeros(n)
84 | arr[:self.k()] = 1
85 | np.random.shuffle(arr)
86 | idxs[m] = arr
87 |
88 | idxs = idxs.astype('int')
89 | sampler = SubsetSampler(np.where(idxs)[0])
90 | return sampler
91 |
92 |
93 | def project_logits(logits, class_sublist_mask, device):
94 | if isinstance(logits, list):
95 | return [project_logits(l, class_sublist_mask, device) for l in logits]
96 | if logits.size(1) > sum(class_sublist_mask):
97 | return logits[:, class_sublist_mask].to(device)
98 | else:
99 | return logits.to(device)
100 |
101 | class ImageNetSubsample(ImageNet):
102 | def __init__(self, *args, **kwargs):
103 | super().__init__(*args, **kwargs)
104 | class_sublist, self.class_sublist_mask = self.get_class_sublist_and_mask()
105 | self.classnames = [self.classnames[i] for i in class_sublist]
106 |
107 | def get_class_sublist_and_mask(self):
108 | raise NotImplementedError()
109 |
110 | def populate_train(self):
111 | pass
112 |
113 | def project_logits(self, logits, device):
114 | return project_logits(logits, self.class_sublist_mask, device)
115 |
116 | class ImageNetSubsampleValClasses(ImageNet):
117 | def get_class_sublist_and_mask(self):
118 | raise NotImplementedError()
119 |
120 | def populate_train(self):
121 | pass
122 |
123 | def get_test_sampler(self):
124 | self.class_sublist, self.class_sublist_mask = self.get_class_sublist_and_mask()
125 | idx_subsample_list = [range(x * 50, (x + 1) * 50) for x in self.class_sublist]
126 | idx_subsample_list = sorted([item for sublist in idx_subsample_list for item in sublist])
127 |
128 | sampler = SubsetSampler(idx_subsample_list)
129 | return sampler
130 |
131 | def project_labels(self, labels, device):
132 | projected_labels = [self.class_sublist.index(int(label)) for label in labels]
133 | return torch.LongTensor(projected_labels).to(device)
134 |
135 | def project_logits(self, logits, device):
136 | return project_logits(logits, self.class_sublist_mask, device)
137 |
138 | ks = [1, 2, 4, 8, 16, 25, 32, 50, 64, 128, 600]
139 |
140 | for k in ks:
141 | cls_name = f"ImageNet{k}"
142 | dyn_cls = type(cls_name, (ImageNetK, ), {
143 | "k": lambda self, num_samples=k: num_samples,
144 | })
145 | globals()[cls_name] = dyn_cls
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/imagenet_a.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 |
5 | from .imagenet import ImageNetSubsample, ImageNetSubsampleValClasses
6 | import numpy as np
7 |
8 |
9 | CLASS_SUBLIST = [
10 | 6, 11, 13, 15, 17, 22, 23, 27, 30, 37, 39, 42, 47, 50, 57, 70, 71, 76, 79, 89, 90, 94, 96, 97, 99, 105, 107,
11 | 108, 110,
12 | 113, 124, 125, 130, 132, 143, 144, 150, 151, 207, 234, 235, 254, 277, 283, 287, 291, 295, 298, 301, 306, 307,
13 | 308, 309,
14 | 310, 311, 313, 314, 315, 317, 319, 323, 324, 326, 327, 330, 334, 335, 336, 347, 361, 363, 372, 378, 386, 397,
15 | 400, 401,
16 | 402, 404, 407, 411, 416, 417, 420, 425, 428, 430, 437, 438, 445, 456, 457, 461, 462, 470, 472, 483, 486, 488,
17 | 492, 496,
18 | 514, 516, 528, 530, 539, 542, 543, 549, 552, 557, 561, 562, 569, 572, 573, 575, 579, 589, 606, 607, 609, 614,
19 | 626, 627,
20 | 640, 641, 642, 643, 658, 668, 677, 682, 684, 687, 701, 704, 719, 736, 746, 749, 752, 758, 763, 765, 768, 773,
21 | 774, 776,
22 | 779, 780, 786, 792, 797, 802, 803, 804, 813, 815, 820, 823, 831, 833, 835, 839, 845, 847, 850, 859, 862, 870,
23 | 879, 880,
24 | 888, 890, 897, 900, 907, 913, 924, 932, 933, 934, 937, 943, 945, 947, 951, 954, 956, 957, 959, 971, 972, 980,
25 | 981, 984,
26 | 986, 987, 988]
27 | CLASS_SUBLIST_MASK = [(i in CLASS_SUBLIST) for i in range(1000)]
28 |
29 |
30 | class ImageNetAValClasses(ImageNetSubsampleValClasses):
31 | def get_class_sublist_and_mask(self):
32 | return CLASS_SUBLIST, CLASS_SUBLIST_MASK
33 |
34 |
35 | class ImageNetA(ImageNetSubsample):
36 | def get_class_sublist_and_mask(self):
37 | return CLASS_SUBLIST, CLASS_SUBLIST_MASK
38 |
39 | def get_test_path(self):
40 | return os.path.join(self.location, 'imagenet-a')
41 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/imagenet_r.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 |
5 | from .imagenet import ImageNetSubsample, ImageNetSubsampleValClasses
6 | import numpy as np
7 |
8 |
9 | CLASS_SUBLIST = [
10 | 1, 2, 4, 6, 8, 9, 11, 13, 22, 23, 26, 29, 31, 39, 47, 63, 71, 76, 79, 84, 90, 94, 96, 97, 99, 100, 105, 107,
11 | 113, 122,
12 | 125, 130, 132, 144, 145, 147, 148, 150, 151, 155, 160, 161, 162, 163, 171, 172, 178, 187, 195, 199, 203,
13 | 207, 208, 219,
14 | 231, 232, 234, 235, 242, 245, 247, 250, 251, 254, 259, 260, 263, 265, 267, 269, 276, 277, 281, 288, 289,
15 | 291, 292, 293,
16 | 296, 299, 301, 308, 309, 310, 311, 314, 315, 319, 323, 327, 330, 334, 335, 337, 338, 340, 341, 344, 347,
17 | 353, 355, 361,
18 | 362, 365, 366, 367, 368, 372, 388, 390, 393, 397, 401, 407, 413, 414, 425, 428, 430, 435, 437, 441, 447,
19 | 448, 457, 462,
20 | 463, 469, 470, 471, 472, 476, 483, 487, 515, 546, 555, 558, 570, 579, 583, 587, 593, 594, 596, 609, 613,
21 | 617, 621, 629,
22 | 637, 657, 658, 701, 717, 724, 763, 768, 774, 776, 779, 780, 787, 805, 812, 815, 820, 824, 833, 847, 852,
23 | 866, 875, 883,
24 | 889, 895, 907, 928, 931, 932, 933, 934, 936, 937, 943, 945, 947, 948, 949, 951, 953, 954, 957, 963, 965,
25 | 967, 980, 981,
26 | 983, 988]
27 | CLASS_SUBLIST_MASK = [(i in CLASS_SUBLIST) for i in range(1000)]
28 |
29 |
30 | class ImageNetRValClasses(ImageNetSubsampleValClasses):
31 | def get_class_sublist_and_mask(self):
32 | return CLASS_SUBLIST, CLASS_SUBLIST_MASK
33 |
34 | class ImageNetR(ImageNetSubsample):
35 | def get_class_sublist_and_mask(self):
36 | return CLASS_SUBLIST, CLASS_SUBLIST_MASK
37 |
38 | def get_test_path(self):
39 | return os.path.join(self.location, 'imagenet-r')
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/imagenet_sketch.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .imagenet import ImageNet
3 |
4 |
5 | class ImageNetSketch(ImageNet):
6 |
7 | def populate_train(self):
8 | pass
9 |
10 | def get_test_path(self):
11 | return os.path.join(self.location, 'sketch')
12 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/imagenetv2.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 | from imagenetv2_pytorch import ImageNetV2Dataset
4 |
5 | from .imagenet import ImageNet
6 |
7 | class ImageNetV2DatasetWithPaths(ImageNetV2Dataset):
8 | def __getitem__(self, i):
9 | img, label = Image.open(self.fnames[i]), int(self.fnames[i].parent.name)
10 | if self.transform is not None:
11 | img = self.transform(img)
12 | return {
13 | 'images': img,
14 | 'labels': label,
15 | 'image_paths': str(self.fnames[i])
16 | }
17 |
18 | class ImageNetV2(ImageNet):
19 | def get_test_dataset(self):
20 | return ImageNetV2DatasetWithPaths(transform=self.preprocess, location=self.location)
21 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/iwildcam.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import json
4 | import numpy as np
5 | import pathlib
6 |
7 | import wilds
8 | from wilds.common.data_loaders import get_train_loader, get_eval_loader
9 | from wilds.datasets.wilds_dataset import WILDSSubset
10 |
11 |
12 | def get_mask_non_empty(dataset):
13 | metadf = pd.read_csv(dataset._data_dir / 'metadata.csv')
14 | filename = os.path.expanduser(dataset._data_dir / 'iwildcam2020_megadetector_results.json')
15 | with open(filename, 'r') as f:
16 | md_data = json.load(f)
17 | id_to_maxdet = {x['id']: x['max_detection_conf'] for x in md_data['images']}
18 | threshold = 0.95
19 | mask_non_empty = [id_to_maxdet[x] >= threshold for x in metadf['image_id']]
20 | return mask_non_empty
21 |
22 |
23 | def get_nonempty_subset(dataset, split, frac=1.0, transform=None):
24 | if split not in dataset.split_dict:
25 | raise ValueError(f"Split {split} not found in dataset's split_dict.")
26 | split_mask = dataset.split_array == dataset.split_dict[split]
27 |
28 | # intersect split mask with non_empty. here is the only place this fn differs
29 | # from https://github.com/p-lambda/wilds/blob/main/wilds/datasets/wilds_dataset.py#L56
30 | mask_non_empty = get_mask_non_empty(dataset)
31 | split_mask = split_mask & mask_non_empty
32 |
33 | split_idx = np.where(split_mask)[0]
34 | if frac < 1.0:
35 | num_to_retain = int(np.round(float(len(split_idx)) * frac))
36 | split_idx = np.sort(np.random.permutation(split_idx)[:num_to_retain])
37 | subset = WILDSSubset(dataset, split_idx, transform)
38 | return subset
39 |
40 |
41 | class IWildCam:
42 | def __init__(self,
43 | preprocess,
44 | location=os.path.expanduser('~/data'),
45 | remove_non_empty=False,
46 | batch_size=128,
47 | num_workers=16,
48 | classnames=None,
49 | subset='train'):
50 | self.dataset = wilds.get_dataset(dataset='iwildcam', root_dir=location)
51 | self.train_dataset = self.dataset.get_subset('train', transform=preprocess)
52 | self.train_loader = get_train_loader("standard", self.train_dataset, num_workers=num_workers, batch_size=batch_size)
53 |
54 | if remove_non_empty:
55 | self.train_dataset = get_nonempty_subset(self.dataset, 'train', transform=preprocess)
56 | else:
57 | self.train_dataset = self.dataset.get_subset('train', transform=preprocess)
58 |
59 | if remove_non_empty:
60 | self.test_dataset = get_nonempty_subset(self.dataset, subset, transform=preprocess)
61 | else:
62 | self.test_dataset = self.dataset.get_subset(subset, transform=preprocess)
63 |
64 | self.test_loader = get_eval_loader(
65 | "standard", self.test_dataset,
66 | num_workers=num_workers,
67 | batch_size=batch_size)
68 |
69 | labels_csv = pathlib.Path(__file__).parent / 'iwildcam_metadata' / 'labels.csv'
70 | df = pd.read_csv(labels_csv)
71 | df = df[df['y'] < 99999]
72 |
73 | self.classnames = [s.lower() for s in list(df['english'])]
74 |
75 | def post_loop_metrics(self, labels, preds, metadata, args):
76 | preds = preds.argmax(dim=1, keepdim=True).view_as(labels)
77 | results = self.dataset.eval(preds, labels, metadata)
78 | return results[0]
79 |
80 |
81 | class IWildCamID(IWildCam):
82 | def __init__(self, *args, **kwargs):
83 | kwargs['subset'] = 'id_test'
84 | super().__init__(*args, **kwargs)
85 |
86 |
87 | class IWildCamOOD(IWildCam):
88 | def __init__(self, *args, **kwargs):
89 | kwargs['subset'] = 'test'
90 | super().__init__(*args, **kwargs)
91 |
92 |
93 | class IWildCamNonEmpty(IWildCam):
94 | def __init__(self, *args, **kwargs):
95 | kwargs['subset'] = 'train'
96 | super().__init__(*args, **kwargs)
97 |
98 |
99 | class IWildCamIDNonEmpty(IWildCam):
100 | def __init__(self, *args, **kwargs):
101 | kwargs['subset'] = 'id_test'
102 | super().__init__(*args, **kwargs)
103 |
104 |
105 | class IWildCamOODNonEmpty(IWildCam):
106 | def __init__(self, *args, **kwargs):
107 | kwargs['subset'] = 'test'
108 | super().__init__(*args, **kwargs)
109 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/objectnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from pathlib import Path
4 | import PIL
5 |
6 | import numpy as np
7 |
8 | import torch
9 | from torchvision import datasets
10 | from torchvision.transforms import Compose
11 |
12 | from .common import ImageFolderWithPaths, SubsetSampler
13 | from .imagenet import ImageNet, ImageNetSubsampleValClasses
14 |
15 |
16 | def get_metadata():
17 | metadata = Path(__file__).parent / 'objectnet_metadata'
18 |
19 | with open(metadata / 'folder_to_objectnet_label.json', 'r') as f:
20 | folder_map = json.load(f)
21 | folder_map = {v: k for k, v in folder_map.items()}
22 | with open(metadata / 'objectnet_to_imagenet_1k.json', 'r') as f:
23 | objectnet_map = json.load(f)
24 |
25 | with open(metadata / 'pytorch_to_imagenet_2012_id.json', 'r') as f:
26 | pytorch_map = json.load(f)
27 | pytorch_map = {v: k for k, v in pytorch_map.items()}
28 |
29 | with open(metadata / 'imagenet_to_label_2012_v2', 'r') as f:
30 | imagenet_map = {v.strip(): str(pytorch_map[i]) for i, v in enumerate(f)}
31 |
32 | folder_to_ids, class_sublist = {}, []
33 | classnames = []
34 | for objectnet_name, imagenet_names in objectnet_map.items():
35 | imagenet_names = imagenet_names.split('; ')
36 | imagenet_ids = [int(imagenet_map[imagenet_name]) for imagenet_name in imagenet_names]
37 | class_sublist.extend(imagenet_ids)
38 | folder_to_ids[folder_map[objectnet_name]] = imagenet_ids
39 |
40 | class_sublist = sorted(class_sublist)
41 | class_sublist_mask = [(i in class_sublist) for i in range(1000)]
42 | classname_map = {v: k for k, v in folder_map.items()}
43 | return class_sublist, class_sublist_mask, folder_to_ids, classname_map
44 |
45 |
46 | def crop(img):
47 | width, height = img.size
48 | cropArea = (2, 2, width - 2, height - 2)
49 | img = img.crop(cropArea)
50 | return img
51 |
52 |
53 | class ObjectNetDataset(datasets.ImageFolder):
54 |
55 | def __init__(self, label_map, path, transform):
56 | self.label_map = label_map
57 | super().__init__(path, transform=transform)
58 | self.samples = [
59 | d for d in self.samples
60 | if os.path.basename(os.path.dirname(d[0])) in self.label_map
61 | ]
62 | self.imgs = self.samples
63 |
64 | def __len__(self):
65 | return len(self.samples)
66 |
67 | def __getitem__(self, index):
68 | path, target = self.samples[index]
69 | sample = self.loader(path)
70 | if self.transform is not None:
71 | sample = self.transform(sample)
72 | label = os.path.basename(os.path.dirname(path))
73 | return {
74 | 'images': sample,
75 | 'labels': self.label_map[label],
76 | 'image_paths': path
77 | }
78 |
79 |
80 | class ObjectNetBase(ImageNet):
81 | def __init__(self, *args, **kwargs):
82 | (self._class_sublist,
83 | self.class_sublist_mask,
84 | self.folders_to_ids,
85 | self.classname_map) = get_metadata()
86 |
87 | super().__init__(*args, **kwargs)
88 |
89 | self.classnames = sorted(list(self.folders_to_ids.keys()))
90 | self.rev_class_idx_map = {}
91 | self.class_idx_map = {}
92 | for idx, name in enumerate(self.classnames):
93 | self.rev_class_idx_map[idx] = self.folders_to_ids[name]
94 | for imagenet_idx in self.rev_class_idx_map[idx]:
95 | self.class_idx_map[imagenet_idx] = idx
96 |
97 | self.crop = crop
98 | self.preprocess = Compose([crop, self.preprocess])
99 | self.classnames = [self.classname_map[c].lower() for c in self.classnames]
100 |
101 | def populate_train(self):
102 | pass
103 |
104 | def get_test_dataset(self):
105 | subdir = 'objectnet-1.0/images'
106 | valdir = os.path.join(self.location, subdir)
107 | label_map = {name: idx for idx, name in enumerate(sorted(list(self.folders_to_ids.keys())))}
108 | return ObjectNetDataset(label_map, valdir, transform=self.preprocess)
109 |
110 | def project_logits(self, logits, device):
111 | if isinstance(logits, list) or isinstance(logits, tuple):
112 | return [self.project_logits(l, device) for l in logits]
113 | if logits.shape[1] == 113:
114 | return logits
115 | if torch.is_tensor(logits):
116 | logits = logits.cpu().numpy()
117 | logits_projected = np.zeros((logits.shape[0], 113))
118 | for k, v in self.rev_class_idx_map.items():
119 | logits_projected[:, k] = np.max(logits[:, v], axis=1).squeeze()
120 | return torch.tensor(logits_projected).to(device)
121 |
122 | def scatter_weights(self, weights):
123 | if weights.size(1) == 1000:
124 | return weights
125 | new_weights = torch.ones((weights.size(0), 1000)).to(weights.device) * -10e8
126 | for k, v in self.rev_class_idx_map.items():
127 | for vv in v:
128 | new_weights[:, vv] = weights[:, k]
129 | return new_weights
130 |
131 |
132 |
133 | def accuracy(logits, targets, img_paths, args):
134 | assert logits.shape[1] == 113
135 | preds = logits.argmax(dim=1)
136 | if torch.is_tensor(preds):
137 | preds = preds.cpu().numpy()
138 | if torch.is_tensor(targets):
139 | targets = targets.cpu().numpy()
140 | return np.sum(preds == targets), len(preds)
141 |
142 |
143 | class ObjectNetValClasses(ObjectNetBase):
144 |
145 | def get_test_sampler(self):
146 | idx_subsample_list = [range(x * 50, (x + 1) * 50) for x in self._class_sublist]
147 | idx_subsample_list = sorted([item for sublist in idx_subsample_list for item in sublist])
148 |
149 | sampler = SubsetSampler(idx_subsample_list)
150 | return sampler
151 |
152 | def get_test_dataset(self):
153 | return ImageFolderWithPaths(self.get_test_path(), transform=self.preprocess)
154 |
155 | def project_labels(self, labels, device):
156 | projected_labels = [self.class_idx_map[int(label)] for label in labels]
157 | return torch.LongTensor(projected_labels).to(device)
158 |
159 |
160 | class ObjectNet(ObjectNetBase):
161 |
162 | def accuracy(self, logits, targets, img_paths, args):
163 | return accuracy(logits, targets, img_paths, args)
164 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/objectnet_metadata/objectnet_to_imagenet_1k.json:
--------------------------------------------------------------------------------
1 | {
2 | "Alarm clock": "analog clock; digital clock",
3 | "Backpack": "backpack, back pack, knapsack, packsack, rucksack, haversack",
4 | "Banana": "banana",
5 | "Band Aid": "Band Aid",
6 | "Basket": "shopping basket",
7 | "Bath towel": "bath towel",
8 | "Beer bottle": "beer bottle",
9 | "Bench": "park bench",
10 | "Bicycle": "mountain bike, all-terrain bike, off-roader; bicycle-built-for-two, tandem bicycle, tandem",
11 | "Binder (closed)": "binder, ring-binder",
12 | "Bottle cap": "bottlecap",
13 | "Bread loaf": "French loaf",
14 | "Broom": "broom",
15 | "Bucket": "bucket, pail",
16 | "Butcher's knife": "cleaver, meat cleaver, chopper",
17 | "Can opener": "can opener, tin opener",
18 | "Candle": "candle, taper, wax light",
19 | "Cellphone": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
20 | "Chair": "barber chair; folding chair; rocking chair, rocker",
21 | "Clothes hamper": "hamper",
22 | "Coffee/French press": "espresso maker",
23 | "Combination lock": "combination lock",
24 | "Computer mouse": "mouse, computer mouse",
25 | "Desk lamp": "table lamp",
26 | "Dishrag or hand towel": "dishrag, dishcloth",
27 | "Doormat": "doormat, welcome mat",
28 | "Dress shoe (men)": "Loafer",
29 | "Drill": "power drill",
30 | "Drinking Cup": "cup",
31 | "Drying rack for plates": "plate rack",
32 | "Envelope": "envelope",
33 | "Fan": "electric fan, blower",
34 | "Frying pan": "frying pan, frypan, skillet",
35 | "Dress": "gown",
36 | "Hair dryer": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
37 | "Hammer": "hammer",
38 | "Helmet": "football helmet; crash helmet",
39 | "Iron (for clothes)": "iron, smoothing iron",
40 | "Jeans": "jean, blue jean, denim",
41 | "Keyboard": "computer keyboard, keypad",
42 | "Ladle": "ladle",
43 | "Lampshade": "lampshade, lamp shade",
44 | "Laptop (open)": "laptop, laptop computer",
45 | "Lemon": "lemon",
46 | "Letter opener": "letter opener, paper knife, paperknife",
47 | "Lighter": "lighter, light, igniter, ignitor",
48 | "Lipstick": "lipstick, lip rouge",
49 | "Match": "matchstick",
50 | "Measuring cup": "measuring cup",
51 | "Microwave": "microwave, microwave oven",
52 | "Mixing / Salad Bowl": "mixing bowl",
53 | "Monitor": "monitor",
54 | "Mug": "coffee mug",
55 | "Nail (fastener)": "nail",
56 | "Necklace": "necklace",
57 | "Orange": "orange",
58 | "Padlock": "padlock",
59 | "Paintbrush": "paintbrush",
60 | "Paper towel": "paper towel",
61 | "Pen": "ballpoint, ballpoint pen, ballpen, Biro; quill, quill pen; fountain pen",
62 | "Pill bottle": "pill bottle",
63 | "Pillow": "pillow",
64 | "Pitcher": "pitcher, ewer",
65 | "Plastic bag": "plastic bag",
66 | "Plate": "plate",
67 | "Plunger": "plunger, plumber's helper",
68 | "Pop can": "pop bottle, soda bottle",
69 | "Portable heater": "space heater",
70 | "Printer": "printer",
71 | "Remote control": "remote control, remote",
72 | "Ruler": "rule, ruler",
73 | "Running shoe": "running shoe",
74 | "Safety pin": "safety pin",
75 | "Salt shaker": "saltshaker, salt shaker",
76 | "Sandal": "sandal",
77 | "Screw": "screw",
78 | "Shovel": "shovel",
79 | "Skirt": "hoopskirt, crinoline; miniskirt, mini; overskirt",
80 | "Sleeping bag": "sleeping bag",
81 | "Soap dispenser": "soap dispenser",
82 | "Sock": "sock",
83 | "Soup Bowl": "soup bowl",
84 | "Spatula": "spatula",
85 | "Speaker": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
86 | "Still Camera": "Polaroid camera, Polaroid Land camera; reflex camera",
87 | "Strainer": "strainer",
88 | "Stuffed animal": "teddy, teddy bear",
89 | "Suit jacket": "suit, suit of clothes",
90 | "Sunglasses": "sunglasses, dark glasses, shades",
91 | "Sweater": "sweatshirt",
92 | "Swimming trunks": "swimming trunks, bathing trunks",
93 | "T-shirt": "jersey, T-shirt, tee shirt",
94 | "TV": "television, television system",
95 | "Teapot": "teapot",
96 | "Tennis racket": "racket, racquet",
97 | "Tie": "bow tie, bow-tie, bowtie; Windsor tie",
98 | "Toaster": "toaster",
99 | "Toilet paper roll": "toilet tissue, toilet paper, bathroom tissue",
100 | "Trash bin": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
101 | "Tray": "tray",
102 | "Umbrella": "umbrella",
103 | "Vacuum cleaner": "vacuum, vacuum cleaner",
104 | "Vase": "vase",
105 | "Wallet": "wallet, billfold, notecase, pocketbook",
106 | "Watch": "digital watch",
107 | "Water bottle": "water bottle",
108 | "Weight (exercise)": "dumbbell",
109 | "Weight scale": "scale, weighing machine",
110 | "Wheel": "car wheel; paddlewheel, paddle wheel",
111 | "Whistle": "whistle",
112 | "Wine bottle": "wine bottle",
113 | "Winter glove": "mitten",
114 | "Wok": "wok"
115 | }
116 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/transfer_ds/cal_mean_std.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import datasets, transforms
3 |
4 | # dataset = datasets.ImageFolder('/opt/tiger/filter_transfer/data/PASS_dataset/train',transform=transforms.ToTensor())
5 | # dataset = datasets.ImageFolder('/opt/tiger/filter_transfer/data/imagenet/ILSVRC2012_img_train/train',transform=transforms.ToTensor())
6 | dataset = datasets.ImageFolder('/opt/tiger/filter_transfer/data/PASS_dataset/train',
7 | transform=transforms.Compose([transforms.Resize(256),
8 | transforms.CenterCrop(224),
9 | transforms.ToTensor()]))
10 |
11 | # --------- PASS
12 | # mean: tensor([0.4646, 0.4484, 0.4129])
13 | # std: tensor([0.2750, 0.2689, 0.2885])
14 |
15 | # dataset = datasets.ImageFolder('/opt/tiger/filter_transfer/data/imagenet/ILSVRC2012_img_train/train',
16 | # transform=transforms.Compose([transforms.Resize(256),
17 | # transforms.CenterCrop(224),
18 | # transforms.ToTensor()]))
19 |
20 | loader = torch.utils.data.DataLoader(dataset,
21 | batch_size=1000,
22 | num_workers=8,
23 | shuffle=False)
24 |
25 | # mean = 0.
26 | # meansq = 0.
27 | # i=0
28 | # for data,_ in loader:
29 | # print('{}/{}'.format(i,len(loader)))
30 | # i+=1
31 | # mean = data.mean()
32 | # meansq = (data ** 2).mean()
33 | #
34 | # std = torch.sqrt(meansq - mean ** 2)
35 | # print("mean: " + str(mean))
36 | # print("std: " + str(std))
37 | # print()
38 |
39 | # mean = 0.0
40 | # i=0
41 | # for images, _ in loader:
42 | # batch_samples = images.size(0)
43 | # images = images.view(batch_samples, images.size(1), -1)
44 | # mean += images.mean(2).sum(0)
45 | # print('{}/{}'.format(i, len(loader)))
46 | # i+=1
47 | # print(mean / i / 1000)
48 | # mean = mean / len(loader.dataset) / 1000
49 |
50 | # import ipdb
51 | # ipdb.set_trace(context=20)
52 | # mean = torch.FloatTensor([0.485, 0.456, 0.406])
53 | mean = torch.FloatTensor([0.4646, 0.4484, 0.4129])
54 | var = 0.0
55 | i=0
56 | for images, _ in loader:
57 | batch_samples = images.size(0)
58 | images = images.view(batch_samples, images.size(1), -1)
59 | var += ((images - mean.unsqueeze(1))**2).sum([0,2])
60 | print('{}/{}'.format(i, len(loader)))
61 | i += 1
62 | print(torch.sqrt(var / (i*224*224)))
63 | print(torch.sqrt(var / (i*1000*224*224)))
64 | std = torch.sqrt(var / (len(loader.dataset)*224*224))
65 |
66 | import ipdb
67 | ipdb.set_trace(context=20)
68 |
69 | a=1
70 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/transfer_ds/constants.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from .Randaug import RandAugment
3 |
4 | prefix = "data"
5 |
6 | # IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_train_100k/"
7 | # IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_train_128k/"
8 | # IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_train_200k/"
9 | IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_train/"
10 | # IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_train_plus15tasks/"
11 | # IMGNET_PATH = prefix + "/imagenet/imagenet.10_1000/"
12 | # IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_traintrain_200cls_640shot_128.0k/"
13 | # IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_traintrain_500cls_256shot_128.0k/"
14 | # IMGNET_PATH = prefix + "/place_128k"
15 | # IMGNET_PATH = prefix + "/pass/PASS_128k"
16 |
17 | # Planes dataset
18 | FGVC_PATH = prefix + "/fgvc-aircraft-2013b/"
19 |
20 | # Oxford Flowers dataset
21 | # FLOWERS_PATH = prefix + "/oxford_flowers_pytorch/"
22 | FLOWERS_PATH = prefix + "/flowers_new/"
23 |
24 | # DTD dataset
25 | DTD_PATH = prefix + "/dtd/"
26 |
27 | # Stanford Cars dataset
28 | CARS_PATH = prefix + "/cars_new"
29 |
30 | # SUN397 dataset
31 | SUN_PATH = prefix + "/SUN397/splits_01/"
32 |
33 | # FOOD dataset
34 | FOOD_PATH = prefix + "/food-101"
35 |
36 | # BIRDS dataset
37 | BIRDS_PATH = prefix + "/birdsnap"
38 |
39 | # CUB-200-2011 birds
40 | CUB_PATH = prefix + "/CUB_200_2011"
41 |
42 | # COCO
43 | COCO_PATH = prefix + "/coco_cls"
44 |
45 | # ade20k
46 | ADE20K_PATH = prefix + "/ade20k_cls"
47 |
48 | # Mix seg: cs voc ade
49 | MIX_SEG_PATH = prefix + "/mix_seg"
50 |
51 | # PETS dataset
52 | PETS_PATH = prefix + ""
53 |
54 | # Caltech datasets
55 | CALTECH101_PATH = prefix + ""
56 | CALTECH256_PATH = prefix + ""
57 |
58 | value_scale = 255
59 | mean = [0.485, 0.456, 0.406]
60 | mean = [0.48145466, 0.4578275, 0.40821073]
61 | mean = [item * value_scale for item in mean]
62 | std = [0.229, 0.224, 0.225]
63 | std = [0.26862954, 0.26130258, 0.27577711]
64 | std = [item * value_scale for item in std]
65 |
66 | # Data Augmentation defaults
67 | TRAIN_TRANSFORMS = transforms.Compose([
68 | # transforms.Resize(32),
69 | transforms.RandomResizedCrop(224),
70 | # transforms.RandomResizedCrop(224, scale=(0.08,1.0), ratio=(0.75,1.333333)),
71 | # transforms.RandomResizedCrop(224, scale=(0.08,1.0), ratio=(0.5,2.0)),
72 | transforms.RandomHorizontalFlip(),
73 | transforms.ToTensor(),
74 | # transforms.Normalize(mean=mean, std=std),
75 | ])
76 | # TRAIN_TRANSFORMS = transforms.Compose([
77 | # # transforms.Resize(32),
78 | # transforms.Resize(256),
79 | # transforms.CenterCrop(224),
80 | # transforms.ToTensor(),
81 | # # transforms.Normalize(mean=mean, std=std),
82 | # ])
83 |
84 | TEST_TRANSFORMS = transforms.Compose([
85 | # transforms.Resize(32),
86 | transforms.Resize(256),
87 | transforms.CenterCrop(224),
88 | transforms.ToTensor(),
89 | # transforms.Normalize(mean=mean, std=std),
90 | ])
91 |
92 | # from PIL import Image
93 | # BICUBIC = Image.BICUBIC
94 | # TEST_TRANSFORMS = transforms.Compose([
95 | # # transforms.Resize(32),
96 | # transforms.Resize(224,interpolation=BICUBIC),
97 | # transforms.CenterCrop(224),
98 | # transforms.ToTensor(),
99 | # # transforms.Normalize(mean=mean, std=std),
100 | # ])
101 |
102 | # Add RandAugment with N, M(hyperparameter)
103 | # N=3
104 | # M=9
105 | # TRAIN_TRANSFORMS.transforms.insert(0, RandAugment(N, M))
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/transfer_ds/cub.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class CUB(Dataset):
8 |
9 | def __init__(self, path, train=True, transform=None, target_transform=None):
10 |
11 | self.root = path
12 | self.is_train = train
13 | self.transform = transform
14 | self.target_transform = target_transform
15 | self.images_path = {}
16 | with open(os.path.join(self.root, 'images.txt')) as f:
17 | for line in f:
18 | image_id, path = line.split()
19 | self.images_path[image_id] = path
20 |
21 | self.class_ids = {}
22 | with open(os.path.join(self.root, 'image_class_labels.txt')) as f:
23 | for line in f:
24 | image_id, class_id = line.split()
25 | self.class_ids[image_id] = class_id
26 |
27 | self.data_id = []
28 | if self.is_train:
29 | with open(os.path.join(self.root, 'train_test_split.txt')) as f:
30 | for line in f:
31 | image_id, is_train = line.split()
32 | if int(is_train):
33 | self.data_id.append(image_id)
34 | if not self.is_train:
35 | with open(os.path.join(self.root, 'train_test_split.txt')) as f:
36 | for line in f:
37 | image_id, is_train = line.split()
38 | if not int(is_train):
39 | self.data_id.append(image_id)
40 |
41 | def __len__(self):
42 | return len(self.data_id)
43 |
44 | def __getitem__(self, index):
45 | """
46 | Args:
47 | index: index of training dataset
48 | Returns:
49 | image and its corresponding label
50 | """
51 | image_id = self.data_id[index]
52 | class_id = int(self._get_class_by_id(image_id)) - 1
53 | path = self._get_path_by_id(image_id)
54 | image = cv2.imread(os.path.join(self.root, 'images', path))
55 |
56 | if self.transform:
57 | image = self.transform(image)
58 |
59 | if self.target_transform:
60 | class_id = self.target_transform(class_id)
61 | return image, class_id
62 |
63 | def _get_path_by_id(self, image_id):
64 |
65 | return self.images_path[image_id]
66 |
67 | def _get_class_by_id(self, image_id):
68 |
69 | return self.class_ids[image_id]
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/transfer_ds/cub_for_robust_codebase.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 |
6 | def main():
7 | path= "/opt/tiger/filter_transfer/data/CUB_200_2011"
8 | root = path
9 | images_path = {}
10 |
11 | os.system('mkdir train')
12 | os.system('cp -r images/* train/')
13 | os.system('mkdir val')
14 | os.system('cp -r images/* val/')
15 |
16 | with open(os.path.join(root, 'images.txt')) as f:
17 | for line in f:
18 | image_id, path = line.split()
19 | images_path[image_id] = path
20 |
21 | class_ids = {}
22 | with open(os.path.join(root, 'image_class_labels.txt')) as f:
23 | for line in f:
24 | image_id, class_id = line.split()
25 | class_ids[image_id] = class_id
26 |
27 | train_id = [] # train not val
28 | with open(os.path.join(root, 'train_test_split.txt')) as f:
29 | for line in f:
30 | image_id, is_train = line.split()
31 | if int(is_train):
32 | train_id.append(image_id)
33 |
34 | with open(os.path.join(root, 'images.txt')) as f:
35 | for line in f:
36 | image_id, path = line.split()
37 | if image_id in train_id:
38 | os.system('rm val/{}'.format(path))
39 | else:
40 | # import ipdb
41 | # ipdb.set_trace(context=20)
42 | os.system('rm train/{}'.format(path))
43 |
44 |
45 |
46 |
47 |
48 |
49 | if __name__ == "__main__":
50 | main()
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/transfer_ds/dtd.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | # from . import constants as cs
3 | from torch.utils.data.dataset import Dataset
4 | from torch.utils.data import DataLoader
5 | from os.path import join as osj
6 | from PIL import Image
7 | from torchvision import transforms
8 | import os
9 |
10 | TRAIN_TRANSFORMS = transforms.Compose([
11 | # transforms.Resize(32),
12 | transforms.RandomResizedCrop(224),
13 | transforms.RandomHorizontalFlip(),
14 | transforms.ToTensor(),
15 | # transforms.Normalize(mean=mean, std=std),
16 | ])
17 |
18 | TEST_TRANSFORMS = transforms.Compose([
19 | # transforms.Resize(32),
20 | transforms.Resize(256),
21 | transforms.CenterCrop(224),
22 | transforms.ToTensor(),
23 | # transforms.Normalize(mean=mean, std=std),
24 | ])
25 |
26 | class DTD(Dataset):
27 | def __init__(self, split="1", train=False, transform=TRAIN_TRANSFORMS):
28 | super().__init__()
29 | DTD_PATH='/opt/tiger/filter_transfer/data/dtd'
30 | train_path = osj(DTD_PATH, f"labels/train{split}.txt")
31 | val_path = osj(DTD_PATH, f"labels/val{split}.txt")
32 | test_path = osj(DTD_PATH, f"labels/test{split}.txt")
33 | if train:
34 | print(DTD_PATH)
35 | self.ims = open(train_path).readlines() + \
36 | open(val_path).readlines()
37 | else:
38 | self.ims = open(test_path).readlines()
39 |
40 | self.full_ims = [osj(DTD_PATH, "images", x) for x in self.ims]
41 |
42 | pth = osj(DTD_PATH, f"labels/classes.txt")
43 | self.c_to_t = {x.strip(): i for i, x in enumerate(open(pth).readlines())}
44 |
45 | # self.transform = TRAIN_TRANSFORMS if train else TEST_TRANSFORMS
46 | self.transform = transform
47 | self.labels = [self.c_to_t[x.split("/")[0]] for x in self.ims]
48 |
49 | def __getitem__(self, index):
50 | im = Image.open(self.full_ims[index].strip())
51 | im = self.transform(im)
52 | return im, self.labels[index]
53 |
54 | def __len__(self):
55 | return len(self.ims)
56 |
57 | if __name__ == "__main__":
58 | dtd = DTD(train=True)
59 | # import ipdb
60 | # ipdb.set_trace(context=20)
61 | target_folder = "/opt/tiger/filter_transfer/data/dtd/mix_dtd/"
62 | for im in dtd.full_ims:
63 | img = im[:-1]
64 | category = img.split('/')[-2]
65 | if not os.path.exists(target_folder+category):os.makedirs(target_folder+category)
66 | os.system('cp {} {}'.format(img, target_folder+category))
67 | a=1
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/transfer_ds/fine_tunify.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from robustness.tools.custom_modules import SequentialWithArgs
3 |
4 | def ft(model_name, model_ft, num_classes, additional_hidden=0):
5 | if model_name in ['clip_resnest50d','resnet50_feat_pca_pre_relu_multi_pool','resnet18_feat_pre_relu_multi_pool','resnet50_feat_pca_pre_relu_multi','resnet18_feat_pre_relu_multi','resnet50_feat_interpolate_multi','resnet18_multi','resnet50_feat_pre_relu_multi','resnet50_overhaul','resnet18_feat_pre_relu_regressor','resnet18_custom','resnet18_feat_pre_relu',"resnet50_feat_interpolate","resnet50_feat_pca","resnet50_feat_nmf","resnet50_feat_lda","resnet50_feat_mag","resnet18_feat","resnet152_feat","resnet50_feat","resnet","resnet20_as_gift","resnet50_clean", "resnet18", "resnet34","resnet50", "wide_resnet50_2", "wide_resnet50_4", "resnext50_32x4d", 'shufflenet']:
6 | num_ftrs = model_ft.fc.in_features
7 | # The two cases are split just to allow loading
8 | # models trained prior to adding the additional_hidden argument
9 | # without errors
10 | if additional_hidden == 0:
11 | model_ft.fc = nn.Linear(num_ftrs, num_classes)
12 | else:
13 | model_ft.fc = SequentialWithArgs(
14 | *list(sum([[nn.Linear(num_ftrs, num_ftrs), nn.ReLU()] for i in range(additional_hidden)], [])),
15 | nn.Linear(num_ftrs, num_classes)
16 | )
17 | input_size = 224
18 |
19 | elif model_name == 'RN50':
20 | num_ftrs = 1024
21 | # model_ft.fc = nn.Linear(num_ftrs, num_classes)
22 | input_size = 224
23 | elif model_name == "alexnet":
24 | num_ftrs = model_ft.classifier[6].in_features
25 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
26 | input_size = 224
27 | elif "vgg" in model_name:
28 | num_ftrs = model_ft.classifier[6].in_features
29 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
30 | input_size = 224
31 | elif model_name == "squeezenet":
32 | model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
33 | model_ft.num_classes = num_classes
34 | input_size = 224
35 | elif model_name == "densenet":
36 | num_ftrs = model_ft.classifier.in_features
37 | model_ft.classifier = nn.Linear(num_ftrs, num_classes)
38 | input_size = 224
39 | elif model_name in ["mnasnet", "mobilenet"]:
40 | num_ftrs = model_ft.classifier.in_features
41 | model_ft.classifier = nn.Linear(num_ftrs, num_classes)
42 | input_size = 224
43 | else:
44 | pass
45 | # raise ValueError("Invalid model type, exiting...")
46 |
47 | return model_ft
48 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/transfer_ds/food_101.py:
--------------------------------------------------------------------------------
1 | # pytorch imports
2 | import torch
3 | from torchvision import models, transforms, datasets
4 | from torch.utils.data import DataLoader
5 | from robustness import data_augmentation as da
6 | from . import constants as cs
7 |
8 | class FOOD101():
9 | def __init__(self, transform=None):
10 | # self.TRAIN_PATH = cs.FOOD_PATH+"/train"
11 | self.TRAIN_PATH = "/opt/tiger/filter_transfer/data/food-101/train"
12 | # self.VALID_PATH = cs.FOOD_PATH+"/valid"
13 | self.VALID_PATH = "/opt/tiger/filter_transfer/data/food-101/valid"
14 |
15 | self.train_ds, self.valid_ds, self.train_cls, self.valid_cls = [None]*4
16 | self.transform = transform
17 |
18 | def _get_tfms(self):
19 | train_tfms = cs.TRAIN_TRANSFORMS
20 | valid_tfms = cs.TEST_TRANSFORMS
21 | return train_tfms, valid_tfms
22 |
23 | def get_dataset(self):
24 | # train_tfms, valid_tfms = self._get_tfms() # transformations
25 | train_tfms, valid_tfms = self.transform, self.transform
26 | self.train_ds = datasets.ImageFolder(root=self.TRAIN_PATH,
27 | transform=train_tfms)
28 | self.valid_ds = datasets.ImageFolder(root=self.VALID_PATH,
29 | transform=valid_tfms)
30 | self.train_classes = self.train_ds.classes
31 | self.valid_classes = self.valid_ds.classes
32 |
33 | # print(self.train_classes)
34 |
35 | assert self.train_classes==self.valid_classes
36 | return self.train_ds, self.valid_ds, self.train_classes
37 |
38 | def get_dls(self, train_ds, valid_ds, bs, **kwargs):
39 | return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
40 | DataLoader(valid_ds, batch_size=bs, shuffle=True, **kwargs))
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/transfer_ds/imbalance_cifar.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 |
6 |
7 | class IMBALANCECIFAR10(torchvision.datasets.CIFAR10):
8 | cls_num = 10
9 |
10 | def __init__(self, root, imb_type='exp', imb_factor=0.01, rand_number=0, train=True,
11 | transform=None, target_transform=None,
12 | download=False):
13 | super(IMBALANCECIFAR10, self).__init__(root, train, transform, target_transform, download)
14 | np.random.seed(rand_number)
15 | img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor)
16 | self.gen_imbalanced_data(img_num_list)
17 |
18 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
19 | img_max = len(self.data) / cls_num
20 | img_num_per_cls = []
21 | if imb_type == 'exp':
22 | for cls_idx in range(cls_num):
23 | num = img_max * (imb_factor ** (cls_idx / (cls_num - 1.0)))
24 | img_num_per_cls.append(int(num))
25 | elif imb_type == 'step':
26 | for cls_idx in range(cls_num // 2):
27 | img_num_per_cls.append(int(img_max))
28 | for cls_idx in range(cls_num // 2):
29 | img_num_per_cls.append(int(img_max * imb_factor))
30 | else:
31 | img_num_per_cls.extend([int(img_max)] * cls_num)
32 | return img_num_per_cls
33 |
34 | def gen_imbalanced_data(self, img_num_per_cls):
35 | new_data = []
36 | new_targets = []
37 | targets_np = np.array(self.targets, dtype=np.int64)
38 | classes = np.unique(targets_np)
39 | # np.random.shuffle(classes)
40 | self.num_per_cls_dict = dict()
41 | for the_class, the_img_num in zip(classes, img_num_per_cls):
42 | self.num_per_cls_dict[the_class] = the_img_num
43 | idx = np.where(targets_np == the_class)[0]
44 | np.random.shuffle(idx)
45 | selec_idx = idx[:the_img_num]
46 | new_data.append(self.data[selec_idx, ...])
47 | new_targets.extend([the_class, ] * the_img_num)
48 | new_data = np.vstack(new_data)
49 | self.data = new_data
50 | self.targets = new_targets
51 |
52 | def get_cls_num_list(self):
53 | cls_num_list = []
54 | for i in range(self.cls_num):
55 | cls_num_list.append(self.num_per_cls_dict[i])
56 | return cls_num_list
57 |
58 |
59 | class IMBALANCECIFAR100(IMBALANCECIFAR10):
60 | """`CIFAR100 `_ Dataset.
61 | This is a subclass of the `CIFAR10` Dataset.
62 | """
63 | base_folder = 'cifar-100-python'
64 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
65 | filename = "cifar-100-python.tar.gz"
66 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
67 | train_list = [
68 | ['train', '16019d7e3df5f24257cddd939b257f8d'],
69 | ]
70 |
71 | test_list = [
72 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
73 | ]
74 | meta = {
75 | 'filename': 'meta',
76 | 'key': 'fine_label_names',
77 | 'md5': '7973b15100ade9c7d40fb424638fde48',
78 | }
79 | cls_num = 100
80 |
81 |
82 | if __name__ == '__main__':
83 | transform = transforms.Compose(
84 | [transforms.ToTensor(),
85 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
86 | trainset = IMBALANCECIFAR100(root='./data', train=True,
87 | download=True, transform=transform)
88 | trainloader = iter(trainset)
89 | data, label = next(trainloader)
90 | import pdb;
91 |
92 | pdb.set_trace()
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/transfer_ds/process_dataset/pro_aircraft.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import os
4 | import numpy as np
5 |
6 | # --- get class names
7 | # read_filepath = 'data/variants.txt'
8 | # names=[]
9 | # with open(read_filepath, 'r') as f:
10 | # for line in f.readlines():
11 | # print(line.strip())
12 | # names.append(line.strip())
13 | #
14 | # names = ["a "+n for n in names]
15 | # print(names)
16 |
17 |
18 | # --- find train images
19 | read_filepath = 'data/images_variant_trainval.txt'
20 | names=[]
21 | with open(read_filepath, 'r') as f:
22 | for line in f.readlines():
23 | line=line.strip().split(' ')
24 | img_name = line[0] + '.jpg'
25 | print(img_name)
26 | names.append(img_name)
27 |
28 | # names = ["a "+n for n in names]
29 | print(names)
30 |
31 | for name in names:
32 | os.system('cp data/images/{} ../orig_pool15/aircraft/'.format(name))
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/transfer_ds/process_dataset/pro_caltech101.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import os
4 | import numpy as np
5 |
6 |
7 | # --- find train images
8 | read_filepath = 'data/images_variant_trainval.txt'
9 | names=[]
10 | with open(read_filepath, 'r') as f:
11 | for line in f.readlines():
12 | line=line.strip().split(' ')
13 | img_name = line[0] + '.jpg'
14 | print(img_name)
15 | names.append(img_name)
16 |
17 | # names = ["a "+n for n in names]
18 | print(names)
19 |
20 | for name in names:
21 | os.system('cp data/images/{} ../orig_pool15/aircraft/'.format(name))
22 | os.system('pwd')
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/transfer_ds/process_dataset/pro_flowers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 |
4 | cat2name = {"21": "fire lily", "3": "canterbury bells", "45": "bolero deep blue", "1": "pink primrose", "34": "mexican aster", "27": "prince of wales feathers", "7": "moon orchid", "16": "globe-flower", "25": "grape hyacinth", "26": "corn poppy", "79": "toad lily", "39": "siam tulip", "24": "red ginger", "67": "spring crocus", "35": "alpine sea holly", "32": "garden phlox", "10": "globe thistle", "6": "tiger lily", "93": "ball moss", "33": "love in the mist", "9": "monkshood", "102": "blackberry lily", "14": "spear thistle", "19": "balloon flower", "100": "blanket flower", "13": "king protea", "49": "oxeye daisy", "15": "yellow iris", "61": "cautleya spicata", "31": "carnation", "64": "silverbush", "68": "bearded iris", "63": "black-eyed susan", "69": "windflower", "62": "japanese anemone", "20": "giant white arum lily", "38": "great masterwort", "4": "sweet pea", "86": "tree mallow", "101": "trumpet creeper", "42": "daffodil", "22": "pincushion flower", "2": "hard-leaved pocket orchid", "54": "sunflower", "66": "osteospermum", "70": "tree poppy", "85": "desert-rose", "99": "bromelia", "87": "magnolia", "5": "english marigold", "92": "bee balm", "28": "stemless gentian", "97": "mallow", "57": "gaura", "40": "lenten rose", "47": "marigold", "59": "orange dahlia", "48": "buttercup", "55": "pelargonium", "36": "ruby-lipped cattleya", "91": "hippeastrum", "29": "artichoke", "71": "gazania", "90": "canna lily", "18": "peruvian lily", "98": "mexican petunia", "8": "bird of paradise", "30": "sweet william", "17": "purple coneflower", "52": "wild pansy", "84": "columbine", "12": "colt's foot", "11": "snapdragon", "96": "camellia", "23": "fritillary", "50": "common dandelion", "44": "poinsettia", "53": "primula", "72": "azalea", "65": "californian poppy", "80": "anthurium", "76": "morning glory", "37": "cape flower", "56": "bishop of llandaff", "60": "pink-yellow dahlia", "82": "clematis", "58": "geranium", "75": "thorn apple", "41": "barbeton daisy", "95": "bougainvillea", "43": "sword lily", "83": "hibiscus", "78": "lotus", "88": "cyclamen", "94": "foxglove", "81": "frangipani", "74": "rose", "89": "watercress", "73": "water lily", "46": "wallflower", "77": "passion flower", "51": "petunia"}
5 |
6 | cates = sorted(glob.glob('*'))[:-1]
7 |
8 |
9 |
10 | names = ["a "+ cat2name[cat] for cat in cates]
11 | import ipdb
12 |
13 | ipdb.set_trace(context=20)
14 | a=1
15 |
16 |
17 | b=[a[3:] for a in al]
18 | al = ["a "+ v for v in al]
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/transfer_ds/process_dataset/pro_imgnet.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | src_classes = sorted(glob.glob('ILSVRC2012_img_train/train/*'))
5 | os.system('mkdir -p ILSVRC2012_img_train_plus15tasks/train/img')
6 |
7 | i=0
8 | for src_cls in src_classes:
9 | print(i)
10 | i+=1
11 | os.system('cp {}/* ILSVRC2012_img_train_plus15tasks/train/img'.format(src_cls))
12 |
13 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/transfer_ds/process_dataset/pro_pool15.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | # pwd = /opt/tiger/filter_transfer/data
4 | new_ds = 'ds15img'
5 | all_ds = sorted(glob.glob('pool15/*'))
6 | os.system('mkdir {}'.format(new_ds))
7 |
8 | for ds in all_ds:
9 | os.system('mv {}/train/*/* {}'.format(ds, new_ds))
10 |
11 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/transfer_ds/process_dataset/process_pets.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | names=[]
5 |
6 | all = sorted(glob.glob('*'))
7 | for i in all:
8 | # os.system('mv {} train/'.format(i))
9 | # os.system('mkdir val/{}'.format(i))
10 |
11 | name = sorted(glob.glob('{}/*.jpg'.format(i)))[0]
12 |
13 | # name = 'a '+name.split('/')[-1].split('_')[0]
14 | name = 'a '+name.split('/')[-1].split('_1')[0]
15 | # name = sorted(glob.glob('train/{}/*.jpg'.format(i)))[0].split('_')[0]
16 |
17 | print(name)
18 | names.append(name)
19 | print(names)
20 | # os.system('cp {} val/{}/'.format(to_cp_name,i))
21 | import ipdb
22 | ipdb.set_trace(context=20)
23 | print(names)
24 |
25 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/transfer_ds/utils.py:
--------------------------------------------------------------------------------
1 | from . import generic_dataset
2 | from . import food_101
3 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/ytbb-robust_metadata/class_idx_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "7": 1,
3 | "8": 1,
4 | "9": 1,
5 | "10": 1,
6 | "11": 1,
7 | "12": 1,
8 | "13": 1,
9 | "14": 1,
10 | "15": 1,
11 | "16": 1,
12 | "17": 1,
13 | "18": 1,
14 | "19": 1,
15 | "20": 1,
16 | "21": 1,
17 | "22": 1,
18 | "23": 1,
19 | "24": 1,
20 | "80": 1,
21 | "81": 1,
22 | "82": 1,
23 | "83": 1,
24 | "84": 1,
25 | "85": 1,
26 | "86": 1,
27 | "87": 1,
28 | "88": 1,
29 | "89": 1,
30 | "90": 1,
31 | "91": 1,
32 | "92": 1,
33 | "93": 1,
34 | "94": 1,
35 | "95": 1,
36 | "96": 1,
37 | "97": 1,
38 | "98": 1,
39 | "99": 1,
40 | "100": 1,
41 | "127": 1,
42 | "128": 1,
43 | "129": 1,
44 | "130": 1,
45 | "131": 1,
46 | "132": 1,
47 | "133": 1,
48 | "134": 1,
49 | "135": 1,
50 | "136": 1,
51 | "137": 1,
52 | "138": 1,
53 | "139": 1,
54 | "140": 1,
55 | "141": 1,
56 | "142": 1,
57 | "143": 1,
58 | "144": 1,
59 | "145": 1,
60 | "146": 1,
61 | "151": 19,
62 | "152": 19,
63 | "153": 19,
64 | "154": 19,
65 | "155": 19,
66 | "156": 19,
67 | "157": 19,
68 | "158": 19,
69 | "159": 19,
70 | "160": 19,
71 | "161": 19,
72 | "162": 19,
73 | "163": 19,
74 | "164": 19,
75 | "165": 19,
76 | "166": 19,
77 | "167": 19,
78 | "168": 19,
79 | "169": 19,
80 | "170": 19,
81 | "171": 19,
82 | "172": 19,
83 | "173": 19,
84 | "174": 19,
85 | "175": 19,
86 | "176": 19,
87 | "177": 19,
88 | "178": 19,
89 | "179": 19,
90 | "180": 19,
91 | "181": 19,
92 | "182": 19,
93 | "183": 19,
94 | "184": 19,
95 | "185": 19,
96 | "186": 19,
97 | "187": 19,
98 | "188": 19,
99 | "189": 19,
100 | "190": 19,
101 | "191": 19,
102 | "192": 19,
103 | "193": 19,
104 | "194": 19,
105 | "195": 19,
106 | "196": 19,
107 | "197": 19,
108 | "198": 19,
109 | "199": 19,
110 | "200": 19,
111 | "201": 19,
112 | "202": 19,
113 | "203": 19,
114 | "204": 19,
115 | "205": 19,
116 | "206": 19,
117 | "207": 19,
118 | "208": 19,
119 | "209": 19,
120 | "210": 19,
121 | "211": 19,
122 | "212": 19,
123 | "213": 19,
124 | "214": 19,
125 | "215": 19,
126 | "216": 19,
127 | "217": 19,
128 | "218": 19,
129 | "219": 19,
130 | "220": 19,
131 | "221": 19,
132 | "222": 19,
133 | "223": 19,
134 | "224": 19,
135 | "225": 19,
136 | "226": 19,
137 | "227": 19,
138 | "228": 19,
139 | "229": 19,
140 | "230": 19,
141 | "231": 19,
142 | "232": 19,
143 | "233": 19,
144 | "234": 19,
145 | "235": 19,
146 | "236": 19,
147 | "237": 19,
148 | "238": 19,
149 | "239": 19,
150 | "240": 19,
151 | "241": 19,
152 | "242": 19,
153 | "243": 19,
154 | "244": 19,
155 | "245": 19,
156 | "246": 19,
157 | "247": 19,
158 | "248": 19,
159 | "249": 19,
160 | "250": 19,
161 | "251": 19,
162 | "252": 19,
163 | "253": 19,
164 | "254": 19,
165 | "255": 19,
166 | "256": 19,
167 | "257": 19,
168 | "258": 19,
169 | "259": 19,
170 | "260": 19,
171 | "261": 19,
172 | "262": 19,
173 | "263": 19,
174 | "264": 19,
175 | "265": 19,
176 | "266": 19,
177 | "267": 19,
178 | "268": 19,
179 | "281": 7,
180 | "282": 7,
181 | "283": 7,
182 | "284": 7,
183 | "285": 7,
184 | "286": 7,
185 | "287": 7,
186 | "294": 5,
187 | "295": 5,
188 | "296": 5,
189 | "297": 5,
190 | "339": 10,
191 | "340": 17,
192 | "385": 20,
193 | "386": 20,
194 | "404": 13,
195 | "407": 23,
196 | "436": 23,
197 | "444": 2,
198 | "466": 15,
199 | "468": 23,
200 | "472": 3,
201 | "499": 12,
202 | "511": 23,
203 | "554": 3,
204 | "555": 16,
205 | "569": 16,
206 | "576": 3,
207 | "609": 23,
208 | "623": 12,
209 | "625": 3,
210 | "627": 23,
211 | "654": 4,
212 | "656": 23,
213 | "661": 23,
214 | "665": 11,
215 | "671": 2,
216 | "675": 16,
217 | "717": 16,
218 | "734": 16,
219 | "751": 23,
220 | "779": 4,
221 | "814": 3,
222 | "817": 23,
223 | "864": 16,
224 | "867": 16,
225 | "874": 4,
226 | "879": 21,
227 | "914": 3,
228 | "981": 0,
229 | "982": 0,
230 | "983": 0,
231 | "985": 9,
232 | "986": 9
233 | }
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/ytbb-robust_metadata/rev_class_idx_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": [
3 | 981,
4 | 982,
5 | 983
6 | ],
7 | "1": [
8 | 7,
9 | 8,
10 | 9,
11 | 10,
12 | 11,
13 | 12,
14 | 13,
15 | 14,
16 | 15,
17 | 16,
18 | 17,
19 | 18,
20 | 19,
21 | 20,
22 | 21,
23 | 22,
24 | 23,
25 | 24,
26 | 80,
27 | 81,
28 | 82,
29 | 83,
30 | 84,
31 | 85,
32 | 86,
33 | 87,
34 | 88,
35 | 89,
36 | 90,
37 | 91,
38 | 92,
39 | 93,
40 | 94,
41 | 95,
42 | 96,
43 | 97,
44 | 98,
45 | 99,
46 | 100,
47 | 127,
48 | 128,
49 | 129,
50 | 130,
51 | 131,
52 | 132,
53 | 133,
54 | 134,
55 | 135,
56 | 136,
57 | 137,
58 | 138,
59 | 139,
60 | 140,
61 | 141,
62 | 142,
63 | 143,
64 | 144,
65 | 145,
66 | 146
67 | ],
68 | "2": [
69 | 444,
70 | 671
71 | ],
72 | "3": [
73 | 472,
74 | 554,
75 | 576,
76 | 625,
77 | 814,
78 | 914
79 | ],
80 | "4": [
81 | 654,
82 | 779,
83 | 874
84 | ],
85 | "5": [
86 | 294,
87 | 295,
88 | 296,
89 | 297
90 | ],
91 | "7": [
92 | 281,
93 | 282,
94 | 283,
95 | 284,
96 | 285,
97 | 286,
98 | 287
99 | ],
100 | "9": [
101 | 985,
102 | 986
103 | ],
104 | "10": [
105 | 339
106 | ],
107 | "11": [
108 | 665
109 | ],
110 | "12": [
111 | 499,
112 | 623
113 | ],
114 | "13": [
115 | 404
116 | ],
117 | "15": [
118 | 466
119 | ],
120 | "16": [
121 | 555,
122 | 569,
123 | 675,
124 | 717,
125 | 734,
126 | 864,
127 | 867
128 | ],
129 | "17": [
130 | 340
131 | ],
132 | "19": [
133 | 151,
134 | 152,
135 | 153,
136 | 154,
137 | 155,
138 | 156,
139 | 157,
140 | 158,
141 | 159,
142 | 160,
143 | 161,
144 | 162,
145 | 163,
146 | 164,
147 | 165,
148 | 166,
149 | 167,
150 | 168,
151 | 169,
152 | 170,
153 | 171,
154 | 172,
155 | 173,
156 | 174,
157 | 175,
158 | 176,
159 | 177,
160 | 178,
161 | 179,
162 | 180,
163 | 181,
164 | 182,
165 | 183,
166 | 184,
167 | 185,
168 | 186,
169 | 187,
170 | 188,
171 | 189,
172 | 190,
173 | 191,
174 | 192,
175 | 193,
176 | 194,
177 | 195,
178 | 196,
179 | 197,
180 | 198,
181 | 199,
182 | 200,
183 | 201,
184 | 202,
185 | 203,
186 | 204,
187 | 205,
188 | 206,
189 | 207,
190 | 208,
191 | 209,
192 | 210,
193 | 211,
194 | 212,
195 | 213,
196 | 214,
197 | 215,
198 | 216,
199 | 217,
200 | 218,
201 | 219,
202 | 220,
203 | 221,
204 | 222,
205 | 223,
206 | 224,
207 | 225,
208 | 226,
209 | 227,
210 | 228,
211 | 229,
212 | 230,
213 | 231,
214 | 232,
215 | 233,
216 | 234,
217 | 235,
218 | 236,
219 | 237,
220 | 238,
221 | 239,
222 | 240,
223 | 241,
224 | 242,
225 | 243,
226 | 244,
227 | 245,
228 | 246,
229 | 247,
230 | 248,
231 | 249,
232 | 250,
233 | 251,
234 | 252,
235 | 253,
236 | 254,
237 | 255,
238 | 256,
239 | 257,
240 | 258,
241 | 259,
242 | 260,
243 | 261,
244 | 262,
245 | 263,
246 | 264,
247 | 265,
248 | 266,
249 | 267,
250 | 268
251 | ],
252 | "20": [
253 | 385,
254 | 386
255 | ],
256 | "21": [
257 | 879
258 | ],
259 | "23": [
260 | 407,
261 | 436,
262 | 468,
263 | 511,
264 | 609,
265 | 627,
266 | 656,
267 | 661,
268 | 751,
269 | 817
270 | ]
271 | }
--------------------------------------------------------------------------------
/src/classfier_tuning/src/datasets/ytbb-robust_metadata/ytbb_class_index.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": "person",
3 | "1": "bird",
4 | "2": "bicycle",
5 | "3": "boat",
6 | "4": "bus",
7 | "5": "bear",
8 | "6": "cow",
9 | "7": "cat",
10 | "8": "giraffe",
11 | "9": "potted plant",
12 | "10": "horse",
13 | "11": "motorcycle",
14 | "12": "knife",
15 | "13": "airplane",
16 | "14": "skateboard",
17 | "15": "train",
18 | "16": "truck",
19 | "17": "zebra",
20 | "18": "toilet",
21 | "19": "dog",
22 | "20": "elephant",
23 | "21": "umbrella",
24 | "22": "none",
25 | "23": "car"
26 | }
27 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/eurosat_text_feature.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/src/classfier_tuning/src/eurosat_text_feature.pt
--------------------------------------------------------------------------------
/src/classfier_tuning/src/get_classifier_weights.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 | import torch
6 |
7 |
8 | from src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier
9 | # from ..src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier
10 |
11 |
12 | # load_path = '/opt/tiger/filter_transfer/src/wise-ft/results/extest.r1/zeroshotEurosat.pt'
13 | # load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/syn_init_tfda_55.68.pt'
14 | # load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/syn_ref16.1_iters50_71.72.pt'
15 | # load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/real_init_16.1_tfda_86.85.pt'
16 | # load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/mix_16.1_iters50_88.21.pt'
17 | # load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/mix_16.1_iters20_88.86.pt'
18 | # load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/mix_16.1_v4_20k_87.86.pt'
19 | load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/mix_4.1_iter50_81.72.pt'
20 | image_classifier = ImageClassifier.load(load_path)
21 |
22 | # import ipdb
23 | # ipdb.set_trace(context=20)
24 |
25 | head = image_classifier.classification_head
26 | weights = head.weight.detach().numpy()
27 |
28 | torch.save(weights, '/opt/tiger/filter_transfer/src/wise-ft/cache/Eurosat/mix_4.1_iter50_81.72_weights.pt')
29 | a=0
30 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/imagenet_text_feature.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/src/classfier_tuning/src/imagenet_text_feature.pt
--------------------------------------------------------------------------------
/src/classfier_tuning/src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/src/classfier_tuning/src/models/__init__.py
--------------------------------------------------------------------------------
/src/classfier_tuning/src/models/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | import torch
5 | import numpy as np
6 |
7 | from src.models import utils
8 | from src.datasets.common import get_dataloader, maybe_dictionarize
9 |
10 | import src.datasets as datasets
11 |
12 |
13 | def eval_single_dataset(image_classifier, dataset, args):
14 | if args.freeze_encoder and not args.data_aug_lp:
15 | model = image_classifier.classification_head
16 | input_key = 'features'
17 | image_enc = image_classifier.image_encoder
18 | else:
19 | model = image_classifier
20 | input_key = 'images'
21 | image_enc = None
22 |
23 | model.eval()
24 | dataloader = get_dataloader(
25 | dataset, is_train=False, args=args, image_encoder=image_enc)
26 | batched_data = enumerate(dataloader)
27 | device = args.device
28 |
29 | if hasattr(dataset, 'post_loop_metrics'):
30 | # keep track of labels, predictions and metadata
31 | all_labels, all_preds, all_metadata = [], [], []
32 |
33 | with torch.no_grad():
34 | top1, correct, n = 0., 0., 0.
35 | for i, data in batched_data:
36 | data = maybe_dictionarize(data)
37 | x = data[input_key].to(device)
38 | y = data['labels'].to(device)
39 |
40 | if 'image_paths' in data:
41 | image_paths = data['image_paths']
42 |
43 | logits = utils.get_logits(x, model)
44 | projection_fn = getattr(dataset, 'project_logits', None)
45 | if projection_fn is not None:
46 | logits = projection_fn(logits, device)
47 |
48 | if hasattr(dataset, 'project_labels'):
49 | y = dataset.project_labels(y, device)
50 | pred = logits.argmax(dim=1, keepdim=True).to(device)
51 | if hasattr(dataset, 'accuracy'):
52 | acc1, num_total = dataset.accuracy(logits, y, image_paths, args)
53 | correct += acc1
54 | n += num_total
55 | else:
56 | correct += pred.eq(y.view_as(pred)).sum().item()
57 | n += y.size(0)
58 |
59 | if hasattr(dataset, 'post_loop_metrics'):
60 | all_labels.append(y.cpu().clone().detach())
61 | all_preds.append(logits.cpu().clone().detach())
62 | metadata = data['metadata'] if 'metadata' in data else image_paths
63 | all_metadata.extend(metadata)
64 |
65 | top1 = correct / n
66 |
67 | if hasattr(dataset, 'post_loop_metrics'):
68 | all_labels = torch.cat(all_labels)
69 | all_preds = torch.cat(all_preds)
70 | metrics = dataset.post_loop_metrics(all_labels, all_preds, all_metadata, args)
71 | if 'acc' in metrics:
72 | metrics['top1'] = metrics['acc']
73 | else:
74 | metrics = {}
75 | if 'top1' not in metrics:
76 | metrics['top1'] = top1
77 |
78 | return metrics
79 |
80 | def evaluate(image_classifier, args):
81 | if args.eval_datasets is None:
82 | return
83 | info = vars(args)
84 | for i, dataset_name in enumerate(args.eval_datasets):
85 | print('Evaluating on', dataset_name)
86 | dataset_class = getattr(datasets, dataset_name)
87 | dataset = dataset_class(
88 | image_classifier.val_preprocess,
89 | location=args.data_location,
90 | batch_size=args.batch_size
91 | )
92 |
93 | results = eval_single_dataset(image_classifier, dataset, args)
94 |
95 | if 'top1' in results:
96 | print(f"{dataset_name} Top-1 accuracy: {results['top1']:.4f}")
97 | for key, val in results.items():
98 | if 'worst' in key or 'f1' in key.lower() or 'pm0' in key:
99 | print(f"{dataset_name} {key}: {val:.4f}")
100 | info[dataset_name + ':' + key] = val
101 |
102 | if args.results_db is not None:
103 | dirname = os.path.dirname(args.results_db)
104 | if dirname:
105 | os.makedirs(dirname, exist_ok=True)
106 | with open(args.results_db, 'a+') as f:
107 | f.write(json.dumps(info) + '\n')
108 | print(f'Results saved to {args.results_db}.')
109 | else:
110 | print('Results not saved (to do so, use --results_db to specify a path).')
111 |
112 | return info
--------------------------------------------------------------------------------
/src/classfier_tuning/src/models/modeling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import copy
3 |
4 | import clip.clip as clip
5 |
6 | from src.models import utils
7 |
8 |
9 | class ImageEncoder(torch.nn.Module):
10 | def __init__(self, args, keep_lang=False):
11 | super().__init__()
12 |
13 | self.model, self.train_preprocess = clip.load(
14 | args.model, args.device, jit=False)
15 | self.val_preprocess = copy.deepcopy(self.train_preprocess)
16 |
17 | self.cache_dir = args.cache_dir
18 |
19 | if not keep_lang and hasattr(self.model, 'transformer'):
20 | delattr(self.model, 'transformer')
21 |
22 | def forward(self, images):
23 | assert self.model is not None
24 | return self.model.encode_image(images)
25 |
26 | def save(self, filename):
27 | print(f'Saving image encoder to {filename}')
28 | utils.torch_save(self, filename)
29 |
30 | @classmethod
31 | def load(cls, filename):
32 | print(f'Loading image encoder from {filename}')
33 | return utils.torch_load(filename)
34 |
35 |
36 | class ClassificationHead(torch.nn.Linear):
37 | def __init__(self, normalize, weights, biases=None):
38 | output_size, input_size = weights.shape
39 | super().__init__(input_size, output_size)
40 | self.normalize = normalize
41 | if weights is not None:
42 | self.weight = torch.nn.Parameter(weights.clone())
43 | if biases is not None:
44 | self.bias = torch.nn.Parameter(biases.clone())
45 | else:
46 | self.bias = torch.nn.Parameter(torch.zeros_like(self.bias))
47 |
48 | def forward(self, inputs):
49 | if self.normalize:
50 | inputs = inputs / inputs.norm(dim=-1, keepdim=True)
51 | return super().forward(inputs)
52 |
53 | def save(self, filename):
54 | print(f'Saving classification head to {filename}')
55 | utils.torch_save(self, filename)
56 |
57 | @classmethod
58 | def load(cls, filename):
59 | print(f'Loading classification head from {filename}')
60 | return utils.torch_load(filename)
61 |
62 |
63 | class ImageClassifier(torch.nn.Module):
64 | def __init__(self, image_encoder, classification_head, process_images=True):
65 | super().__init__()
66 | self.image_encoder = image_encoder
67 | self.classification_head = classification_head
68 | self.process_images = process_images
69 | if self.image_encoder is not None:
70 | self.train_preprocess = self.image_encoder.train_preprocess
71 | self.val_preprocess = self.image_encoder.val_preprocess
72 |
73 | def forward(self, inputs):
74 | if self.process_images:
75 | inputs = self.image_encoder(inputs)
76 | outputs = self.classification_head(inputs)
77 | return outputs
78 |
79 | def save(self, filename):
80 | print(f'Saving image classifier to {filename}')
81 | utils.torch_save(self, filename)
82 |
83 | @classmethod
84 | def load(cls, filename):
85 | print(f'Loading image classifier from {filename}')
86 | return utils.torch_load(filename)
87 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/models/zeroshot.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from tqdm import tqdm
5 |
6 | import numpy as np
7 |
8 | import clip.clip as clip
9 |
10 | import src.templates as templates
11 | import src.datasets as datasets
12 |
13 | from src.args import parse_arguments
14 | from src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier
15 | from src.models.eval import evaluate
16 |
17 |
18 | def get_classnames_zs(args, clip_model):
19 | assert args.template is not None
20 | assert args.train_dataset is not None
21 | template = getattr(templates, args.template)
22 | logit_scale = clip_model.logit_scale
23 | dataset_class = getattr(datasets, args.train_dataset)
24 | dataset = dataset_class(
25 | None,
26 | location=args.data_location,
27 | batch_size=args.batch_size,
28 | classnames=args.classnames
29 | )
30 | classnames = dataset.classnames
31 | return classnames
32 |
33 |
34 |
35 | def get_zeroshot_classifier(args, clip_model):
36 | assert args.template is not None
37 | assert args.train_dataset is not None
38 | template = getattr(templates, args.template)
39 | print('Template', template)
40 | logit_scale = clip_model.logit_scale
41 | dataset_class = getattr(datasets, args.train_dataset)
42 | dataset = dataset_class(
43 | None,
44 | location=args.data_location,
45 | batch_size=args.batch_size,
46 | classnames=args.classnames
47 | )
48 | device = args.device
49 | clip_model.eval()
50 | clip_model.to(device)
51 |
52 | print('Getting zeroshot weights.')
53 | with torch.no_grad():
54 | zeroshot_weights = []
55 |
56 | for classname in tqdm(dataset.classnames):
57 | texts = []
58 | for t in template:
59 | texts.append(t(classname))
60 | texts = clip.tokenize(texts).to(device) # tokenize
61 | embeddings = clip_model.encode_text(texts) # embed with text encoder
62 | embeddings /= embeddings.norm(dim=-1, keepdim=True)
63 |
64 | embeddings = embeddings.mean(dim=0, keepdim=True)
65 | embeddings /= embeddings.norm()
66 |
67 | zeroshot_weights.append(embeddings)
68 |
69 | zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device)
70 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2)
71 |
72 | zeroshot_weights *= logit_scale.exp()
73 |
74 | zeroshot_weights = zeroshot_weights.squeeze().float()
75 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1)
76 | # import ipdb
77 | # ipdb.set_trace(context=20)
78 | classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights)
79 |
80 | return classification_head
81 |
82 |
83 | def eval(args):
84 | args.freeze_encoder = True
85 | if args.load is not None:
86 | classifier = ImageClassifier.load(args.load)
87 | else:
88 | image_encoder = ImageEncoder(args, keep_lang=True)
89 | classification_head = get_zeroshot_classifier(args, image_encoder.model)
90 | delattr(image_encoder.model, 'transformer')
91 | classifier = ImageClassifier(image_encoder, classification_head, process_images=False)
92 |
93 | evaluate(classifier, args)
94 |
95 | if args.save is not None:
96 | classifier.save(args.save)
97 |
98 |
99 | if __name__ == '__main__':
100 | args = parse_arguments()
101 | eval(args)
--------------------------------------------------------------------------------
/src/classfier_tuning/src/models/zeroshot_retina.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from tqdm import tqdm
5 |
6 | import numpy as np
7 |
8 | import clip.clip as clip
9 |
10 | import src.templates as templates
11 | import src.datasets as datasets
12 |
13 | from src.args import parse_arguments
14 | from src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier
15 | from src.models.eval import evaluate
16 |
17 |
18 | def get_zeroshot_classifier(args, clip_model, classnames=None):
19 | assert args.template is not None
20 | assert args.train_dataset is not None
21 |
22 | template = [lambda c: f"a photo of retina with disease {c}."]
23 | logit_scale = clip_model.logit_scale
24 |
25 | device = args.device
26 | clip_model.eval()
27 | clip_model.to(device)
28 |
29 | print('Getting zeroshot weights.')
30 | with torch.no_grad():
31 | zeroshot_weights = []
32 |
33 | for classname in tqdm(classnames):
34 | texts = []
35 | for t in template:
36 | texts.append(t(classname))
37 | texts = clip.tokenize(texts).to(device) # tokenize
38 | embeddings = clip_model.encode_text(texts) # embed with text encoder
39 | embeddings /= embeddings.norm(dim=-1, keepdim=True)
40 |
41 | embeddings = embeddings.mean(dim=0, keepdim=True)
42 | embeddings /= embeddings.norm()
43 |
44 | zeroshot_weights.append(embeddings)
45 |
46 | zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device)
47 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2)
48 |
49 | zeroshot_weights *= logit_scale.exp()
50 |
51 | zeroshot_weights = zeroshot_weights.squeeze().float()
52 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1)
53 | # import ipdb
54 | # ipdb.set_trace(context=20)
55 | classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights)
56 |
57 | return classification_head
58 |
59 |
60 | def eval(args):
61 | args.freeze_encoder = True
62 | if args.load is not None:
63 | classifier = ImageClassifier.load(args.load)
64 | else:
65 | image_encoder = ImageEncoder(args, keep_lang=True)
66 | classification_head = get_zeroshot_classifier(args, image_encoder.model)
67 | delattr(image_encoder.model, 'transformer')
68 | classifier = ImageClassifier(image_encoder, classification_head, process_images=False)
69 |
70 | evaluate(classifier, args)
71 |
72 | if args.save is not None:
73 | classifier.save(args.save)
74 |
75 |
76 | if __name__ == '__main__':
77 | args = parse_arguments()
78 | eval(args)
--------------------------------------------------------------------------------
/src/classfier_tuning/src/templates/__init__.py:
--------------------------------------------------------------------------------
1 | from .openai_imagenet_template import openai_imagenet_template
2 | from .simple_template import simple_template
3 | from .fmow_template import fmow_template
4 | from .iwildcam_template import iwildcam_template
5 | from .transfer_ds_template import *
--------------------------------------------------------------------------------
/src/classfier_tuning/src/templates/fmow_template.py:
--------------------------------------------------------------------------------
1 | from .utils import append_proper_article, get_plural
2 |
3 | fmow_template = [
4 | lambda c : f"satellite photo of a {c}.",
5 | lambda c : f"aerial photo of a {c}.",
6 | lambda c : f"satellite photo of {append_proper_article(c)}.",
7 | lambda c : f"aerial photo of {append_proper_article(c)}.",
8 | lambda c : f"satellite photo of a {c} in asia.",
9 | lambda c : f"aerial photo of a {c} in asia.",
10 | lambda c : f"satellite photo of a {c} in africa.",
11 | lambda c : f"aerial photo of a {c} in africa.",
12 | lambda c : f"satellite photo of a {c} in the americas.",
13 | lambda c : f"aerial photo of a {c} in the americas.",
14 | lambda c : f"satellite photo of a {c} in europe.",
15 | lambda c : f"aerial photo of a {c} in europe.",
16 | lambda c : f"satellite photo of a {c} in oceania.",
17 | lambda c : f"aerial photo of a {c} in oceania.",
18 | lambda c: f"a photo of a {c}.",
19 | lambda c: f"{c}.",
20 | ]
21 |
--------------------------------------------------------------------------------
/src/classfier_tuning/src/templates/iwildcam_template.py:
--------------------------------------------------------------------------------
1 | from .utils import append_proper_article, get_plural
2 |
3 | iwildcam_template = [
4 | lambda c: f"a photo of {c}.",
5 | lambda c: f"{c} in the wild.",
6 | ]
--------------------------------------------------------------------------------
/src/classfier_tuning/src/templates/openai_imagenet_template.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | openai_imagenet_template = [
5 | lambda c: f'a bad photo of a {c}.',
6 | lambda c: f'a photo of many {c}.',
7 | lambda c: f'a sculpture of a {c}.',
8 | lambda c: f'a photo of the hard to see {c}.',
9 | lambda c: f'a low resolution photo of the {c}.',
10 | lambda c: f'a rendering of a {c}.',
11 | lambda c: f'graffiti of a {c}.',
12 | lambda c: f'a bad photo of the {c}.',
13 | lambda c: f'a cropped photo of the {c}.',
14 | lambda c: f'a tattoo of a {c}.',
15 | lambda c: f'the embroidered {c}.',
16 | lambda c: f'a photo of a hard to see {c}.',
17 | lambda c: f'a bright photo of a {c}.',
18 | lambda c: f'a photo of a clean {c}.',
19 | lambda c: f'a photo of a dirty {c}.',
20 | lambda c: f'a dark photo of the {c}.',
21 | lambda c: f'a drawing of a {c}.',
22 | lambda c: f'a photo of my {c}.',
23 | lambda c: f'the plastic {c}.',
24 | lambda c: f'a photo of the cool {c}.',
25 | lambda c: f'a close-up photo of a {c}.',
26 | lambda c: f'a black and white photo of the {c}.',
27 | lambda c: f'a painting of the {c}.',
28 | lambda c: f'a painting of a {c}.',
29 | lambda c: f'a pixelated photo of the {c}.',
30 | lambda c: f'a sculpture of the {c}.',
31 | lambda c: f'a bright photo of the {c}.',
32 | lambda c: f'a cropped photo of a {c}.',
33 | lambda c: f'a plastic {c}.',
34 | lambda c: f'a photo of the dirty {c}.',
35 | lambda c: f'a jpeg corrupted photo of a {c}.',
36 | lambda c: f'a blurry photo of the {c}.',
37 | lambda c: f'a photo of the {c}.',
38 | lambda c: f'a good photo of the {c}.',
39 | lambda c: f'a rendering of the {c}.',
40 | lambda c: f'a {c} in a video game.',
41 | lambda c: f'a photo of one {c}.',
42 | lambda c: f'a doodle of a {c}.',
43 | lambda c: f'a close-up photo of the {c}.',
44 | lambda c: f'a photo of a {c}.',
45 | lambda c: f'the origami {c}.',
46 | lambda c: f'the {c} in a video game.',
47 | lambda c: f'a sketch of a {c}.',
48 | lambda c: f'a doodle of the {c}.',
49 | lambda c: f'a origami {c}.',
50 | lambda c: f'a low resolution photo of a {c}.',
51 | lambda c: f'the toy {c}.',
52 | lambda c: f'a rendition of the {c}.',
53 | lambda c: f'a photo of the clean {c}.',
54 | lambda c: f'a photo of a large {c}.',
55 | lambda c: f'a rendition of a {c}.',
56 | lambda c: f'a photo of a nice {c}.',
57 | lambda c: f'a photo of a weird {c}.',
58 | lambda c: f'a blurry photo of a {c}.',
59 | lambda c: f'a cartoon {c}.',
60 | lambda c: f'art of a {c}.',
61 | lambda c: f'a sketch of the {c}.',
62 | lambda c: f'a embroidered {c}.',
63 | lambda c: f'a pixelated photo of a {c}.',
64 | lambda c: f'itap of the {c}.',
65 | lambda c: f'a jpeg corrupted photo of the {c}.',
66 | lambda c: f'a good photo of a {c}.',
67 | lambda c: f'a plushie {c}.',
68 | lambda c: f'a photo of the nice {c}.',
69 | lambda c: f'a photo of the small {c}.',
70 | lambda c: f'a photo of the weird {c}.',
71 | lambda c: f'the cartoon {c}.',
72 | lambda c: f'art of the {c}.',
73 | lambda c: f'a drawing of the {c}.',
74 | lambda c: f'a photo of the large {c}.',
75 | lambda c: f'a black and white photo of a {c}.',
76 | lambda c: f'the plushie {c}.',
77 | lambda c: f'a dark photo of a {c}.',
78 | lambda c: f'itap of a {c}.',
79 | lambda c: f'graffiti of the {c}.',
80 | lambda c: f'a toy {c}.',
81 | lambda c: f'itap of my {c}.',
82 | lambda c: f'a photo of a cool {c}.',
83 | lambda c: f'a photo of a small {c}.',
84 | lambda c: f'a tattoo of the {c}.',
85 | ]
--------------------------------------------------------------------------------
/src/classfier_tuning/src/templates/simple_template.py:
--------------------------------------------------------------------------------
1 | from src.templates.utils import append_proper_article
2 |
3 | simple_template = [
4 | lambda c: f"a photo of a {c}."
5 | # lambda c: f"a sketch of a {c}."
6 | ]
--------------------------------------------------------------------------------
/src/classfier_tuning/src/templates/transfer_ds_template.py:
--------------------------------------------------------------------------------
1 | from src.templates.utils import append_proper_article
2 |
3 | aircraft_template = [
4 | lambda c: f"a photo of a {c}, a type of aircraft.",
5 | # lambda c: f"a photo of the {c}, a type of aircraft."
6 | ]
7 |
8 | birds_template = [
9 | lambda c: f"a photo of a {c}, a type of bird."
10 | ]
11 |
12 | eurosat_template = [
13 | lambda c: f"a centered satellite photo of {c}."
14 | ]
15 |
16 | # eurosat_template = [
17 | # lambda c: f"a centered satellite photo of {c}.",
18 | # lambda c: f'a centered satellite photo of a {c}.',
19 | # lambda c: f'a centered satellite photo of the {c}.',
20 | # ]
21 |
22 |
23 | flowers_template = [
24 | lambda c: f"a photo of a {c}, a type of flower."
25 | ]
26 |
27 | food_template = [
28 | lambda c: f"a photo of a {c}, a type of food."
29 | ]
30 |
31 | pets_template = [
32 | lambda c: f"a photo of a {c}, a type of pet."
33 | ]
34 |
35 | imagenet_template = [
36 | lambda c: f"itap of a {c}.",
37 | lambda c: f"a bad photo of the {c}.",
38 | lambda c: f"a origami {c}.",
39 | lambda c: f"a photo of the large {c}.",
40 | lambda c: f"a {c} in a video game.",
41 | lambda c: f"art of the {c}.",
42 | lambda c: f"a photo of the small {c}."]
43 |
44 | cifar100_template = [
45 | lambda c: f'a photo of a {c}.',
46 | lambda c: f'a blurry photo of a {c}.',
47 | lambda c: f'a black and white photo of a {c}.',
48 | lambda c: f'a low contrast photo of a {c}.',
49 | lambda c: f'a high contrast photo of a {c}.',
50 | lambda c: f'a bad photo of a {c}.',
51 | lambda c: f'a good photo of a {c}.',
52 | lambda c: f'a photo of a small {c}.',
53 | lambda c: f'a photo of a big {c}.',
54 | lambda c: f'a photo of the {c}.',
55 | lambda c: f'a blurry photo of the {c}.',
56 | lambda c: f'a black and white photo of the {c}.',
57 | lambda c: f'a low contrast photo of the {c}.',
58 | lambda c: f'a high contrast photo of the {c}.',
59 | lambda c: f'a bad photo of the {c}.',
60 | lambda c: f'a good photo of the {c}.',
61 | lambda c: f'a photo of the small {c}.',
62 | lambda c: f'a photo of the big {c}.',
63 | ]
64 |
65 | cifar10_templates = [
66 | lambda c: f'a photo of a {c}.',
67 | lambda c: f'a blurry photo of a {c}.',
68 | lambda c: f'a black and white photo of a {c}.',
69 | lambda c: f'a low contrast photo of a {c}.',
70 | lambda c: f'a high contrast photo of a {c}.',
71 | lambda c: f'a bad photo of a {c}.',
72 | lambda c: f'a good photo of a {c}.',
73 | lambda c: f'a photo of a small {c}.',
74 | lambda c: f'a photo of a big {c}.',
75 | lambda c: f'a photo of the {c}.',
76 | lambda c: f'a blurry photo of the {c}.',
77 | lambda c: f'a black and white photo of the {c}.',
78 | lambda c: f'a low contrast photo of the {c}.',
79 | lambda c: f'a high contrast photo of the {c}.',
80 | lambda c: f'a bad photo of the {c}.',
81 | lambda c: f'a good photo of the {c}.',
82 | lambda c: f'a photo of the small {c}.',
83 | lambda c: f'a photo of the big {c}.',
84 | ]
85 |
86 | sun_template = [
87 | lambda c: f'a photo of a {c}.',
88 | lambda c: f'a photo of the {c}.',
89 | ]
90 |
91 | cars_template = [
92 | lambda c: f'a photo of a {c}.',
93 | lambda c: f'a photo of the {c}.',
94 | lambda c: f'a photo of my {c}.',
95 | lambda c: f'i love my {c}!',
96 | lambda c: f'a photo of my dirty {c}.',
97 | lambda c: f'a photo of my clean {c}.',
98 | lambda c: f'a photo of my new {c}.',
99 | lambda c: f'a photo of my old {c}.',
100 | ]
101 |
102 | dtd_template = [
103 | lambda c: f"{c} texture."
104 | ]
105 |
106 | # dtd_template = [
107 | # lambda c: f'a photo of a {c} texture.',
108 | # lambda c: f'a photo of a {c} pattern.',
109 | # lambda c: f'a photo of a {c} thing.',
110 | # lambda c: f'a photo of a {c} object.',
111 | # lambda c: f'a photo of the {c} texture.',
112 | # lambda c: f'a photo of the {c} pattern.',
113 | # lambda c: f'a photo of the {c} thing.',
114 | # lambda c: f'a photo of the {c} object.',
115 | # ]
--------------------------------------------------------------------------------
/src/classfier_tuning/src/templates/utils.py:
--------------------------------------------------------------------------------
1 |
2 | def get_plural(name):
3 | name = name.replace('_', ' ')
4 | if name[-2:] == 'sh':
5 | name = name + 'es'
6 | elif name[-2:] == 'ch':
7 | name = name + 'es'
8 | elif name[-1:] == 'y':
9 | name = name[:-1] + 'ies'
10 | elif name[-1:] == 's':
11 | name = name + 'es'
12 | elif name[-1:] == 'x':
13 | name = name + 'es'
14 | elif name[-3:] == 'man':
15 | name = name[:-3] + 'men'
16 | elif name == 'mouse':
17 | name = 'mice'
18 | elif name[-1:] == 'f':
19 | name = name[:-1] + 'ves'
20 | else:
21 | name = name + 's'
22 | return name
23 |
24 |
25 | def append_proper_article(name):
26 | name = name.replace('_', ' ')
27 | if name[0] in 'aeiou':
28 | return 'an ' + name
29 | return 'a ' + name
30 |
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from absl import logging
4 |
5 | import tqdm
6 | import numpy as np
7 | import tensorflow as tf
8 | import tensorflow_datasets as tfds
9 |
10 |
11 | def center_crop(x, resolution):
12 | shape = tf.shape(x)
13 | h, w = shape[0], shape[1]
14 | size = tf.minimum(h, w)
15 | begin = tf.cast([h - size, w - size], tf.float32) / 2.0
16 | begin = tf.cast(begin, tf.int32)
17 | begin = tf.concat([begin, [0]], axis=0) # Add channel dimension.
18 | x = tf.slice(x, begin, [size, size, 3])
19 | x = tf.image.resize_with_pad(
20 | x, resolution, resolution, method='area', antialias=True)
21 | return x
22 |
23 |
24 | def load_data(ds, img_shape, resolution=32):
25 | if resolution <= 64:
26 | batch_size = 5000
27 | else:
28 | batch_size = 1000
29 |
30 | size = len(ds)
31 | logging.info('Dataset size: {}'.format(size))
32 | if None in img_shape:
33 | x = np.zeros(shape=(size, resolution, resolution, 3), dtype=np.uint8)
34 | else:
35 | x = np.zeros(
36 | shape=(size, img_shape[0], img_shape[1], img_shape[2]), dtype=np.uint8)
37 |
38 | if None in img_shape:
39 | ds = ds.map(lambda x, y: (center_crop(
40 | x, resolution), y), tf.data.AUTOTUNE)
41 | ds = ds.batch(batch_size=batch_size)
42 | ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)
43 |
44 | y_list = []
45 | count = 0
46 | for x_batch, y_batch in tqdm.tqdm(tfds.as_numpy(ds), desc='Process the data'):
47 | num = x_batch.shape[0]
48 | x_processed = np.array(x_batch)
49 | x[count:count + num] = x_processed
50 | y_list.append(y_batch)
51 | count += num
52 |
53 | return x, np.concatenate(y_list, axis=0)
54 |
55 |
56 | def configure_dataloader(ds, batch_size, x_transform=None, y_transform=None, train=False, shuffle=False, seed=0, resolution=None):
57 | if y_transform is None:
58 | def y_transform(x): return x
59 | else:
60 | y_transform = y_transform
61 |
62 | ds = ds.cache()
63 | if train:
64 | ds = ds.repeat()
65 | if shuffle:
66 | ds = ds.shuffle(16 * batch_size, seed=seed)
67 |
68 | if resolution is not None:
69 | ds = ds.map(lambda x, y: (tf.clip_by_value(
70 | tf.image.resize(x, [resolution, resolution], 'bilinear'),
71 | 0, 255), y), tf.data.AUTOTUNE)
72 |
73 | ds = ds.map(lambda x, y: (tf.cast(x, tf.float32), y), tf.data.AUTOTUNE)
74 |
75 | if x_transform:
76 | ds = ds.map(lambda x, y: (x_transform(
77 | x), y_transform(y)), tf.data.AUTOTUNE)
78 | else:
79 | ds = ds.map(lambda x, y: x, y_transform(y), tf.data.AUTOTUNE)
80 |
81 | ds = ds.batch(batch_size=batch_size)
82 | ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)
83 | return ds
84 |
85 |
86 | def get_dataset(config, return_raw=False, resolution=None, train_only=True):
87 | dataset_name = config.name
88 | data_dir = config.data_dir
89 |
90 | if dataset_name in ['imagenette']:
91 | split = ['train', 'validation']
92 | else:
93 | split = ['train', 'test']
94 |
95 | if resolution is None:
96 | if dataset_name in ['cifar10', 'cifar100']:
97 | resolution = 32
98 | elif dataset_name in ['stl10']:
99 | resolution = 96
100 | elif dataset_name in ['imagenette']:
101 | resolution = 256
102 |
103 | ds_builder = tfds.builder(dataset_name, data_dir=data_dir)
104 |
105 | ds_builder.download_and_prepare()
106 |
107 | img_shape = ds_builder.info.features['image'].shape
108 | num_train, num_test = ds_builder.info.splits[split[0]
109 | ].num_examples, ds_builder.info.splits[split[1]].num_examples
110 | num_classes, class_names = ds_builder.info.features[
111 | 'label'].num_classes, ds_builder.info.features['label'].names
112 |
113 | ds_train, ds_test = ds_builder.as_dataset(split=split, as_supervised=True)
114 |
115 | print('Number of training samples: {}'.format(num_train))
116 | print('Number of test samples: {}'.format(num_test))
117 | sys.stdout.flush()
118 |
119 | with config.unlocked():
120 | config.img_shape = (resolution, resolution,
121 | 3) if None in img_shape else img_shape
122 | config.num_classes = num_classes
123 | config.class_names = class_names
124 | config.train_size = num_train
125 | config.test_size = num_test
126 |
127 | x_train, y_train = load_data(ds_train, img_shape, resolution)
128 |
129 | if train_only:
130 | return x_train, y_train
131 |
132 | x_test, y_test = load_data(ds_test, img_shape, resolution)
133 |
134 | if return_raw:
135 | return x_train, y_train, x_test, y_test
136 | else:
137 | ds_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
138 | ds_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
139 | return ds_train, ds_test
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .vgg import *
2 | from .dpn import *
3 | from .lenet import *
4 | from .senet import *
5 | from .pnasnet import *
6 | from .densenet import *
7 | from .googlenet import *
8 | from .shufflenet import *
9 | from .shufflenetv2 import *
10 | from .resnet import *
11 | from .resnext import *
12 | from .preact_resnet import *
13 | from .mobilenet import *
14 | from .mobilenetv2 import *
15 | from .efficientnet import *
16 | from .regnet import *
17 | from .dla_simple import *
18 | from .dla import *
19 |
--------------------------------------------------------------------------------
/src/models/densenet.py:
--------------------------------------------------------------------------------
1 | '''DenseNet in PyTorch.'''
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class Bottleneck(nn.Module):
10 | def __init__(self, in_planes, growth_rate):
11 | super(Bottleneck, self).__init__()
12 | self.bn1 = nn.BatchNorm2d(in_planes)
13 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(4*growth_rate)
15 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
16 |
17 | def forward(self, x):
18 | out = self.conv1(F.relu(self.bn1(x)))
19 | out = self.conv2(F.relu(self.bn2(out)))
20 | out = torch.cat([out,x], 1)
21 | return out
22 |
23 |
24 | class Transition(nn.Module):
25 | def __init__(self, in_planes, out_planes):
26 | super(Transition, self).__init__()
27 | self.bn = nn.BatchNorm2d(in_planes)
28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)
29 |
30 | def forward(self, x):
31 | out = self.conv(F.relu(self.bn(x)))
32 | out = F.avg_pool2d(out, 2)
33 | return out
34 |
35 |
36 | class DenseNet(nn.Module):
37 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10):
38 | super(DenseNet, self).__init__()
39 | self.growth_rate = growth_rate
40 |
41 | num_planes = 2*growth_rate
42 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False)
43 |
44 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0])
45 | num_planes += nblocks[0]*growth_rate
46 | out_planes = int(math.floor(num_planes*reduction))
47 | self.trans1 = Transition(num_planes, out_planes)
48 | num_planes = out_planes
49 |
50 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1])
51 | num_planes += nblocks[1]*growth_rate
52 | out_planes = int(math.floor(num_planes*reduction))
53 | self.trans2 = Transition(num_planes, out_planes)
54 | num_planes = out_planes
55 |
56 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2])
57 | num_planes += nblocks[2]*growth_rate
58 | out_planes = int(math.floor(num_planes*reduction))
59 | self.trans3 = Transition(num_planes, out_planes)
60 | num_planes = out_planes
61 |
62 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3])
63 | num_planes += nblocks[3]*growth_rate
64 |
65 | self.bn = nn.BatchNorm2d(num_planes)
66 | self.linear = nn.Linear(num_planes, num_classes)
67 |
68 | def _make_dense_layers(self, block, in_planes, nblock):
69 | layers = []
70 | for i in range(nblock):
71 | layers.append(block(in_planes, self.growth_rate))
72 | in_planes += self.growth_rate
73 | return nn.Sequential(*layers)
74 |
75 | def forward(self, x):
76 | out = self.conv1(x)
77 | out = self.trans1(self.dense1(out))
78 | out = self.trans2(self.dense2(out))
79 | out = self.trans3(self.dense3(out))
80 | out = self.dense4(out)
81 | out = F.avg_pool2d(F.relu(self.bn(out)), 4)
82 | out = out.view(out.size(0), -1)
83 | out = self.linear(out)
84 | return out
85 |
86 | def DenseNet121():
87 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32)
88 |
89 | def DenseNet169():
90 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32)
91 |
92 | def DenseNet201():
93 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32)
94 |
95 | def DenseNet161():
96 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48)
97 |
98 | def densenet_cifar():
99 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12)
100 |
101 | def test():
102 | net = densenet_cifar()
103 | x = torch.randn(1,3,32,32)
104 | y = net(x)
105 | print(y)
106 |
107 | # test()
108 |
--------------------------------------------------------------------------------
/src/models/dla.py:
--------------------------------------------------------------------------------
1 | '''DLA in PyTorch.
2 |
3 | Reference:
4 | Deep Layer Aggregation. https://arxiv.org/abs/1707.06484
5 | '''
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | class BasicBlock(nn.Module):
12 | expansion = 1
13 |
14 | def __init__(self, in_planes, planes, stride=1):
15 | super(BasicBlock, self).__init__()
16 | self.conv1 = nn.Conv2d(
17 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
20 | stride=1, padding=1, bias=False)
21 | self.bn2 = nn.BatchNorm2d(planes)
22 |
23 | self.shortcut = nn.Sequential()
24 | if stride != 1 or in_planes != self.expansion*planes:
25 | self.shortcut = nn.Sequential(
26 | nn.Conv2d(in_planes, self.expansion*planes,
27 | kernel_size=1, stride=stride, bias=False),
28 | nn.BatchNorm2d(self.expansion*planes)
29 | )
30 |
31 | def forward(self, x):
32 | out = F.relu(self.bn1(self.conv1(x)))
33 | out = self.bn2(self.conv2(out))
34 | out += self.shortcut(x)
35 | out = F.relu(out)
36 | return out
37 |
38 |
39 | class Root(nn.Module):
40 | def __init__(self, in_channels, out_channels, kernel_size=1):
41 | super(Root, self).__init__()
42 | self.conv = nn.Conv2d(
43 | in_channels, out_channels, kernel_size,
44 | stride=1, padding=(kernel_size - 1) // 2, bias=False)
45 | self.bn = nn.BatchNorm2d(out_channels)
46 |
47 | def forward(self, xs):
48 | x = torch.cat(xs, 1)
49 | out = F.relu(self.bn(self.conv(x)))
50 | return out
51 |
52 |
53 | class Tree(nn.Module):
54 | def __init__(self, block, in_channels, out_channels, level=1, stride=1):
55 | super(Tree, self).__init__()
56 | self.level = level
57 | if level == 1:
58 | self.root = Root(2*out_channels, out_channels)
59 | self.left_node = block(in_channels, out_channels, stride=stride)
60 | self.right_node = block(out_channels, out_channels, stride=1)
61 | else:
62 | self.root = Root((level+2)*out_channels, out_channels)
63 | for i in reversed(range(1, level)):
64 | subtree = Tree(block, in_channels, out_channels,
65 | level=i, stride=stride)
66 | self.__setattr__('level_%d' % i, subtree)
67 | self.prev_root = block(in_channels, out_channels, stride=stride)
68 | self.left_node = block(out_channels, out_channels, stride=1)
69 | self.right_node = block(out_channels, out_channels, stride=1)
70 |
71 | def forward(self, x):
72 | xs = [self.prev_root(x)] if self.level > 1 else []
73 | for i in reversed(range(1, self.level)):
74 | level_i = self.__getattr__('level_%d' % i)
75 | x = level_i(x)
76 | xs.append(x)
77 | x = self.left_node(x)
78 | xs.append(x)
79 | x = self.right_node(x)
80 | xs.append(x)
81 | out = self.root(xs)
82 | return out
83 |
84 |
85 | class DLA(nn.Module):
86 | def __init__(self, block=BasicBlock, num_classes=10):
87 | super(DLA, self).__init__()
88 | self.base = nn.Sequential(
89 | nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
90 | nn.BatchNorm2d(16),
91 | nn.ReLU(True)
92 | )
93 |
94 | self.layer1 = nn.Sequential(
95 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
96 | nn.BatchNorm2d(16),
97 | nn.ReLU(True)
98 | )
99 |
100 | self.layer2 = nn.Sequential(
101 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False),
102 | nn.BatchNorm2d(32),
103 | nn.ReLU(True)
104 | )
105 |
106 | self.layer3 = Tree(block, 32, 64, level=1, stride=1)
107 | self.layer4 = Tree(block, 64, 128, level=2, stride=2)
108 | self.layer5 = Tree(block, 128, 256, level=2, stride=2)
109 | self.layer6 = Tree(block, 256, 512, level=1, stride=2)
110 | self.linear = nn.Linear(512, num_classes)
111 |
112 | def forward(self, x):
113 | out = self.base(x)
114 | out = self.layer1(out)
115 | out = self.layer2(out)
116 | out = self.layer3(out)
117 | out = self.layer4(out)
118 | out = self.layer5(out)
119 | out = self.layer6(out)
120 | out = F.avg_pool2d(out, 4)
121 | out = out.view(out.size(0), -1)
122 | out = self.linear(out)
123 | return out
124 |
125 |
126 | def test():
127 | net = DLA()
128 | print(net)
129 | x = torch.randn(1, 3, 32, 32)
130 | y = net(x)
131 | print(y.size())
132 |
133 |
134 | if __name__ == '__main__':
135 | test()
136 |
--------------------------------------------------------------------------------
/src/models/dla_simple.py:
--------------------------------------------------------------------------------
1 | '''Simplified version of DLA in PyTorch.
2 |
3 | Note this implementation is not identical to the original paper version.
4 | But it seems works fine.
5 |
6 | See dla.py for the original paper version.
7 |
8 | Reference:
9 | Deep Layer Aggregation. https://arxiv.org/abs/1707.06484
10 | '''
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 |
16 | class BasicBlock(nn.Module):
17 | expansion = 1
18 |
19 | def __init__(self, in_planes, planes, stride=1):
20 | super(BasicBlock, self).__init__()
21 | self.conv1 = nn.Conv2d(
22 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
23 | self.bn1 = nn.BatchNorm2d(planes)
24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
25 | stride=1, padding=1, bias=False)
26 | self.bn2 = nn.BatchNorm2d(planes)
27 |
28 | self.shortcut = nn.Sequential()
29 | if stride != 1 or in_planes != self.expansion*planes:
30 | self.shortcut = nn.Sequential(
31 | nn.Conv2d(in_planes, self.expansion*planes,
32 | kernel_size=1, stride=stride, bias=False),
33 | nn.BatchNorm2d(self.expansion*planes)
34 | )
35 |
36 | def forward(self, x):
37 | out = F.relu(self.bn1(self.conv1(x)))
38 | out = self.bn2(self.conv2(out))
39 | out += self.shortcut(x)
40 | out = F.relu(out)
41 | return out
42 |
43 |
44 | class Root(nn.Module):
45 | def __init__(self, in_channels, out_channels, kernel_size=1):
46 | super(Root, self).__init__()
47 | self.conv = nn.Conv2d(
48 | in_channels, out_channels, kernel_size,
49 | stride=1, padding=(kernel_size - 1) // 2, bias=False)
50 | self.bn = nn.BatchNorm2d(out_channels)
51 |
52 | def forward(self, xs):
53 | x = torch.cat(xs, 1)
54 | out = F.relu(self.bn(self.conv(x)))
55 | return out
56 |
57 |
58 | class Tree(nn.Module):
59 | def __init__(self, block, in_channels, out_channels, level=1, stride=1):
60 | super(Tree, self).__init__()
61 | self.root = Root(2*out_channels, out_channels)
62 | if level == 1:
63 | self.left_tree = block(in_channels, out_channels, stride=stride)
64 | self.right_tree = block(out_channels, out_channels, stride=1)
65 | else:
66 | self.left_tree = Tree(block, in_channels,
67 | out_channels, level=level-1, stride=stride)
68 | self.right_tree = Tree(block, out_channels,
69 | out_channels, level=level-1, stride=1)
70 |
71 | def forward(self, x):
72 | out1 = self.left_tree(x)
73 | out2 = self.right_tree(out1)
74 | out = self.root([out1, out2])
75 | return out
76 |
77 |
78 | class SimpleDLA(nn.Module):
79 | def __init__(self, block=BasicBlock, num_classes=10):
80 | super(SimpleDLA, self).__init__()
81 | self.base = nn.Sequential(
82 | nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
83 | nn.BatchNorm2d(16),
84 | nn.ReLU(True)
85 | )
86 |
87 | self.layer1 = nn.Sequential(
88 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
89 | nn.BatchNorm2d(16),
90 | nn.ReLU(True)
91 | )
92 |
93 | self.layer2 = nn.Sequential(
94 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False),
95 | nn.BatchNorm2d(32),
96 | nn.ReLU(True)
97 | )
98 |
99 | self.layer3 = Tree(block, 32, 64, level=1, stride=1)
100 | self.layer4 = Tree(block, 64, 128, level=2, stride=2)
101 | self.layer5 = Tree(block, 128, 256, level=2, stride=2)
102 | self.layer6 = Tree(block, 256, 512, level=1, stride=2)
103 | self.linear = nn.Linear(512, num_classes)
104 |
105 | def forward(self, x):
106 | out = self.base(x)
107 | out = self.layer1(out)
108 | out = self.layer2(out)
109 | out = self.layer3(out)
110 | out = self.layer4(out)
111 | out = self.layer5(out)
112 | out = self.layer6(out)
113 | out = F.avg_pool2d(out, 4)
114 | out = out.view(out.size(0), -1)
115 | out = self.linear(out)
116 | return out
117 |
118 |
119 | def test():
120 | net = SimpleDLA()
121 | print(net)
122 | x = torch.randn(1, 3, 32, 32)
123 | y = net(x)
124 | print(y.size())
125 |
126 |
127 | if __name__ == '__main__':
128 | test()
129 |
--------------------------------------------------------------------------------
/src/models/dpn.py:
--------------------------------------------------------------------------------
1 | '''Dual Path Networks in PyTorch.'''
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class Bottleneck(nn.Module):
8 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer):
9 | super(Bottleneck, self).__init__()
10 | self.out_planes = out_planes
11 | self.dense_depth = dense_depth
12 |
13 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False)
14 | self.bn1 = nn.BatchNorm2d(in_planes)
15 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False)
16 | self.bn2 = nn.BatchNorm2d(in_planes)
17 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False)
18 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth)
19 |
20 | self.shortcut = nn.Sequential()
21 | if first_layer:
22 | self.shortcut = nn.Sequential(
23 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False),
24 | nn.BatchNorm2d(out_planes+dense_depth)
25 | )
26 |
27 | def forward(self, x):
28 | out = F.relu(self.bn1(self.conv1(x)))
29 | out = F.relu(self.bn2(self.conv2(out)))
30 | out = self.bn3(self.conv3(out))
31 | x = self.shortcut(x)
32 | d = self.out_planes
33 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1)
34 | out = F.relu(out)
35 | return out
36 |
37 |
38 | class DPN(nn.Module):
39 | def __init__(self, cfg):
40 | super(DPN, self).__init__()
41 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes']
42 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth']
43 |
44 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
45 | self.bn1 = nn.BatchNorm2d(64)
46 | self.last_planes = 64
47 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1)
48 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2)
49 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2)
50 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2)
51 | self.linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], 10)
52 |
53 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride):
54 | strides = [stride] + [1]*(num_blocks-1)
55 | layers = []
56 | for i,stride in enumerate(strides):
57 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0))
58 | self.last_planes = out_planes + (i+2) * dense_depth
59 | return nn.Sequential(*layers)
60 |
61 | def forward(self, x):
62 | out = F.relu(self.bn1(self.conv1(x)))
63 | out = self.layer1(out)
64 | out = self.layer2(out)
65 | out = self.layer3(out)
66 | out = self.layer4(out)
67 | out = F.avg_pool2d(out, 4)
68 | out = out.view(out.size(0), -1)
69 | out = self.linear(out)
70 | return out
71 |
72 |
73 | def DPN26():
74 | cfg = {
75 | 'in_planes': (96,192,384,768),
76 | 'out_planes': (256,512,1024,2048),
77 | 'num_blocks': (2,2,2,2),
78 | 'dense_depth': (16,32,24,128)
79 | }
80 | return DPN(cfg)
81 |
82 | def DPN92():
83 | cfg = {
84 | 'in_planes': (96,192,384,768),
85 | 'out_planes': (256,512,1024,2048),
86 | 'num_blocks': (3,4,20,3),
87 | 'dense_depth': (16,32,24,128)
88 | }
89 | return DPN(cfg)
90 |
91 |
92 | def test():
93 | net = DPN92()
94 | x = torch.randn(1,3,32,32)
95 | y = net(x)
96 | print(y)
97 |
98 | # test()
99 |
--------------------------------------------------------------------------------
/src/models/efficientnet.py:
--------------------------------------------------------------------------------
1 | '''EfficientNet in PyTorch.
2 |
3 | Paper: "EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks".
4 |
5 | Reference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | def swish(x):
13 | return x * x.sigmoid()
14 |
15 |
16 | def drop_connect(x, drop_ratio):
17 | keep_ratio = 1.0 - drop_ratio
18 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
19 | mask.bernoulli_(keep_ratio)
20 | x.div_(keep_ratio)
21 | x.mul_(mask)
22 | return x
23 |
24 |
25 | class SE(nn.Module):
26 | '''Squeeze-and-Excitation block with Swish.'''
27 |
28 | def __init__(self, in_channels, se_channels):
29 | super(SE, self).__init__()
30 | self.se1 = nn.Conv2d(in_channels, se_channels,
31 | kernel_size=1, bias=True)
32 | self.se2 = nn.Conv2d(se_channels, in_channels,
33 | kernel_size=1, bias=True)
34 |
35 | def forward(self, x):
36 | out = F.adaptive_avg_pool2d(x, (1, 1))
37 | out = swish(self.se1(out))
38 | out = self.se2(out).sigmoid()
39 | out = x * out
40 | return out
41 |
42 |
43 | class Block(nn.Module):
44 | '''expansion + depthwise + pointwise + squeeze-excitation'''
45 |
46 | def __init__(self,
47 | in_channels,
48 | out_channels,
49 | kernel_size,
50 | stride,
51 | expand_ratio=1,
52 | se_ratio=0.,
53 | drop_rate=0.):
54 | super(Block, self).__init__()
55 | self.stride = stride
56 | self.drop_rate = drop_rate
57 | self.expand_ratio = expand_ratio
58 |
59 | # Expansion
60 | channels = expand_ratio * in_channels
61 | self.conv1 = nn.Conv2d(in_channels,
62 | channels,
63 | kernel_size=1,
64 | stride=1,
65 | padding=0,
66 | bias=False)
67 | self.bn1 = nn.BatchNorm2d(channels)
68 |
69 | # Depthwise conv
70 | self.conv2 = nn.Conv2d(channels,
71 | channels,
72 | kernel_size=kernel_size,
73 | stride=stride,
74 | padding=(1 if kernel_size == 3 else 2),
75 | groups=channels,
76 | bias=False)
77 | self.bn2 = nn.BatchNorm2d(channels)
78 |
79 | # SE layers
80 | se_channels = int(in_channels * se_ratio)
81 | self.se = SE(channels, se_channels)
82 |
83 | # Output
84 | self.conv3 = nn.Conv2d(channels,
85 | out_channels,
86 | kernel_size=1,
87 | stride=1,
88 | padding=0,
89 | bias=False)
90 | self.bn3 = nn.BatchNorm2d(out_channels)
91 |
92 | # Skip connection if in and out shapes are the same (MV-V2 style)
93 | self.has_skip = (stride == 1) and (in_channels == out_channels)
94 |
95 | def forward(self, x):
96 | out = x if self.expand_ratio == 1 else swish(self.bn1(self.conv1(x)))
97 | out = swish(self.bn2(self.conv2(out)))
98 | out = self.se(out)
99 | out = self.bn3(self.conv3(out))
100 | if self.has_skip:
101 | if self.training and self.drop_rate > 0:
102 | out = drop_connect(out, self.drop_rate)
103 | out = out + x
104 | return out
105 |
106 |
107 | class EfficientNet(nn.Module):
108 | def __init__(self, cfg, num_classes=10):
109 | super(EfficientNet, self).__init__()
110 | self.cfg = cfg
111 | self.conv1 = nn.Conv2d(3,
112 | 32,
113 | kernel_size=3,
114 | stride=1,
115 | padding=1,
116 | bias=False)
117 | self.bn1 = nn.BatchNorm2d(32)
118 | self.layers = self._make_layers(in_channels=32)
119 | self.linear = nn.Linear(cfg['out_channels'][-1], num_classes)
120 |
121 | def _make_layers(self, in_channels):
122 | layers = []
123 | cfg = [self.cfg[k] for k in ['expansion', 'out_channels', 'num_blocks', 'kernel_size',
124 | 'stride']]
125 | b = 0
126 | blocks = sum(self.cfg['num_blocks'])
127 | for expansion, out_channels, num_blocks, kernel_size, stride in zip(*cfg):
128 | strides = [stride] + [1] * (num_blocks - 1)
129 | for stride in strides:
130 | drop_rate = self.cfg['drop_connect_rate'] * b / blocks
131 | layers.append(
132 | Block(in_channels,
133 | out_channels,
134 | kernel_size,
135 | stride,
136 | expansion,
137 | se_ratio=0.25,
138 | drop_rate=drop_rate))
139 | in_channels = out_channels
140 | return nn.Sequential(*layers)
141 |
142 | def forward(self, x):
143 | out = swish(self.bn1(self.conv1(x)))
144 | out = self.layers(out)
145 | out = F.adaptive_avg_pool2d(out, 1)
146 | out = out.view(out.size(0), -1)
147 | dropout_rate = self.cfg['dropout_rate']
148 | if self.training and dropout_rate > 0:
149 | out = F.dropout(out, p=dropout_rate)
150 | out = self.linear(out)
151 | return out
152 |
153 |
154 | def EfficientNetB0(num_classes=10):
155 | cfg = {
156 | 'num_blocks': [1, 2, 2, 3, 3, 4, 1],
157 | 'expansion': [1, 6, 6, 6, 6, 6, 6],
158 | 'out_channels': [16, 24, 40, 80, 112, 192, 320],
159 | 'kernel_size': [3, 3, 5, 3, 5, 5, 3],
160 | 'stride': [1, 2, 2, 2, 1, 2, 1],
161 | 'dropout_rate': 0.2,
162 | 'drop_connect_rate': 0.2,
163 | }
164 | return EfficientNet(cfg, num_classes=num_classes)
165 |
166 |
167 | def test():
168 | net = EfficientNetB0()
169 | x = torch.randn(2, 3, 32, 32)
170 | y = net(x)
171 | print(y.shape)
172 |
173 |
174 | if __name__ == '__main__':
175 | test()
176 |
--------------------------------------------------------------------------------
/src/models/googlenet.py:
--------------------------------------------------------------------------------
1 | '''GoogLeNet with PyTorch.'''
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class Inception(nn.Module):
8 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
9 | super(Inception, self).__init__()
10 | # 1x1 conv branch
11 | self.b1 = nn.Sequential(
12 | nn.Conv2d(in_planes, n1x1, kernel_size=1),
13 | nn.BatchNorm2d(n1x1),
14 | nn.ReLU(True),
15 | )
16 |
17 | # 1x1 conv -> 3x3 conv branch
18 | self.b2 = nn.Sequential(
19 | nn.Conv2d(in_planes, n3x3red, kernel_size=1),
20 | nn.BatchNorm2d(n3x3red),
21 | nn.ReLU(True),
22 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1),
23 | nn.BatchNorm2d(n3x3),
24 | nn.ReLU(True),
25 | )
26 |
27 | # 1x1 conv -> 5x5 conv branch
28 | self.b3 = nn.Sequential(
29 | nn.Conv2d(in_planes, n5x5red, kernel_size=1),
30 | nn.BatchNorm2d(n5x5red),
31 | nn.ReLU(True),
32 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1),
33 | nn.BatchNorm2d(n5x5),
34 | nn.ReLU(True),
35 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1),
36 | nn.BatchNorm2d(n5x5),
37 | nn.ReLU(True),
38 | )
39 |
40 | # 3x3 pool -> 1x1 conv branch
41 | self.b4 = nn.Sequential(
42 | nn.MaxPool2d(3, stride=1, padding=1),
43 | nn.Conv2d(in_planes, pool_planes, kernel_size=1),
44 | nn.BatchNorm2d(pool_planes),
45 | nn.ReLU(True),
46 | )
47 |
48 | def forward(self, x):
49 | y1 = self.b1(x)
50 | y2 = self.b2(x)
51 | y3 = self.b3(x)
52 | y4 = self.b4(x)
53 | return torch.cat([y1,y2,y3,y4], 1)
54 |
55 |
56 | class GoogLeNet(nn.Module):
57 | def __init__(self):
58 | super(GoogLeNet, self).__init__()
59 | self.pre_layers = nn.Sequential(
60 | nn.Conv2d(3, 192, kernel_size=3, padding=1),
61 | nn.BatchNorm2d(192),
62 | nn.ReLU(True),
63 | )
64 |
65 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
66 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)
67 |
68 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
69 |
70 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
71 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
72 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
73 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
74 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)
75 |
76 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
77 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)
78 |
79 | self.avgpool = nn.AvgPool2d(8, stride=1)
80 | self.linear = nn.Linear(1024, 10)
81 |
82 | def forward(self, x):
83 | out = self.pre_layers(x)
84 | out = self.a3(out)
85 | out = self.b3(out)
86 | out = self.maxpool(out)
87 | out = self.a4(out)
88 | out = self.b4(out)
89 | out = self.c4(out)
90 | out = self.d4(out)
91 | out = self.e4(out)
92 | out = self.maxpool(out)
93 | out = self.a5(out)
94 | out = self.b5(out)
95 | out = self.avgpool(out)
96 | out = out.view(out.size(0), -1)
97 | out = self.linear(out)
98 | return out
99 |
100 |
101 | def test():
102 | net = GoogLeNet()
103 | x = torch.randn(1,3,32,32)
104 | y = net(x)
105 | print(y.size())
106 |
107 | # test()
108 |
--------------------------------------------------------------------------------
/src/models/lenet.py:
--------------------------------------------------------------------------------
1 | '''LeNet in PyTorch.'''
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class LeNet(nn.Module):
6 | def __init__(self):
7 | super(LeNet, self).__init__()
8 | self.conv1 = nn.Conv2d(3, 6, 5)
9 | self.conv2 = nn.Conv2d(6, 16, 5)
10 | self.fc1 = nn.Linear(16*5*5, 120)
11 | self.fc2 = nn.Linear(120, 84)
12 | self.fc3 = nn.Linear(84, 10)
13 |
14 | def forward(self, x):
15 | out = F.relu(self.conv1(x))
16 | out = F.max_pool2d(out, 2)
17 | out = F.relu(self.conv2(out))
18 | out = F.max_pool2d(out, 2)
19 | out = out.view(out.size(0), -1)
20 | out = F.relu(self.fc1(out))
21 | out = F.relu(self.fc2(out))
22 | out = self.fc3(out)
23 | return out
24 |
--------------------------------------------------------------------------------
/src/models/mobilenet.py:
--------------------------------------------------------------------------------
1 | '''MobileNet in PyTorch.
2 |
3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications"
4 | for more details.
5 | '''
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | class Block(nn.Module):
12 | '''Depthwise conv + Pointwise conv'''
13 | def __init__(self, in_planes, out_planes, stride=1):
14 | super(Block, self).__init__()
15 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False)
16 | self.bn1 = nn.BatchNorm2d(in_planes)
17 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
18 | self.bn2 = nn.BatchNorm2d(out_planes)
19 |
20 | def forward(self, x):
21 | out = F.relu(self.bn1(self.conv1(x)))
22 | out = F.relu(self.bn2(self.conv2(out)))
23 | return out
24 |
25 |
26 | class MobileNet(nn.Module):
27 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1
28 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024]
29 |
30 | def __init__(self, num_classes=10):
31 | super(MobileNet, self).__init__()
32 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
33 | self.bn1 = nn.BatchNorm2d(32)
34 | self.layers = self._make_layers(in_planes=32)
35 | self.linear = nn.Linear(1024, num_classes)
36 |
37 | def _make_layers(self, in_planes):
38 | layers = []
39 | for x in self.cfg:
40 | out_planes = x if isinstance(x, int) else x[0]
41 | stride = 1 if isinstance(x, int) else x[1]
42 | layers.append(Block(in_planes, out_planes, stride))
43 | in_planes = out_planes
44 | return nn.Sequential(*layers)
45 |
46 | def forward(self, x):
47 | out = F.relu(self.bn1(self.conv1(x)))
48 | out = self.layers(out)
49 | out = F.avg_pool2d(out, 2)
50 | out = out.view(out.size(0), -1)
51 | out = self.linear(out)
52 | return out
53 |
54 |
55 | def test():
56 | net = MobileNet()
57 | x = torch.randn(1,3,32,32)
58 | y = net(x)
59 | print(y.size())
60 |
61 | # test()
62 |
--------------------------------------------------------------------------------
/src/models/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | '''MobileNetV2 in PyTorch.
2 |
3 | See the paper "Inverted Residuals and Linear Bottlenecks:
4 | Mobile Networks for Classification, Detection and Segmentation" for more details.
5 | '''
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | class Block(nn.Module):
12 | '''expand + depthwise + pointwise'''
13 | def __init__(self, in_planes, out_planes, expansion, stride):
14 | super(Block, self).__init__()
15 | self.stride = stride
16 |
17 | planes = expansion * in_planes
18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
19 | self.bn1 = nn.BatchNorm2d(planes)
20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
21 | self.bn2 = nn.BatchNorm2d(planes)
22 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
23 | self.bn3 = nn.BatchNorm2d(out_planes)
24 |
25 | self.shortcut = nn.Sequential()
26 | if stride == 1 and in_planes != out_planes:
27 | self.shortcut = nn.Sequential(
28 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),
29 | nn.BatchNorm2d(out_planes),
30 | )
31 |
32 | def forward(self, x):
33 | out = F.relu(self.bn1(self.conv1(x)))
34 | out = F.relu(self.bn2(self.conv2(out)))
35 | out = self.bn3(self.conv3(out))
36 | out = out + self.shortcut(x) if self.stride==1 else out
37 | return out
38 |
39 |
40 | class MobileNetV2(nn.Module):
41 | # (expansion, out_planes, num_blocks, stride)
42 | cfg = [(1, 16, 1, 1),
43 | # (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10
44 | (6, 24, 2, 2), # NOTE: change for STL10
45 | (6, 32, 3, 2),
46 | (6, 64, 4, 2),
47 | (6, 96, 3, 1),
48 | (6, 160, 3, 2),
49 | (6, 320, 1, 1)]
50 |
51 | def __init__(self, num_classes=10):
52 | super(MobileNetV2, self).__init__()
53 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10
54 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
55 | self.bn1 = nn.BatchNorm2d(32)
56 | self.layers = self._make_layers(in_planes=32)
57 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
58 | self.bn2 = nn.BatchNorm2d(1280)
59 | self.linear = nn.Linear(1280, num_classes)
60 |
61 | def _make_layers(self, in_planes):
62 | layers = []
63 | for expansion, out_planes, num_blocks, stride in self.cfg:
64 | strides = [stride] + [1]*(num_blocks-1)
65 | for stride in strides:
66 | layers.append(Block(in_planes, out_planes, expansion, stride))
67 | in_planes = out_planes
68 | return nn.Sequential(*layers)
69 |
70 | def forward(self, x):
71 | out = F.relu(self.bn1(self.conv1(x)))
72 | out = self.layers(out)
73 | out = F.relu(self.bn2(self.conv2(out)))
74 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10
75 | # out = F.avg_pool2d(out, 4)
76 | out = F.adaptive_avg_pool2d(out, 1) # NOTE: change for STL10
77 | out = out.view(out.size(0), -1)
78 | out = self.linear(out)
79 | return out
80 |
81 |
82 | def test():
83 | net = MobileNetV2()
84 | x = torch.randn(2,3,32,32)
85 | y = net(x)
86 | print(y.size())
87 |
88 | # test()
89 |
--------------------------------------------------------------------------------
/src/models/pnasnet.py:
--------------------------------------------------------------------------------
1 | '''PNASNet in PyTorch.
2 |
3 | Paper: Progressive Neural Architecture Search
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class SepConv(nn.Module):
11 | '''Separable Convolution.'''
12 | def __init__(self, in_planes, out_planes, kernel_size, stride):
13 | super(SepConv, self).__init__()
14 | self.conv1 = nn.Conv2d(in_planes, out_planes,
15 | kernel_size, stride,
16 | padding=(kernel_size-1)//2,
17 | bias=False, groups=in_planes)
18 | self.bn1 = nn.BatchNorm2d(out_planes)
19 |
20 | def forward(self, x):
21 | return self.bn1(self.conv1(x))
22 |
23 |
24 | class CellA(nn.Module):
25 | def __init__(self, in_planes, out_planes, stride=1):
26 | super(CellA, self).__init__()
27 | self.stride = stride
28 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride)
29 | if stride==2:
30 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
31 | self.bn1 = nn.BatchNorm2d(out_planes)
32 |
33 | def forward(self, x):
34 | y1 = self.sep_conv1(x)
35 | y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1)
36 | if self.stride==2:
37 | y2 = self.bn1(self.conv1(y2))
38 | return F.relu(y1+y2)
39 |
40 | class CellB(nn.Module):
41 | def __init__(self, in_planes, out_planes, stride=1):
42 | super(CellB, self).__init__()
43 | self.stride = stride
44 | # Left branch
45 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride)
46 | self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride)
47 | # Right branch
48 | self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride)
49 | if stride==2:
50 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
51 | self.bn1 = nn.BatchNorm2d(out_planes)
52 | # Reduce channels
53 | self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
54 | self.bn2 = nn.BatchNorm2d(out_planes)
55 |
56 | def forward(self, x):
57 | # Left branch
58 | y1 = self.sep_conv1(x)
59 | y2 = self.sep_conv2(x)
60 | # Right branch
61 | y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1)
62 | if self.stride==2:
63 | y3 = self.bn1(self.conv1(y3))
64 | y4 = self.sep_conv3(x)
65 | # Concat & reduce channels
66 | b1 = F.relu(y1+y2)
67 | b2 = F.relu(y3+y4)
68 | y = torch.cat([b1,b2], 1)
69 | return F.relu(self.bn2(self.conv2(y)))
70 |
71 | class PNASNet(nn.Module):
72 | def __init__(self, cell_type, num_cells, num_planes):
73 | super(PNASNet, self).__init__()
74 | self.in_planes = num_planes
75 | self.cell_type = cell_type
76 |
77 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False)
78 | self.bn1 = nn.BatchNorm2d(num_planes)
79 |
80 | self.layer1 = self._make_layer(num_planes, num_cells=6)
81 | self.layer2 = self._downsample(num_planes*2)
82 | self.layer3 = self._make_layer(num_planes*2, num_cells=6)
83 | self.layer4 = self._downsample(num_planes*4)
84 | self.layer5 = self._make_layer(num_planes*4, num_cells=6)
85 |
86 | self.linear = nn.Linear(num_planes*4, 10)
87 |
88 | def _make_layer(self, planes, num_cells):
89 | layers = []
90 | for _ in range(num_cells):
91 | layers.append(self.cell_type(self.in_planes, planes, stride=1))
92 | self.in_planes = planes
93 | return nn.Sequential(*layers)
94 |
95 | def _downsample(self, planes):
96 | layer = self.cell_type(self.in_planes, planes, stride=2)
97 | self.in_planes = planes
98 | return layer
99 |
100 | def forward(self, x):
101 | out = F.relu(self.bn1(self.conv1(x)))
102 | out = self.layer1(out)
103 | out = self.layer2(out)
104 | out = self.layer3(out)
105 | out = self.layer4(out)
106 | out = self.layer5(out)
107 | out = F.avg_pool2d(out, 8)
108 | out = self.linear(out.view(out.size(0), -1))
109 | return out
110 |
111 |
112 | def PNASNetA():
113 | return PNASNet(CellA, num_cells=6, num_planes=44)
114 |
115 | def PNASNetB():
116 | return PNASNet(CellB, num_cells=6, num_planes=32)
117 |
118 |
119 | def test():
120 | net = PNASNetB()
121 | x = torch.randn(1,3,32,32)
122 | y = net(x)
123 | print(y)
124 |
125 | # test()
126 |
--------------------------------------------------------------------------------
/src/models/preact_resnet.py:
--------------------------------------------------------------------------------
1 | '''Pre-activation ResNet in PyTorch.
2 |
3 | Reference:
4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
5 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class PreActBlock(nn.Module):
13 | '''Pre-activation version of the BasicBlock.'''
14 | expansion = 1
15 |
16 | def __init__(self, in_planes, planes, stride=1):
17 | super(PreActBlock, self).__init__()
18 | self.bn1 = nn.BatchNorm2d(in_planes)
19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
22 |
23 | if stride != 1 or in_planes != self.expansion*planes:
24 | self.shortcut = nn.Sequential(
25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
26 | )
27 |
28 | def forward(self, x):
29 | out = F.relu(self.bn1(x))
30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
31 | out = self.conv1(out)
32 | out = self.conv2(F.relu(self.bn2(out)))
33 | out += shortcut
34 | return out
35 |
36 |
37 | class PreActBottleneck(nn.Module):
38 | '''Pre-activation version of the original Bottleneck module.'''
39 | expansion = 4
40 |
41 | def __init__(self, in_planes, planes, stride=1):
42 | super(PreActBottleneck, self).__init__()
43 | self.bn1 = nn.BatchNorm2d(in_planes)
44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
45 | self.bn2 = nn.BatchNorm2d(planes)
46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
47 | self.bn3 = nn.BatchNorm2d(planes)
48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
49 |
50 | if stride != 1 or in_planes != self.expansion*planes:
51 | self.shortcut = nn.Sequential(
52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
53 | )
54 |
55 | def forward(self, x):
56 | out = F.relu(self.bn1(x))
57 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
58 | out = self.conv1(out)
59 | out = self.conv2(F.relu(self.bn2(out)))
60 | out = self.conv3(F.relu(self.bn3(out)))
61 | out += shortcut
62 | return out
63 |
64 |
65 | class PreActResNet(nn.Module):
66 | def __init__(self, block, num_blocks, num_classes=10):
67 | super(PreActResNet, self).__init__()
68 | self.in_planes = 64
69 |
70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
71 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
72 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
73 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
74 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
75 | self.linear = nn.Linear(512*block.expansion, num_classes)
76 |
77 | def _make_layer(self, block, planes, num_blocks, stride):
78 | strides = [stride] + [1]*(num_blocks-1)
79 | layers = []
80 | for stride in strides:
81 | layers.append(block(self.in_planes, planes, stride))
82 | self.in_planes = planes * block.expansion
83 | return nn.Sequential(*layers)
84 |
85 | def forward(self, x):
86 | out = self.conv1(x)
87 | out = self.layer1(out)
88 | out = self.layer2(out)
89 | out = self.layer3(out)
90 | out = self.layer4(out)
91 | out = F.avg_pool2d(out, 4)
92 | out = out.view(out.size(0), -1)
93 | out = self.linear(out)
94 | return out
95 |
96 |
97 | def PreActResNet18():
98 | return PreActResNet(PreActBlock, [2,2,2,2])
99 |
100 | def PreActResNet34():
101 | return PreActResNet(PreActBlock, [3,4,6,3])
102 |
103 | def PreActResNet50():
104 | return PreActResNet(PreActBottleneck, [3,4,6,3])
105 |
106 | def PreActResNet101():
107 | return PreActResNet(PreActBottleneck, [3,4,23,3])
108 |
109 | def PreActResNet152():
110 | return PreActResNet(PreActBottleneck, [3,8,36,3])
111 |
112 |
113 | def test():
114 | net = PreActResNet18()
115 | y = net((torch.randn(1,3,32,32)))
116 | print(y.size())
117 |
118 | # test()
119 |
--------------------------------------------------------------------------------
/src/models/regnet.py:
--------------------------------------------------------------------------------
1 | '''RegNet in PyTorch.
2 |
3 | Paper: "Designing Network Design Spaces".
4 |
5 | Reference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class SE(nn.Module):
13 | '''Squeeze-and-Excitation block.'''
14 |
15 | def __init__(self, in_planes, se_planes):
16 | super(SE, self).__init__()
17 | self.se1 = nn.Conv2d(in_planes, se_planes, kernel_size=1, bias=True)
18 | self.se2 = nn.Conv2d(se_planes, in_planes, kernel_size=1, bias=True)
19 |
20 | def forward(self, x):
21 | out = F.adaptive_avg_pool2d(x, (1, 1))
22 | out = F.relu(self.se1(out))
23 | out = self.se2(out).sigmoid()
24 | out = x * out
25 | return out
26 |
27 |
28 | class Block(nn.Module):
29 | def __init__(self, w_in, w_out, stride, group_width, bottleneck_ratio, se_ratio):
30 | super(Block, self).__init__()
31 | # 1x1
32 | w_b = int(round(w_out * bottleneck_ratio))
33 | self.conv1 = nn.Conv2d(w_in, w_b, kernel_size=1, bias=False)
34 | self.bn1 = nn.BatchNorm2d(w_b)
35 | # 3x3
36 | num_groups = w_b // group_width
37 | self.conv2 = nn.Conv2d(w_b, w_b, kernel_size=3,
38 | stride=stride, padding=1, groups=num_groups, bias=False)
39 | self.bn2 = nn.BatchNorm2d(w_b)
40 | # se
41 | self.with_se = se_ratio > 0
42 | if self.with_se:
43 | w_se = int(round(w_in * se_ratio))
44 | self.se = SE(w_b, w_se)
45 | # 1x1
46 | self.conv3 = nn.Conv2d(w_b, w_out, kernel_size=1, bias=False)
47 | self.bn3 = nn.BatchNorm2d(w_out)
48 |
49 | self.shortcut = nn.Sequential()
50 | if stride != 1 or w_in != w_out:
51 | self.shortcut = nn.Sequential(
52 | nn.Conv2d(w_in, w_out,
53 | kernel_size=1, stride=stride, bias=False),
54 | nn.BatchNorm2d(w_out)
55 | )
56 |
57 | def forward(self, x):
58 | out = F.relu(self.bn1(self.conv1(x)))
59 | out = F.relu(self.bn2(self.conv2(out)))
60 | if self.with_se:
61 | out = self.se(out)
62 | out = self.bn3(self.conv3(out))
63 | out += self.shortcut(x)
64 | out = F.relu(out)
65 | return out
66 |
67 |
68 | class RegNet(nn.Module):
69 | def __init__(self, cfg, num_classes=10):
70 | super(RegNet, self).__init__()
71 | self.cfg = cfg
72 | self.in_planes = 64
73 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
74 | stride=1, padding=1, bias=False)
75 | self.bn1 = nn.BatchNorm2d(64)
76 | self.layer1 = self._make_layer(0)
77 | self.layer2 = self._make_layer(1)
78 | self.layer3 = self._make_layer(2)
79 | self.layer4 = self._make_layer(3)
80 | self.linear = nn.Linear(self.cfg['widths'][-1], num_classes)
81 |
82 | def _make_layer(self, idx):
83 | depth = self.cfg['depths'][idx]
84 | width = self.cfg['widths'][idx]
85 | stride = self.cfg['strides'][idx]
86 | group_width = self.cfg['group_width']
87 | bottleneck_ratio = self.cfg['bottleneck_ratio']
88 | se_ratio = self.cfg['se_ratio']
89 |
90 | layers = []
91 | for i in range(depth):
92 | s = stride if i == 0 else 1
93 | layers.append(Block(self.in_planes, width,
94 | s, group_width, bottleneck_ratio, se_ratio))
95 | self.in_planes = width
96 | return nn.Sequential(*layers)
97 |
98 | def forward(self, x):
99 | out = F.relu(self.bn1(self.conv1(x)))
100 | out = self.layer1(out)
101 | out = self.layer2(out)
102 | out = self.layer3(out)
103 | out = self.layer4(out)
104 | out = F.adaptive_avg_pool2d(out, (1, 1))
105 | out = out.view(out.size(0), -1)
106 | out = self.linear(out)
107 | return out
108 |
109 |
110 | def RegNetX_200MF():
111 | cfg = {
112 | 'depths': [1, 1, 4, 7],
113 | 'widths': [24, 56, 152, 368],
114 | 'strides': [1, 1, 2, 2],
115 | 'group_width': 8,
116 | 'bottleneck_ratio': 1,
117 | 'se_ratio': 0,
118 | }
119 | return RegNet(cfg)
120 |
121 |
122 | def RegNetX_400MF():
123 | cfg = {
124 | 'depths': [1, 2, 7, 12],
125 | 'widths': [32, 64, 160, 384],
126 | 'strides': [1, 1, 2, 2],
127 | 'group_width': 16,
128 | 'bottleneck_ratio': 1,
129 | 'se_ratio': 0,
130 | }
131 | return RegNet(cfg)
132 |
133 |
134 | def RegNetY_400MF():
135 | cfg = {
136 | 'depths': [1, 2, 7, 12],
137 | 'widths': [32, 64, 160, 384],
138 | 'strides': [1, 1, 2, 2],
139 | 'group_width': 16,
140 | 'bottleneck_ratio': 1,
141 | 'se_ratio': 0.25,
142 | }
143 | return RegNet(cfg)
144 |
145 |
146 | def test():
147 | net = RegNetX_200MF()
148 | print(net)
149 | x = torch.randn(2, 3, 32, 32)
150 | y = net(x)
151 | print(y.shape)
152 |
153 |
154 | if __name__ == '__main__':
155 | test()
156 |
--------------------------------------------------------------------------------
/src/models/resnet.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 |
3 | For Pre-activation ResNet, see 'preact_resnet.py'.
4 |
5 | Reference:
6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
8 | '''
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, in_planes, planes, stride=1):
18 | super(BasicBlock, self).__init__()
19 | self.conv1 = nn.Conv2d(
20 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
21 | self.bn1 = nn.BatchNorm2d(planes)
22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
23 | stride=1, padding=1, bias=False)
24 | self.bn2 = nn.BatchNorm2d(planes)
25 |
26 | self.shortcut = nn.Sequential()
27 | if stride != 1 or in_planes != self.expansion*planes:
28 | self.shortcut = nn.Sequential(
29 | nn.Conv2d(in_planes, self.expansion*planes,
30 | kernel_size=1, stride=stride, bias=False),
31 | nn.BatchNorm2d(self.expansion*planes)
32 | )
33 |
34 | def forward(self, x):
35 | out = F.relu(self.bn1(self.conv1(x)))
36 | out = self.bn2(self.conv2(out))
37 | out += self.shortcut(x)
38 | out = F.relu(out)
39 | return out
40 |
41 |
42 | class Bottleneck(nn.Module):
43 | expansion = 4
44 |
45 | def __init__(self, in_planes, planes, stride=1):
46 | super(Bottleneck, self).__init__()
47 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
48 | self.bn1 = nn.BatchNorm2d(planes)
49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
50 | stride=stride, padding=1, bias=False)
51 | self.bn2 = nn.BatchNorm2d(planes)
52 | self.conv3 = nn.Conv2d(planes, self.expansion *
53 | planes, kernel_size=1, bias=False)
54 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
55 |
56 | self.shortcut = nn.Sequential()
57 | if stride != 1 or in_planes != self.expansion*planes:
58 | self.shortcut = nn.Sequential(
59 | nn.Conv2d(in_planes, self.expansion*planes,
60 | kernel_size=1, stride=stride, bias=False),
61 | nn.BatchNorm2d(self.expansion*planes)
62 | )
63 |
64 | def forward(self, x):
65 | out = F.relu(self.bn1(self.conv1(x)))
66 | out = F.relu(self.bn2(self.conv2(out)))
67 | out = self.bn3(self.conv3(out))
68 | out += self.shortcut(x)
69 | out = F.relu(out)
70 | return out
71 |
72 |
73 | class ResNet(nn.Module):
74 | def __init__(self, block, num_blocks, num_classes=10, resolution=32):
75 | super(ResNet, self).__init__()
76 | self.in_planes = 64
77 | self.resolution = resolution
78 |
79 | print(">>>>> ResNet resolution: ", resolution, " <<<<<")
80 | if resolution == 32:
81 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
82 | stride=1, padding=1, bias=False)
83 | elif resolution == 28:
84 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
85 | stride=1, padding=3, bias=False)
86 | elif resolution == 96:
87 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7,
88 | stride=2, padding=3, bias=False)
89 | elif resolution == 256:
90 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7,
91 | stride=2, padding=3, bias=False)
92 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
93 |
94 | self.bn1 = nn.BatchNorm2d(64)
95 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
96 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
97 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
98 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
99 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
100 | self.linear = nn.Linear(512*block.expansion, num_classes)
101 |
102 | def _make_layer(self, block, planes, num_blocks, stride):
103 | strides = [stride] + [1]*(num_blocks-1)
104 | layers = []
105 | for stride in strides:
106 | layers.append(block(self.in_planes, planes, stride))
107 | self.in_planes = planes * block.expansion
108 | return nn.Sequential(*layers)
109 |
110 | def forward(self, x):
111 | out = F.relu(self.bn1(self.conv1(x)))
112 | if self.resolution==256:
113 | out = self.maxpool(out)
114 | out = self.layer1(out)
115 | out = self.layer2(out)
116 | out = self.layer3(out)
117 | out = self.layer4(out)
118 | out = self.avgpool(out)
119 | out = out.view(out.size(0), -1)
120 | out = self.linear(out)
121 | return out
122 |
123 | def ResNet18(num_classes=10, resolution=32):
124 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, resolution=resolution)
125 |
126 |
127 | def ResNet34(num_classes=10, resolution=32):
128 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, resolution=resolution)
129 |
130 |
131 | def ResNet50(num_classes=10, resolution=32):
132 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, resolution=resolution)
133 |
134 |
135 | def ResNet101(num_classes=10, resolution=32):
136 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, resolution=resolution)
137 |
138 |
139 | def ResNet152(num_classes=10, resolution=32):
140 | return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, resolution=resolution)
141 |
142 |
143 | def test():
144 | net = ResNet18()
145 | y = net(torch.randn(1, 3, 32, 32))
146 | print(y.size())
147 |
148 | # test()
149 |
--------------------------------------------------------------------------------
/src/models/resnext.py:
--------------------------------------------------------------------------------
1 | '''ResNeXt in PyTorch.
2 |
3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details.
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class Block(nn.Module):
11 | '''Grouped convolution block.'''
12 | expansion = 2
13 |
14 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1):
15 | super(Block, self).__init__()
16 | group_width = cardinality * bottleneck_width
17 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(group_width)
19 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)
20 | self.bn2 = nn.BatchNorm2d(group_width)
21 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False)
22 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width)
23 |
24 | self.shortcut = nn.Sequential()
25 | if stride != 1 or in_planes != self.expansion*group_width:
26 | self.shortcut = nn.Sequential(
27 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False),
28 | nn.BatchNorm2d(self.expansion*group_width)
29 | )
30 |
31 | def forward(self, x):
32 | out = F.relu(self.bn1(self.conv1(x)))
33 | out = F.relu(self.bn2(self.conv2(out)))
34 | out = self.bn3(self.conv3(out))
35 | out += self.shortcut(x)
36 | out = F.relu(out)
37 | return out
38 |
39 |
40 | class ResNeXt(nn.Module):
41 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10):
42 | super(ResNeXt, self).__init__()
43 | self.cardinality = cardinality
44 | self.bottleneck_width = bottleneck_width
45 | self.in_planes = 64
46 |
47 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False)
48 | self.bn1 = nn.BatchNorm2d(64)
49 | self.layer1 = self._make_layer(num_blocks[0], 1)
50 | self.layer2 = self._make_layer(num_blocks[1], 2)
51 | self.layer3 = self._make_layer(num_blocks[2], 2)
52 | # self.layer4 = self._make_layer(num_blocks[3], 2)
53 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes)
54 |
55 | def _make_layer(self, num_blocks, stride):
56 | strides = [stride] + [1]*(num_blocks-1)
57 | layers = []
58 | for stride in strides:
59 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride))
60 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width
61 | # Increase bottleneck_width by 2 after each stage.
62 | self.bottleneck_width *= 2
63 | return nn.Sequential(*layers)
64 |
65 | def forward(self, x):
66 | out = F.relu(self.bn1(self.conv1(x)))
67 | out = self.layer1(out)
68 | out = self.layer2(out)
69 | out = self.layer3(out)
70 | # out = self.layer4(out)
71 | out = F.avg_pool2d(out, 8)
72 | out = out.view(out.size(0), -1)
73 | out = self.linear(out)
74 | return out
75 |
76 |
77 | def ResNeXt29_2x64d():
78 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64)
79 |
80 | def ResNeXt29_4x64d():
81 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64)
82 |
83 | def ResNeXt29_8x64d():
84 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64)
85 |
86 | def ResNeXt29_32x4d():
87 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4)
88 |
89 | def test_resnext():
90 | net = ResNeXt29_2x64d()
91 | x = torch.randn(1,3,32,32)
92 | y = net(x)
93 | print(y.size())
94 |
95 | # test_resnext()
96 |
--------------------------------------------------------------------------------
/src/models/senet.py:
--------------------------------------------------------------------------------
1 | '''SENet in PyTorch.
2 |
3 | SENet is the winner of ImageNet-2017. The paper is not released yet.
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class BasicBlock(nn.Module):
11 | def __init__(self, in_planes, planes, stride=1):
12 | super(BasicBlock, self).__init__()
13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
14 | self.bn1 = nn.BatchNorm2d(planes)
15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
16 | self.bn2 = nn.BatchNorm2d(planes)
17 |
18 | self.shortcut = nn.Sequential()
19 | if stride != 1 or in_planes != planes:
20 | self.shortcut = nn.Sequential(
21 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
22 | nn.BatchNorm2d(planes)
23 | )
24 |
25 | # SE layers
26 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear
27 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1)
28 |
29 | def forward(self, x):
30 | out = F.relu(self.bn1(self.conv1(x)))
31 | out = self.bn2(self.conv2(out))
32 |
33 | # Squeeze
34 | w = F.avg_pool2d(out, out.size(2))
35 | w = F.relu(self.fc1(w))
36 | w = F.sigmoid(self.fc2(w))
37 | # Excitation
38 | out = out * w # New broadcasting feature from v0.2!
39 |
40 | out += self.shortcut(x)
41 | out = F.relu(out)
42 | return out
43 |
44 |
45 | class PreActBlock(nn.Module):
46 | def __init__(self, in_planes, planes, stride=1):
47 | super(PreActBlock, self).__init__()
48 | self.bn1 = nn.BatchNorm2d(in_planes)
49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
50 | self.bn2 = nn.BatchNorm2d(planes)
51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
52 |
53 | if stride != 1 or in_planes != planes:
54 | self.shortcut = nn.Sequential(
55 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False)
56 | )
57 |
58 | # SE layers
59 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1)
60 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1)
61 |
62 | def forward(self, x):
63 | out = F.relu(self.bn1(x))
64 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
65 | out = self.conv1(out)
66 | out = self.conv2(F.relu(self.bn2(out)))
67 |
68 | # Squeeze
69 | w = F.avg_pool2d(out, out.size(2))
70 | w = F.relu(self.fc1(w))
71 | w = F.sigmoid(self.fc2(w))
72 | # Excitation
73 | out = out * w
74 |
75 | out += shortcut
76 | return out
77 |
78 |
79 | class SENet(nn.Module):
80 | def __init__(self, block, num_blocks, num_classes=10):
81 | super(SENet, self).__init__()
82 | self.in_planes = 64
83 |
84 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
85 | self.bn1 = nn.BatchNorm2d(64)
86 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
87 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
88 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
89 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
90 | self.linear = nn.Linear(512, num_classes)
91 |
92 | def _make_layer(self, block, planes, num_blocks, stride):
93 | strides = [stride] + [1]*(num_blocks-1)
94 | layers = []
95 | for stride in strides:
96 | layers.append(block(self.in_planes, planes, stride))
97 | self.in_planes = planes
98 | return nn.Sequential(*layers)
99 |
100 | def forward(self, x):
101 | out = F.relu(self.bn1(self.conv1(x)))
102 | out = self.layer1(out)
103 | out = self.layer2(out)
104 | out = self.layer3(out)
105 | out = self.layer4(out)
106 | out = F.avg_pool2d(out, 4)
107 | out = out.view(out.size(0), -1)
108 | out = self.linear(out)
109 | return out
110 |
111 |
112 | def SENet18():
113 | return SENet(PreActBlock, [2,2,2,2])
114 |
115 |
116 | def test():
117 | net = SENet18()
118 | y = net(torch.randn(1,3,32,32))
119 | print(y.size())
120 |
121 | # test()
122 |
--------------------------------------------------------------------------------
/src/models/shufflenet.py:
--------------------------------------------------------------------------------
1 | '''ShuffleNet in PyTorch.
2 |
3 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details.
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class ShuffleBlock(nn.Module):
11 | def __init__(self, groups):
12 | super(ShuffleBlock, self).__init__()
13 | self.groups = groups
14 |
15 | def forward(self, x):
16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
17 | N,C,H,W = x.size()
18 | g = self.groups
19 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W)
20 |
21 |
22 | class Bottleneck(nn.Module):
23 | def __init__(self, in_planes, out_planes, stride, groups):
24 | super(Bottleneck, self).__init__()
25 | self.stride = stride
26 |
27 | mid_planes = out_planes/4
28 | g = 1 if in_planes==24 else groups
29 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False)
30 | self.bn1 = nn.BatchNorm2d(mid_planes)
31 | self.shuffle1 = ShuffleBlock(groups=g)
32 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False)
33 | self.bn2 = nn.BatchNorm2d(mid_planes)
34 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False)
35 | self.bn3 = nn.BatchNorm2d(out_planes)
36 |
37 | self.shortcut = nn.Sequential()
38 | if stride == 2:
39 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1))
40 |
41 | def forward(self, x):
42 | out = F.relu(self.bn1(self.conv1(x)))
43 | out = self.shuffle1(out)
44 | out = F.relu(self.bn2(self.conv2(out)))
45 | out = self.bn3(self.conv3(out))
46 | res = self.shortcut(x)
47 | out = F.relu(torch.cat([out,res], 1)) if self.stride==2 else F.relu(out+res)
48 | return out
49 |
50 |
51 | class ShuffleNet(nn.Module):
52 | def __init__(self, cfg):
53 | super(ShuffleNet, self).__init__()
54 | out_planes = cfg['out_planes']
55 | num_blocks = cfg['num_blocks']
56 | groups = cfg['groups']
57 |
58 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False)
59 | self.bn1 = nn.BatchNorm2d(24)
60 | self.in_planes = 24
61 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups)
62 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups)
63 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups)
64 | self.linear = nn.Linear(out_planes[2], 10)
65 |
66 | def _make_layer(self, out_planes, num_blocks, groups):
67 | layers = []
68 | for i in range(num_blocks):
69 | stride = 2 if i == 0 else 1
70 | cat_planes = self.in_planes if i == 0 else 0
71 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups))
72 | self.in_planes = out_planes
73 | return nn.Sequential(*layers)
74 |
75 | def forward(self, x):
76 | out = F.relu(self.bn1(self.conv1(x)))
77 | out = self.layer1(out)
78 | out = self.layer2(out)
79 | out = self.layer3(out)
80 | out = F.avg_pool2d(out, 4)
81 | out = out.view(out.size(0), -1)
82 | out = self.linear(out)
83 | return out
84 |
85 |
86 | def ShuffleNetG2():
87 | cfg = {
88 | 'out_planes': [200,400,800],
89 | 'num_blocks': [4,8,4],
90 | 'groups': 2
91 | }
92 | return ShuffleNet(cfg)
93 |
94 | def ShuffleNetG3():
95 | cfg = {
96 | 'out_planes': [240,480,960],
97 | 'num_blocks': [4,8,4],
98 | 'groups': 3
99 | }
100 | return ShuffleNet(cfg)
101 |
102 |
103 | def test():
104 | net = ShuffleNetG2()
105 | x = torch.randn(1,3,32,32)
106 | y = net(x)
107 | print(y)
108 |
109 | # test()
110 |
--------------------------------------------------------------------------------
/src/models/shufflenetv2.py:
--------------------------------------------------------------------------------
1 | '''ShuffleNetV2 in PyTorch.
2 |
3 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details.
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class ShuffleBlock(nn.Module):
11 | def __init__(self, groups=2):
12 | super(ShuffleBlock, self).__init__()
13 | self.groups = groups
14 |
15 | def forward(self, x):
16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
17 | N, C, H, W = x.size()
18 | g = self.groups
19 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W)
20 |
21 |
22 | class SplitBlock(nn.Module):
23 | def __init__(self, ratio):
24 | super(SplitBlock, self).__init__()
25 | self.ratio = ratio
26 |
27 | def forward(self, x):
28 | c = int(x.size(1) * self.ratio)
29 | return x[:, :c, :, :], x[:, c:, :, :]
30 |
31 |
32 | class BasicBlock(nn.Module):
33 | def __init__(self, in_channels, split_ratio=0.5):
34 | super(BasicBlock, self).__init__()
35 | self.split = SplitBlock(split_ratio)
36 | in_channels = int(in_channels * split_ratio)
37 | self.conv1 = nn.Conv2d(in_channels, in_channels,
38 | kernel_size=1, bias=False)
39 | self.bn1 = nn.BatchNorm2d(in_channels)
40 | self.conv2 = nn.Conv2d(in_channels, in_channels,
41 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False)
42 | self.bn2 = nn.BatchNorm2d(in_channels)
43 | self.conv3 = nn.Conv2d(in_channels, in_channels,
44 | kernel_size=1, bias=False)
45 | self.bn3 = nn.BatchNorm2d(in_channels)
46 | self.shuffle = ShuffleBlock()
47 |
48 | def forward(self, x):
49 | x1, x2 = self.split(x)
50 | out = F.relu(self.bn1(self.conv1(x2)))
51 | out = self.bn2(self.conv2(out))
52 | out = F.relu(self.bn3(self.conv3(out)))
53 | out = torch.cat([x1, out], 1)
54 | out = self.shuffle(out)
55 | return out
56 |
57 |
58 | class DownBlock(nn.Module):
59 | def __init__(self, in_channels, out_channels):
60 | super(DownBlock, self).__init__()
61 | mid_channels = out_channels // 2
62 | # left
63 | self.conv1 = nn.Conv2d(in_channels, in_channels,
64 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False)
65 | self.bn1 = nn.BatchNorm2d(in_channels)
66 | self.conv2 = nn.Conv2d(in_channels, mid_channels,
67 | kernel_size=1, bias=False)
68 | self.bn2 = nn.BatchNorm2d(mid_channels)
69 | # right
70 | self.conv3 = nn.Conv2d(in_channels, mid_channels,
71 | kernel_size=1, bias=False)
72 | self.bn3 = nn.BatchNorm2d(mid_channels)
73 | self.conv4 = nn.Conv2d(mid_channels, mid_channels,
74 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False)
75 | self.bn4 = nn.BatchNorm2d(mid_channels)
76 | self.conv5 = nn.Conv2d(mid_channels, mid_channels,
77 | kernel_size=1, bias=False)
78 | self.bn5 = nn.BatchNorm2d(mid_channels)
79 |
80 | self.shuffle = ShuffleBlock()
81 |
82 | def forward(self, x):
83 | # left
84 | out1 = self.bn1(self.conv1(x))
85 | out1 = F.relu(self.bn2(self.conv2(out1)))
86 | # right
87 | out2 = F.relu(self.bn3(self.conv3(x)))
88 | out2 = self.bn4(self.conv4(out2))
89 | out2 = F.relu(self.bn5(self.conv5(out2)))
90 | # concat
91 | out = torch.cat([out1, out2], 1)
92 | out = self.shuffle(out)
93 | return out
94 |
95 |
96 | class ShuffleNetV2(nn.Module):
97 | def __init__(self, net_size, num_classes=10):
98 | super(ShuffleNetV2, self).__init__()
99 | out_channels = configs[net_size]['out_channels']
100 | num_blocks = configs[net_size]['num_blocks']
101 |
102 | self.conv1 = nn.Conv2d(3, 24, kernel_size=3,
103 | stride=1, padding=1, bias=False)
104 | self.bn1 = nn.BatchNorm2d(24)
105 | self.in_channels = 24
106 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0])
107 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1])
108 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2])
109 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3],
110 | kernel_size=1, stride=1, padding=0, bias=False)
111 | self.bn2 = nn.BatchNorm2d(out_channels[3])
112 | self.linear = nn.Linear(out_channels[3], num_classes)
113 |
114 | def _make_layer(self, out_channels, num_blocks):
115 | layers = [DownBlock(self.in_channels, out_channels)]
116 | for i in range(num_blocks):
117 | layers.append(BasicBlock(out_channels))
118 | self.in_channels = out_channels
119 | return nn.Sequential(*layers)
120 |
121 | def forward(self, x):
122 | out = F.relu(self.bn1(self.conv1(x)))
123 | out = F.max_pool2d(out, 3, stride=2, padding=1)
124 | out = self.layer1(out)
125 | out = self.layer2(out)
126 | out = self.layer3(out)
127 | out = F.relu(self.bn2(self.conv2(out)))
128 | out = F.adaptive_avg_pool2d(out, 1) # NOTE: change for STL10
129 | out = out.view(out.size(0), -1)
130 | out = self.linear(out)
131 | return out
132 |
133 |
134 | configs = {
135 | 0.5: {
136 | 'out_channels': (48, 96, 192, 1024),
137 | 'num_blocks': (3, 7, 3)
138 | },
139 |
140 | 1: {
141 | 'out_channels': (116, 232, 464, 1024),
142 | 'num_blocks': (3, 7, 3)
143 | },
144 | 1.5: {
145 | 'out_channels': (176, 352, 704, 1024),
146 | 'num_blocks': (3, 7, 3)
147 | },
148 | 2: {
149 | 'out_channels': (224, 488, 976, 2048),
150 | 'num_blocks': (3, 7, 3)
151 | }
152 | }
153 |
154 |
155 | def test():
156 | net = ShuffleNetV2(net_size=0.5)
157 | x = torch.randn(3, 3, 32, 32)
158 | y = net(x)
159 | print(y.shape)
160 |
161 |
162 | # test()
163 |
--------------------------------------------------------------------------------
/src/models/vgg.py:
--------------------------------------------------------------------------------
1 | '''VGG11/13/16/19 in Pytorch.'''
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | cfg = {
7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
11 | }
12 |
13 |
14 | class VGG(nn.Module):
15 | def __init__(self, vgg_name, num_classes=10):
16 | super(VGG, self).__init__()
17 | self.features = self._make_layers(cfg[vgg_name])
18 | self.classifier = nn.Linear(512, num_classes)
19 |
20 | def forward(self, x):
21 | out = self.features(x)
22 | out = out.view(out.size(0), -1)
23 | out = self.classifier(out)
24 | return out
25 |
26 | def _make_layers(self, cfg):
27 | layers = []
28 | in_channels = 3
29 | for x in cfg:
30 | if x == 'M':
31 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
32 | else:
33 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
34 | nn.BatchNorm2d(x),
35 | nn.ReLU(inplace=True)]
36 | in_channels = x
37 | layers += [nn.AdaptiveAvgPool2d((1, 1))]
38 | return nn.Sequential(*layers)
39 |
40 |
41 | def test():
42 | net = VGG('VGG11')
43 | x = torch.randn(2,3,32,32)
44 | y = net(x)
45 | print(y.size())
46 |
47 | # test()
48 |
--------------------------------------------------------------------------------
/src/script/clip_retrieval_stl10.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 | import json
5 | import fire
6 | import numpy as np
7 | import pandas as pd
8 | import tqdm
9 | from clip_retrieval.clip_client import ClipClient
10 |
11 | def get_urls(img_folder, class_idx, k=45):
12 | # for each image, get the top 45 urls
13 | print(f'Number of images per query: {k}')
14 | client = ClipClient(url="https://knn.laion.ai/knn-service",
15 | indice_name="laion5B-L-14", num_images=k,
16 | use_safety_model=False, use_violence_detector=False, deduplicate=False)
17 |
18 | url_map = {}
19 | for name in tqdm.tqdm(os.listdir(f'{img_folder}/class_{class_idx:03d}')):
20 | img_path = f'{img_folder}/class_{class_idx:03d}/{name}'
21 | assert os.path.exists(img_path)
22 | try:
23 | results = client.query(image=img_path)
24 | url_map[name] = {'url': [ele['url'] for ele in results], 'id': [ele['id'] for ele in results]}
25 | except:
26 | print('Error: Class{} {}'.format(class_idx, name))
27 | # print(url_map[name])
28 | return url_map
29 |
30 |
31 | def url_analysis(url_map, key='id'):
32 | # count the unique urls
33 | unique_urls = set()
34 |
35 | for name, ele in url_map.items():
36 | for url in ele[key]:
37 | unique_urls.add(url)
38 |
39 | print('Unique urls:')
40 | print(len(unique_urls))
41 |
42 | # count occurrences of each url
43 | occurrences = {}
44 | for url in unique_urls:
45 | count = 0
46 | for name, ele in url_map.items():
47 | if url in ele[key]:
48 | count += 1
49 | occurrences[url] = count
50 | # print('Occurrences:')
51 | # print(occurrences)
52 |
53 | # compute the average number of occurrences
54 | mean = sum(occurrences.values()) / len(occurrences)
55 | print('Mean occurrences:')
56 | print(mean)
57 | # The underlying reason is because the clip embedding does not capture the discriminative information of the image on the target domain.
58 |
59 |
60 | def main(class_idx, k=45):
61 | img_folder = "/ssd005/projects/diffusion_inversion/real_data/stl10/scaling/res96_bicubic"
62 | output_dir = f'clip_retrieval/stl10'
63 | if not os.path.exists(output_dir):
64 | os.makedirs(output_dir)
65 |
66 | # url_map = get_urls(img_folder, class_idx=class_idx, k=k)
67 |
68 | # with open(f"{output_dir}/cls{class_idx}_k{k}.json", "w") as f:
69 | # json.dump(url_map, f)
70 |
71 | with open(f"{output_dir}/cls{class_idx}_k{k}.json") as f:
72 | url_map = json.load(f)
73 |
74 | url_analysis(url_map)
75 |
76 | unique_urls = set()
77 | res = []
78 |
79 | for name, ele in url_map.items():
80 | for url in ele['url']:
81 | if url not in unique_urls:
82 | res.append({'url': url})
83 | unique_urls.add(url)
84 |
85 | print('Unique urls:')
86 | print(len(unique_urls))
87 |
88 | with open(f"{output_dir}/cls{class_idx}_k{k}_unique.json", mode="w") as f:
89 | json.dump(res, f)
90 |
91 | # image_folder = f"/ssd005/projects/diffusion_inversion/inversion_data/stl10/knn_retrieval/k{k}/class_{class_idx:03d}"
92 | # os.system(
93 | # f'img2dataset --input_format=json --url_list={output_dir}/cls{class_idx}_k{k}_unique.json --output_folder={image_folder} --thread_count=64 --image_size=96 --output_format=files')
94 |
95 | image_folder = f"/ssd005/projects/diffusion_inversion/inversion_data/stl10/knn_retrieval/k{k}_unresize/class_{class_idx:03d}"
96 | os.system(
97 | f'img2dataset --input_format=json --url_list={output_dir}/cls{class_idx}_k{k}_unique.json --output_folder={image_folder} --thread_count=64 --output_format=files')
98 |
99 | if __name__ == "__main__":
100 | fire.Fire(main)
101 | # python clip_retrieval_stl10.py --class_idx=0 --k=10
102 | # img2dataset --input_format=json --url_list=urls_10.json --output_folder=/ssd005/projects/diffusion_inversion/inversion_data/stl10/knn_retrieval/ --thread_count=64 --image_size=96 --output_format=files
103 |
104 | # https://github.com/rom1504/img2dataset
105 |
--------------------------------------------------------------------------------
/src/script/compute_statistics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import tqdm
4 | import fire
5 | import functools
6 | import ml_collections
7 | from absl import logging
8 |
9 | import numpy as np
10 | import tensorflow as tf
11 | import tensorflow_datasets as tfds
12 |
13 |
14 | sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
15 | from dataset import get_dataset # NOQA
16 |
17 | tf.config.experimental.set_visible_devices([], "GPU")
18 |
19 |
20 | def D(**kwargs):
21 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
22 |
23 |
24 | def center_crop(x, resolution):
25 | shape = tf.shape(x)
26 | h, w = shape[0], shape[1]
27 | size = tf.minimum(h, w)
28 | begin = tf.cast([h - size, w - size], tf.float32) / 2.0
29 | begin = tf.cast(begin, tf.int32)
30 | begin = tf.concat([begin, [0]], axis=0) # Add channel dimension.
31 | x = tf.slice(x, begin, [size, size, 3])
32 | x = tf.image.resize_with_pad(
33 | x, resolution, resolution, method='area', antialias=True)
34 | return x
35 |
36 |
37 | def compute_channel_mean_std_ds(ds, img_shape, resolution=32, batch_size=1000):
38 | if None in img_shape:
39 | dim = resolution * resolution
40 | else:
41 | dim = functools.reduce(lambda x, y: x * y, img_shape[:-1], 1)
42 |
43 | # ds = ds.map(lambda x, y: tf.cast(
44 | # x, dtype='float32') / 255.0, tf.data.AUTOTUNE)
45 | ds = ds.map(lambda x, y: tf.cast(
46 | x, dtype='float32'), tf.data.AUTOTUNE)
47 | if None in img_shape:
48 | ds = ds.map(lambda x: center_crop(x, resolution), tf.data.AUTOTUNE)
49 | ds = ds.map(lambda x: tf.reshape(
50 | x, shape=(dim, img_shape[-1])), tf.data.AUTOTUNE)
51 | ds = ds.batch(batch_size=batch_size)
52 | ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)
53 |
54 | mean = np.zeros(shape=(img_shape[-1],))
55 | var = np.zeros(shape=(img_shape[-1],))
56 | count = 0
57 |
58 | for x_batch in tqdm.tqdm(tfds.as_numpy(ds), desc='Compute mean with batch size: {}'.format(batch_size)):
59 | mean = mean + np.sum(x_batch, axis=(0, 1))
60 | count += x_batch.shape[0]
61 |
62 | mean = 1.0 / (count * dim) * mean
63 |
64 | for x_batch in tqdm.tqdm(tfds.as_numpy(ds), desc='Compute variance with batch size: {}'.format(batch_size)):
65 | var = var + np.sum(np.square(x_batch - mean), axis=(0, 1))
66 |
67 | std = np.sqrt(1.0 / (count * dim) * var)
68 |
69 | logging.info(
70 | 'Total number of data: {}, mean: {}, std: {}'.format(count, mean, std))
71 |
72 | return mean, std
73 |
74 |
75 | def main(dataset_name, data_dir, resolution, batch_size):
76 | dataset_config = D(
77 | name=dataset_name,
78 | data_dir=data_dir)
79 |
80 | x_train, y_train = get_dataset(
81 | dataset_config, return_raw=True, resolution=resolution, train_only=True)
82 |
83 | ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
84 | mean, std = compute_channel_mean_std_ds(
85 | ds, (resolution, resolution, 3), resolution=resolution, batch_size=batch_size)
86 | print('Mean: {}, Std: {}'.format(mean, std))
87 |
88 |
89 | if __name__ == "__main__":
90 | fire.Fire(main)
91 | # python compute_statistics.py --dataset_name=imagenette --data_dir=$HOME/tensorflow_datasets --resolution=256 --batch_size=1000
--------------------------------------------------------------------------------
/src/script/convert_vae.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import fire
4 | import ml_collections
5 | import torch
6 | import tensorflow as tf
7 | import numpy as np
8 | import PIL
9 | from PIL import Image
10 | from tqdm import tqdm
11 |
12 | from torch.utils.data import Dataset
13 |
14 | from diffusers import AutoencoderKL
15 |
16 | sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
17 | from dataset import get_dataset # NOQA
18 |
19 | tf.config.experimental.set_visible_devices([], "GPU")
20 |
21 |
22 | PIL_INTERPOLATION = {
23 | "linear": PIL.Image.Resampling.BILINEAR,
24 | "bilinear": PIL.Image.Resampling.BILINEAR,
25 | "bicubic": PIL.Image.Resampling.BICUBIC,
26 | "lanczos": PIL.Image.Resampling.LANCZOS,
27 | "nearest": PIL.Image.Resampling.NEAREST,
28 | }
29 |
30 |
31 | def D(**kwargs):
32 | return ml_collections.ConfigDict(initial_dictionary=kwargs)
33 |
34 |
35 | class EmbDataset(Dataset):
36 | def __init__(self, x, y, size=32, interpolation=Image.BICUBIC):
37 | self.x = x
38 | self.y = y
39 | self.size = size
40 | self.interpolation = interpolation
41 |
42 | def __len__(self):
43 | return self.x.shape[0]
44 |
45 | def __getitem__(self, idx):
46 | img = self.x[idx].astype(np.uint8)
47 | image = Image.fromarray(img)
48 | image = image.resize((self.size, self.size),
49 | resample=self.interpolation)
50 | image = np.array(image).astype(np.float32)
51 | image = (image / 127.5 - 1.0)
52 | image = torch.from_numpy(image).permute(2, 0, 1)
53 |
54 | return image, self.y[idx]
55 |
56 |
57 | def main(dataset_name, root_name, split='train', group_size=100, batch_size=50, sampling_resolution=128,
58 | interpolation='bicubic', save_resolution=32, num_classes=10):
59 |
60 | device = "cuda"
61 | dataset_config = D(
62 | name=dataset_name,
63 | data_dir='~/tensorflow_datasets'
64 | )
65 | x_train, y_train, x_test, y_test = get_dataset(
66 | dataset_config, return_raw=True, train_only=False)
67 |
68 | root_name = os.path.join(root_name, f'{dataset_name}_{split}')
69 | model_id = "CompVis/stable-diffusion-v1-4"
70 | ae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
71 |
72 | config_name = f'{model_id.replace("/", "-")}_sample{sampling_resolution}_{interpolation}_save{save_resolution}'
73 |
74 | for i in range(num_classes):
75 | os.makedirs(f'{root_name}/class_{i:03d}/{config_name}', exist_ok=True)
76 |
77 | if split == 'train':
78 | dataset = EmbDataset(x_train, y_train, size=sampling_resolution,
79 | interpolation=PIL_INTERPOLATION[interpolation])
80 | elif split == 'test':
81 | dataset = EmbDataset(x_test, y_test, size=sampling_resolution,
82 | interpolation=PIL_INTERPOLATION[interpolation])
83 | else:
84 | raise ValueError(f'Unknown split {split}')
85 |
86 | dataloader = torch.utils.data.DataLoader(
87 | dataset, batch_size=batch_size, shuffle=False, num_workers=4)
88 |
89 | for i, (img, lb) in tqdm(enumerate(dataloader)):
90 | with torch.no_grad():
91 | latents = ae.encode(img.to(device)).latent_dist.mode()
92 | image = ae.decode(latents).sample
93 | image = (image / 2 + 0.5).clamp(0, 1)
94 | images = image.cpu().permute(0, 2, 3, 1).float().numpy()
95 | images = (images * 255).round().astype("uint8")
96 | images = [Image.fromarray(image) for image in images]
97 |
98 | for idx, img in enumerate(images):
99 | if sampling_resolution != save_resolution:
100 | img = img.resize((save_resolution, save_resolution),
101 | resample=PIL_INTERPOLATION[interpolation])
102 | img.save(
103 | f'{root_name}/class_{lb[idx]:03d}/{config_name}/group{(i*batch_size+idx)//group_size:02d}_sample{(i*batch_size+idx):05d}.png')
104 |
105 |
106 | if __name__ == '__main__':
107 | fire.Fire(main)
108 | # python convert_vae.py --dataset_name=cifar10 --root_name=$HOME/inversion_data --batch_size=100 --interpolation=bicubic --save_resolution=128 --num_classes=10 --split=train --sampling_resolution=128
109 |
--------------------------------------------------------------------------------
/src/script/split_euro_dataset.py:
--------------------------------------------------------------------------------
1 | import fire
2 |
3 | import os
4 | import sys
5 | import json
6 | import shutil
7 |
8 | def main():
9 | input_dir = "$HOME/tensorflow_datasets/downloads/manual/eurosat/2750"
10 | output_dir = "$HOME/tensorflow_datasets/eurosat"
11 |
12 | with open("$HOME/Project/diffusion_inversion/src/split_zhou_EuroSAT.json") as f:
13 | data = json.load(f)
14 |
15 | for k, v in data.items():
16 | print(k, len(v))
17 | if not os.path.exists(output_dir):
18 | os.makedirs('{}/{}'.format(output_dir, k))
19 |
20 | for k, v in data.items():
21 | for i in v:
22 | if not os.path.exists(os.path.join(output_dir, k, i[0].split('/')[0])):
23 | os.makedirs(os.path.join(output_dir, k, i[0].split('/')[0]))
24 | # print(os.path.join(input_dir, i[0]), os.path.join(output_dir, k, i[0]))
25 | shutil.copy(os.path.join(input_dir, i[0]), os.path.join(output_dir, k, i[0]))
26 |
27 | if __name__ == '__main__':
28 | fire.Fire(main)
--------------------------------------------------------------------------------