├── .gitignore ├── LICENSE ├── README.md ├── asset ├── bloodmnist_vis.png ├── cifar100_vis.png ├── cifar10_vis.png ├── dermamnist_vis.png ├── eurosat_vis.png ├── imagenette_vis.png ├── pathmnist_vis.png ├── stl10_vis.png └── workflow_box.png ├── environment.yml └── src ├── classfier_tuning ├── clip │ ├── README.md │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ ├── model.py │ └── tokenizer.py └── src │ ├── args.py │ ├── clip_filtering.py │ ├── clip_filtering_eurosat.py │ ├── ct_fsl.py │ ├── ct_zsl.py │ ├── datasets │ ├── __init__.py │ ├── cifar10.py │ ├── cifar100.py │ ├── common.py │ ├── eurosat.py │ ├── fmow.py │ ├── imagenet.py │ ├── imagenet_a.py │ ├── imagenet_classnames.py │ ├── imagenet_r.py │ ├── imagenet_sketch.py │ ├── imagenet_vid_robust.py │ ├── imagenetv2.py │ ├── iwildcam.py │ ├── iwildcam_metadata │ │ └── labels.csv │ ├── objectnet.py │ ├── objectnet_metadata │ │ ├── folder_to_objectnet_label.json │ │ ├── imagenet_to_label_2012_v2 │ │ ├── objectnet_to_imagenet_1k.json │ │ └── pytorch_to_imagenet_2012_id.json │ ├── transfer_ds │ │ ├── Randaug.py │ │ ├── __init__.py │ │ ├── aircraft.py │ │ ├── cal_mean_std.py │ │ ├── caltech.py │ │ ├── constants.py │ │ ├── cub.py │ │ ├── cub_for_robust_codebase.py │ │ ├── cub_transform.py │ │ ├── dtd.py │ │ ├── fine_tunify.py │ │ ├── food_101.py │ │ ├── imbalance_cifar.py │ │ ├── process_dataset │ │ │ ├── pro_aircraft.py │ │ │ ├── pro_caltech101.py │ │ │ ├── pro_cars.py │ │ │ ├── pro_flowers.py │ │ │ ├── pro_imgnet.py │ │ │ ├── pro_pool15.py │ │ │ └── process_pets.py │ │ ├── transfer_datasets.py │ │ ├── transform_ckpt.py │ │ └── utils.py │ ├── ytbb-robust_metadata │ │ ├── anchor_labels.json │ │ ├── class_idx_map.json │ │ ├── pmk_labels.json │ │ ├── rev_class_idx_map.json │ │ ├── ytbb_class_index.json │ │ └── ytbb_robustness_test_anchors_full.csv │ └── ytbb_robust.py │ ├── eurosat_text_feature.pt │ ├── get_classifier_weights.py │ ├── imagenet_text_feature.pt │ ├── models │ ├── __init__.py │ ├── eval.py │ ├── finetune.py │ ├── finetune_retina.py │ ├── modeling.py │ ├── utils.py │ ├── zeroshot.py │ └── zeroshot_retina.py │ ├── select_glide_ims_by_clip.py │ └── templates │ ├── __init__.py │ ├── fmow_template.py │ ├── iwildcam_template.py │ ├── openai_imagenet_template.py │ ├── simple_template.py │ ├── transfer_ds_template.py │ └── utils.py ├── dataset.py ├── diffuser_inversion.py ├── models ├── __init__.py ├── densenet.py ├── dla.py ├── dla_simple.py ├── dpn.py ├── efficientnet.py ├── googlenet.py ├── lenet.py ├── mobilenet.py ├── mobilenetv2.py ├── pnasnet.py ├── preact_resnet.py ├── regnet.py ├── resnet.py ├── resnext.py ├── senet.py ├── shufflenet.py ├── shufflenetv2.py └── vgg.py ├── pipeline_emb.py ├── sample_dataset.py ├── script ├── clip_retrieval_stl10.py ├── compute_statistics.py ├── convert_vae.py ├── split_euro_dataset.py └── split_zhou_EuroSAT.json └── train_net.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | .DS_Store 30 | *.pyc 31 | *.log 32 | .vscode/ 33 | src/wandb/ 34 | 35 | checkpoint/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yongchao Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion Inversion 2 | 3 | [Project Page](https://sites.google.com/view/diffusion-inversion) 4 | | [ArXiv](https://arxiv.org/abs/2305.15316) 5 | 6 | This repo contains code for steer Stable Diffusion Model to generate data for downstream classifier training. Please see our paper and project page for more results. 7 | 8 |

9 | cifar10_vis 10 | cifar100_vis 11 | stl10_vis 12 | imagenette_vis 13 | pathmnist_vis 14 | bloodmnist_vis 15 | dermamnist_vis 16 | eurosat_vis 17 |

18 | 19 | ## Abstract 20 | Acquiring high-quality data for training discriminative models is a crucial yet challenging aspect of building effective predictive systems. In this paper, we present Diffusion Inversion, a simple yet effective method that leverages the pre-trained generative model, Stable Diffusion, to generate diverse, high-quality training data for image classification. Our approach captures the original data distribution and ensures data coverage by inverting images to the latent space of Stable Diffusion, and generates diverse novel training images by conditioning the generative model on noisy versions of these vectors. We identify three key components that allow our generated images to successfully supplant the original dataset, leading to a 2-3x enhancement in sample complexity and a 6.5x decrease in sampling time. Moreover, our approach consistently outperforms generic prompt-based steering methods and KNN retrieval baseline across a wide range of datasets. Additionally, we demonstrate the compatibility of our approach with widely-used data augmentation techniques, as well as the reliability of the generated data in supporting various neural architectures and enhancing few-shot learning. 21 | 22 | ## Method 23 | Stable Diffusion, a model trained on billions of image-text pairs, boasts a wealth of generalizable knowledge. To harness this knowledge for specific classification tasks, we propose a two-stage method that guides a pre-trained generator, $G$, towards the target domain dataset. In the first stage, we map each image to the model's latent space, generating a dataset of latent embedding vectors. Then, we produce novel image variants by running the inverse diffusion process conditioned on perturbed versions of these vectors. We illustrate our approach in Figure below. 24 |

25 | method 26 |

27 | 28 | ## Reproducing 29 | ### Environment 30 | 31 | - You can set up the environment using the command below. 32 | 33 | ```bash 34 | conda env create -f environment.yaml 35 | conda activate di 36 | ``` 37 | 38 | ### Training 39 | 40 | ```bash 41 | path="--pretrained_model_name_or_path=CompVis/stable-diffusion-v1-4 --output_dir=$PROJDIR/diffusion_inversion/logs/stl10 --dataset_name=stl10 --data_dir=~/tensorflow_datasets" 42 | args="--gradient_accumulation_steps=1 --num_tokens=5 --resolution=256 --train_batch_size=50 --num_emb=100 --max_train_steps=8000" 43 | lr="--lr_warmup_steps=0 --interpolation=bicubic --lr_scheduler=constant --learning_rate=3e-02" 44 | log="--checkpointing_steps=1000 --save_steps=1000 --save_image_steps=400 --resume_from_checkpoint=latest" 45 | 46 | accelerate launch src/diffuser_inversion.py $path $args $lr $log --group_id=0 47 | ... 48 | accelerate launch src/diffuser_inversion.py $path $args $lr $log --group_id=50 49 | ``` 50 | 51 | ### Sampling 52 | 53 | ```bash 54 | path="--dataset_name=stl10 --model_root_dir=$PROJDIR/diffusion_inversion/logs/stl10/res256_bicubic/emb100_token5_lr0.03_constant --dm_name=CompVis/stable-diffusion-v1-4" 55 | train_config="--emb_ch=768 --num_tokens=5 --num_classes=10 --num_emb=100 --sampling_resolution=256 --save_resolution=96 --outdir=$PROJDIR/inversion_data/stl10/scaling" 56 | sampling_config="--num_inference_steps=100 --batch_size=100 --interpolation_strength=0.1 --num_samples=5 --emb_noise=0.1 --train_steps=3000 --seed=42" 57 | 58 | python sample_dataset.py $path $train_config $sampling_config --group_id=0 59 | ... 60 | python sample_dataset.py $path $train_config $sampling_config --group_id=50 61 | ``` 62 | 63 | ### Evaluation 64 | ```bash 65 | path="--output=$PROJDIR/project/diffusion_inversion/arch" 66 | stl10="--dataset-name=stl10 --group-size=100 --num-steps=50000" 67 | pstl10="--syn-data-dir=$PROJDIR/inversion_data/stl10/scaling/res96_bicubic --syn-pattern=tstep[0-9]*_infstep100_gs[0-9]*_noise0.[0-9]*_itep0.[0-9]*_seed[0-9]*" 68 | args="--batch-size=128 --warmup-steps=1000 --num-data=5000 --num-steps=50000 --optimizer=sgd --weight-decay=5e-4 --real-bs=0 --syn-bs=128" 69 | log="--num-evals=20 --seed=42 --wandb-name=DI-stl10 --log-wandb" 70 | 71 | # Train on real data 72 | python train_net.py $path $stl10 $args $log --model=resnet18 --lr=1e-1 73 | 74 | # Train on synthetic data 75 | python train_net.py $path $stl10 $pstl10 $args $log --model=resnet18 --lr=1e-1 76 | ``` 77 | 78 | -------------------------------------------------------------------------------- /asset/bloodmnist_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/asset/bloodmnist_vis.png -------------------------------------------------------------------------------- /asset/cifar100_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/asset/cifar100_vis.png -------------------------------------------------------------------------------- /asset/cifar10_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/asset/cifar10_vis.png -------------------------------------------------------------------------------- /asset/dermamnist_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/asset/dermamnist_vis.png -------------------------------------------------------------------------------- /asset/eurosat_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/asset/eurosat_vis.png -------------------------------------------------------------------------------- /asset/imagenette_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/asset/imagenette_vis.png -------------------------------------------------------------------------------- /asset/pathmnist_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/asset/pathmnist_vis.png -------------------------------------------------------------------------------- /asset/stl10_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/asset/stl10_vis.png -------------------------------------------------------------------------------- /asset/workflow_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/asset/workflow_box.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: di 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - python=3.10.6 8 | - cudatoolkit=11.3 9 | - pytorch=1.12.1 10 | - torchvision=0.13.1 11 | - fire=0.4.0 12 | - imageio=2.22.4 13 | - einops=0.6.0 14 | - pip 15 | - pip: 16 | - tensorflow~=2.10 17 | - tensorflow-datasets~=4.7.0 18 | - ml_collections~=0.1.1 19 | - seaborn~=0.12.1 20 | - kornia~=0.6.8 21 | - timm~=0.6.12 22 | - transformers~=4.25.1 23 | - safetensors~=0.2.5 24 | - accelerate~=0.15 25 | - datasets~=2.7.1 26 | - diffusers 27 | - clu 28 | - gsutil 29 | - medmnist 30 | -------------------------------------------------------------------------------- /src/classfier_tuning/clip/README.md: -------------------------------------------------------------------------------- 1 | This folder is a lightly modified version of https://github.com/openai/CLIP. 2 | -------------------------------------------------------------------------------- /src/classfier_tuning/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/src/classfier_tuning/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /src/classfier_tuning/clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | if not special_tokens: 74 | special_tokens = ['', ''] 75 | else: 76 | special_tokens = ['', ''] + special_tokens 77 | vocab.extend(special_tokens) 78 | self.encoder = dict(zip(vocab, range(len(vocab)))) 79 | self.decoder = {v: k for k, v in self.encoder.items()} 80 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 81 | self.cache = {t:t for t in special_tokens} 82 | special = "|".join(special_tokens) 83 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 84 | 85 | self.vocab_size = len(self.encoder) 86 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 87 | 88 | def bpe(self, token): 89 | if token in self.cache: 90 | return self.cache[token] 91 | word = tuple(token[:-1]) + ( token[-1] + '',) 92 | pairs = get_pairs(word) 93 | 94 | if not pairs: 95 | return token+'' 96 | 97 | while True: 98 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 99 | if bigram not in self.bpe_ranks: 100 | break 101 | first, second = bigram 102 | new_word = [] 103 | i = 0 104 | while i < len(word): 105 | try: 106 | j = word.index(first, i) 107 | new_word.extend(word[i:j]) 108 | i = j 109 | except: 110 | new_word.extend(word[i:]) 111 | break 112 | 113 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 114 | new_word.append(first+second) 115 | i += 2 116 | else: 117 | new_word.append(word[i]) 118 | i += 1 119 | new_word = tuple(new_word) 120 | word = new_word 121 | if len(word) == 1: 122 | break 123 | else: 124 | pairs = get_pairs(word) 125 | word = ' '.join(word) 126 | self.cache[token] = word 127 | return word 128 | 129 | def encode(self, text): 130 | bpe_tokens = [] 131 | text = whitespace_clean(basic_clean(text)).lower() 132 | for token in re.findall(self.pat, text): 133 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 134 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 135 | return bpe_tokens 136 | 137 | def decode(self, tokens): 138 | text = ''.join([self.decoder[token] for token in tokens]) 139 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 140 | return text -------------------------------------------------------------------------------- /src/classfier_tuning/src/ct_fsl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 4 | 5 | import wandb 6 | from src.models.finetune import finetune_fsl 7 | from src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier 8 | from src.models.zeroshot import get_zeroshot_classifier 9 | from src.args import parse_arguments 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomResizedCrop, RandomHorizontalFlip 11 | from PIL import Image 12 | 13 | def _convert_to_rgb(image): 14 | return image.convert('RGB') 15 | 16 | normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 17 | 18 | def classifier_tuning(args): 19 | assert args.save is not None, 'Please provide a path to store models' 20 | print('import success') 21 | 22 | classification_head_path = os.path.join(args.save, f'head.pt') 23 | args.data_location_real = os.path.join(args.data_location_real, f'shot_{args.shot}') 24 | args.data_location_syn = os.path.join(args.data_location_syn, f'shot_{args.shot}') 25 | args.save = os.path.join(args.save, f'shot_{args.shot}', args.cache_name_syn) 26 | args.cache_dir = os.path.join(args.cache_dir, f'shot_{args.shot}') 27 | args.results_db = os.path.join(args.results_db, f'shot_{args.shot}', args.cache_name_syn) 28 | 29 | # Build and save zero-shot model 30 | image_encoder = ImageEncoder(args, keep_lang=True) 31 | if not os.path.exists(classification_head_path): 32 | classification_head = get_zeroshot_classifier(args, image_encoder.model) 33 | classification_head.save(classification_head_path) 34 | else: 35 | classification_head = ClassificationHead.load(classification_head_path) 36 | delattr(image_encoder.model, 'transformer') 37 | classifier = ImageClassifier(image_encoder, classification_head, process_images=False) 38 | 39 | zeroshot_checkpoint = os.path.join(args.save, 'zeroshot'+args.train_dataset+'.pt') 40 | classifier.save(zeroshot_checkpoint) 41 | 42 | # Standard fine-tuning 43 | args.load = zeroshot_checkpoint 44 | args.save = os.path.join(args.save, 'finetuned') 45 | 46 | # Mimic eurosat low-res images, val data aug 47 | train_data_aug = Compose([ 48 | # Resize(64), # resize to 32/64 for Cifar / Eurosat 49 | Resize(224, interpolation=Image.BICUBIC), 50 | CenterCrop(224), 51 | _convert_to_rgb, 52 | ToTensor(), 53 | normalize, 54 | ]) 55 | 56 | wandb.init(project=f"eurosat-fewshot", config=args) 57 | finetuned_checkpoint = finetune_fsl(args, train_data_aug) 58 | 59 | 60 | if __name__ == '__main__': 61 | args = parse_arguments() 62 | classifier_tuning(args) 63 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/ct_zsl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append( 4 | '/h/zycluke/Project/diffusion_g/src/SyntheticData/src/classifier-tuning') 5 | 6 | from src.models.finetune import finetune_zsl 7 | from src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier 8 | from src.models.zeroshot import get_zeroshot_classifier 9 | from src.args import parse_arguments 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomResizedCrop, RandomHorizontalFlip 11 | from PIL import Image 12 | 13 | def _convert_to_rgb(image): 14 | return image.convert('RGB') 15 | 16 | normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 17 | 18 | def classifier_tuning(args): 19 | assert args.save is not None, 'Please provide a path to store models' 20 | print('import success') 21 | 22 | # Build and save zero-shot model 23 | image_encoder = ImageEncoder(args, keep_lang=True) 24 | classification_head = get_zeroshot_classifier(args, image_encoder.model) 25 | delattr(image_encoder.model, 'transformer') 26 | classifier = ImageClassifier(image_encoder, classification_head, process_images=False) 27 | 28 | zeroshot_checkpoint = os.path.join(args.save, 'zeroshot'+args.train_dataset+'.pt') 29 | classifier.save(zeroshot_checkpoint) 30 | 31 | # Standard fine-tuning 32 | args.load = zeroshot_checkpoint 33 | args.save = os.path.join(args.save, 'finetuned') 34 | 35 | # Mimic eurosat low-res images, val data aug 36 | train_data_aug = Compose([ 37 | # Resize(64), # resize to 32/64 for Cifar / Eurosat 38 | Resize(224, interpolation=Image.BICUBIC), 39 | CenterCrop(224), 40 | _convert_to_rgb, 41 | ToTensor(), 42 | normalize, 43 | ]) 44 | 45 | finetuned_checkpoint = finetune_zsl(args, train_data_aug) 46 | 47 | 48 | 49 | if __name__ == '__main__': 50 | args = parse_arguments() 51 | classifier_tuning(args) 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar10 import * 2 | from .cifar100 import * 3 | # from .fmow import FMOWID, FMOWOOD, FMOW 4 | from .imagenet import ImageNet as ImageNet_ft 5 | # from .imagenetv2 import ImageNetV2 6 | # from .imagenet_a import ImageNetAValClasses, ImageNetA 7 | # from .imagenet_r import ImageNetRValClasses, ImageNetR 8 | # from .imagenet_sketch import ImageNetSketch 9 | # from .imagenet_vid_robust import ImageNetVidRobustValClasses, ImageNetVidRobust 10 | # from .iwildcam import IWildCamID, IWildCamOOD, IWildCamIDNonEmpty, IWildCamOODNonEmpty, IWildCam 11 | # from .objectnet import ObjectNetValClasses, ObjectNet 12 | # from .ytbb_robust import YTBBRobustValClasses, YTBBRobust 13 | # from .transfer_ds.transfer_datasets import * 14 | from .eurosat import EuroSAT 15 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL 3 | import torch 4 | import numpy as np 5 | import torchvision 6 | from torchvision import transforms 7 | from torchvision.datasets import CIFAR10 as PyTorchCIFAR10 8 | from torchvision.datasets import VisionDataset 9 | 10 | cifar_classnames = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 11 | 12 | class CIFAR10_theirs: 13 | def __init__(self, preprocess, 14 | location=os.path.expanduser('~/data'), 15 | batch_size=128, 16 | num_workers=16, 17 | classnames=None): 18 | 19 | 20 | self.train_dataset = PyTorchCIFAR10( 21 | root=location, download=True, train=True, transform=preprocess 22 | ) 23 | 24 | self.train_loader = torch.utils.data.DataLoader( 25 | self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers 26 | ) 27 | 28 | self.test_dataset = PyTorchCIFAR10( 29 | root=location, download=True, train=False, transform=preprocess 30 | ) 31 | 32 | self.test_loader = torch.utils.data.DataLoader( 33 | self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers 34 | ) 35 | 36 | self.classnames = self.test_dataset.classes 37 | 38 | def convert(x): 39 | if isinstance(x, np.ndarray): 40 | return torchvision.transforms.functional.to_pil_image(x) 41 | return x 42 | 43 | class BasicVisionDataset(VisionDataset): 44 | def __init__(self, images, targets, transform=None, target_transform=None): 45 | if transform is not None: 46 | transform.transforms.insert(0, convert) 47 | super(BasicVisionDataset, self).__init__(root=None, transform=transform, target_transform=target_transform) 48 | assert len(images) == len(targets) 49 | 50 | self.images = images 51 | self.targets = targets 52 | 53 | def __getitem__(self, index): 54 | return self.transform(self.images[index]), self.targets[index] 55 | 56 | def __len__(self): 57 | return len(self.targets) 58 | 59 | class CIFAR101: 60 | def __init__(self, 61 | preprocess, 62 | location=os.path.expanduser('~/data'), 63 | batch_size=128, 64 | num_workers=16, 65 | classnames=None): 66 | 67 | data_root = os.path.join(location, "CIFAR-10.1") 68 | data = np.load(os.path.join(data_root, 'cifar10.1_v6_data.npy'), allow_pickle=True) 69 | labels = np.load(os.path.join(data_root, 'cifar10.1_v6_labels.npy'), allow_pickle=True) 70 | 71 | use_cuda = torch.cuda.is_available() 72 | 73 | # Data loading code 74 | kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {} 75 | 76 | self.train_loader = None 77 | 78 | self.test_dataset = BasicVisionDataset( 79 | images=data, targets=torch.Tensor(labels).long(), 80 | transform=preprocess, 81 | ) 82 | 83 | self.test_loader = torch.utils.data.DataLoader( 84 | self.test_dataset, batch_size=batch_size, shuffle=False, **kwargs 85 | ) 86 | 87 | self.classnames = cifar_classnames 88 | 89 | 90 | class CIFAR102: 91 | def __init__(self, 92 | preprocess, 93 | location=os.path.expanduser('~/data'), 94 | batch_size=128, 95 | num_workers=16, 96 | classnames=None): 97 | 98 | train_data = np.load(os.path.join(location, "CIFAR-10.2", 'cifar102_train.npy'), allow_pickle=True).item() 99 | test_data = np.load(os.path.join(location, "CIFAR-10.2", 'cifar102_test.npy'), allow_pickle=True).item() 100 | 101 | 102 | train_data_images = train_data['images'] 103 | train_data_labels = train_data['labels'] 104 | 105 | test_data_images = test_data['images'] 106 | test_data_labels = test_data['labels'] 107 | 108 | use_cuda = torch.cuda.is_available() 109 | 110 | # Data loading code 111 | kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {} 112 | 113 | self.test_dataset = BasicVisionDataset( 114 | images=test_data_images, targets=torch.Tensor(test_data_labels).long(), 115 | transform=preprocess, 116 | ) 117 | 118 | self.test_loader = torch.utils.data.DataLoader( 119 | self.test_dataset, batch_size=batch_size, shuffle=False, **kwargs 120 | ) 121 | 122 | self.classnames = cifar_classnames 123 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/cifar100.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.datasets import CIFAR100 as PyTorchCIFAR100 4 | 5 | class CIFAR100_theirs: 6 | def __init__(self, 7 | preprocess, 8 | location=os.path.expanduser('~/data'), 9 | batch_size=128, 10 | num_workers=16, 11 | classnames=None): 12 | 13 | self.train_dataset = PyTorchCIFAR100( 14 | root=location, download=True, train=True, transform=preprocess 15 | ) 16 | 17 | self.train_loader = torch.utils.data.DataLoader( 18 | self.train_dataset, batch_size=batch_size, num_workers=num_workers 19 | ) 20 | 21 | self.test_dataset = PyTorchCIFAR100( 22 | root=location, download=True, train=False, transform=preprocess 23 | ) 24 | 25 | self.test_loader = torch.utils.data.DataLoader( 26 | self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers 27 | ) 28 | 29 | self.classnames = self.test_dataset.classes 30 | 31 | 32 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import glob 5 | import collections 6 | import random 7 | 8 | import numpy as np 9 | 10 | from tqdm import tqdm 11 | 12 | import torchvision.datasets as datasets 13 | from torch.utils.data import Dataset, DataLoader, Sampler 14 | 15 | 16 | class SubsetSampler(Sampler): 17 | def __init__(self, indices): 18 | self.indices = indices 19 | 20 | def __iter__(self): 21 | return (i for i in self.indices) 22 | 23 | def __len__(self): 24 | return len(self.indices) 25 | 26 | class ImageFolderWithPaths(datasets.ImageFolder): 27 | def __init__(self, path, transform, flip_label_prob=0.0): 28 | super().__init__(path, transform) 29 | self.flip_label_prob = flip_label_prob 30 | if self.flip_label_prob > 0: 31 | print(f'Flipping labels with probability {self.flip_label_prob}') 32 | num_classes = len(self.classes) 33 | for i in range(len(self.samples)): 34 | if random.random() < self.flip_label_prob: 35 | new_label = random.randint(0, num_classes-1) 36 | self.samples[i] = ( 37 | self.samples[i][0], 38 | new_label 39 | ) 40 | 41 | def __getitem__(self, index): 42 | image, label = super(ImageFolderWithPaths, self).__getitem__(index) 43 | return { 44 | 'images': image, 45 | 'labels': label, 46 | 'image_paths': self.samples[index][0] 47 | } 48 | 49 | 50 | def maybe_dictionarize(batch): 51 | if isinstance(batch, dict): 52 | return batch 53 | 54 | if len(batch) == 2: 55 | batch = {'images': batch[0], 'labels': batch[1]} 56 | elif len(batch) == 3: 57 | batch = {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]} 58 | else: 59 | raise ValueError(f'Unexpected number of elements: {len(batch)}') 60 | 61 | return batch 62 | 63 | 64 | def get_features_helper(image_encoder, dataloader, device): 65 | all_data = collections.defaultdict(list) 66 | 67 | image_encoder = image_encoder.to(device) 68 | # image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[x for x in range(torch.cuda.device_count())]) 69 | image_encoder.eval() 70 | 71 | with torch.no_grad(): 72 | for batch in tqdm(dataloader): 73 | batch = maybe_dictionarize(batch) 74 | features = image_encoder(batch['images'].cuda()) 75 | 76 | all_data['features'].append(features.cpu()) 77 | 78 | for key, val in batch.items(): 79 | if key == 'images': 80 | continue 81 | if hasattr(val, 'cpu'): 82 | val = val.cpu() 83 | all_data[key].append(val) 84 | else: 85 | all_data[key].extend(val) 86 | del batch, features 87 | 88 | for key, val in all_data.items(): 89 | if torch.is_tensor(val[0]): 90 | all_data[key] = torch.cat(val).numpy() 91 | 92 | return all_data 93 | 94 | 95 | def get_features(is_train, image_encoder, dataset, device, is_real=True, args=None): 96 | split = 'train' if is_train else 'val' 97 | if is_real is False: 98 | split= args.cache_name_syn 99 | dname = type(dataset).__name__ 100 | cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}' 101 | if image_encoder.cache_dir is not None: 102 | cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}' 103 | cached_files = glob.glob(f'{cache_dir}/*') 104 | if image_encoder.cache_dir is not None and len(cached_files) > 0: 105 | print(f'Getting features from {cache_dir}') 106 | data = {} 107 | for cached_file in cached_files: 108 | name = os.path.splitext(os.path.basename(cached_file))[0] 109 | data[name] = torch.load(cached_file) 110 | else: 111 | # import ipdb 112 | # ipdb.set_trace(context=20) 113 | print(f'Did not find cached features at {cache_dir}. Building from scratch.') 114 | loader = dataset.train_loader if is_train else dataset.test_loader 115 | data = get_features_helper(image_encoder, loader, device) 116 | if image_encoder.cache_dir is None: 117 | print('Not caching because no cache directory was passed.') 118 | else: 119 | os.makedirs(cache_dir, exist_ok=True) 120 | print(f'Caching data at {cache_dir}') 121 | for name, val in data.items(): 122 | torch.save(val, f'{cache_dir}/{name}.pt', pickle_protocol=4) 123 | return data 124 | 125 | 126 | class FeatureDataset(Dataset): 127 | def __init__(self, is_train, image_encoder, dataset, device, is_real=True, args=None): 128 | self.data = get_features(is_train, image_encoder, dataset, device, is_real, args=args) 129 | 130 | def __len__(self): 131 | return len(self.data['features']) 132 | 133 | def __getitem__(self, idx): 134 | data = {k: v[idx] for k, v in self.data.items()} 135 | data['features'] = torch.from_numpy(data['features']).float() 136 | return data 137 | 138 | 139 | def get_dataloader(dataset, is_train, args, image_encoder=None, is_real=True): 140 | if image_encoder is not None: 141 | feature_dataset = FeatureDataset( 142 | is_train, image_encoder, dataset, args.device, is_real, args=args) 143 | dataloader = DataLoader(feature_dataset, batch_size=args.batch_size, shuffle=is_train) 144 | else: 145 | dataloader = dataset.train_loader if is_train else dataset.test_loader 146 | return dataloader -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from .common import ImageFolderWithPaths, SubsetSampler 5 | import numpy as np 6 | 7 | class EuroSAT: 8 | def __init__(self, 9 | preprocess, 10 | location=os.path.expanduser('~/data'), 11 | batch_size=32, 12 | num_workers=32, 13 | classnames='eurosat'): 14 | self.preprocess = preprocess 15 | self.location = location 16 | self.batch_size = batch_size 17 | self.num_workers = num_workers 18 | self.classnames = ['Annual Crop Land', 'Forest', 'Herbaceous Vegetation Land', 19 | 'Highway or Road', 'Industrial Building', 'Pasture Land', 20 | 'Permanent Crop Land', 'Residential Building', 'River', 'Sea or Lake'] 21 | 22 | self.populate_train() 23 | try: 24 | self.populate_test() 25 | except: 26 | self.test_dataset = None 27 | self.test_loader = None 28 | 29 | def populate_train(self): 30 | traindir = os.path.join(self.location, 'train') 31 | self.train_dataset = ImageFolderWithPaths( 32 | traindir, 33 | transform=self.preprocess) 34 | sampler = self.get_train_sampler() 35 | kwargs = {'shuffle' : True} if sampler is None else {} 36 | self.train_loader = torch.utils.data.DataLoader( 37 | self.train_dataset, 38 | sampler=sampler, 39 | batch_size=self.batch_size, 40 | num_workers=self.num_workers, 41 | **kwargs, 42 | ) 43 | 44 | def populate_test(self): 45 | self.test_dataset = self.get_test_dataset() 46 | self.test_loader = torch.utils.data.DataLoader( 47 | self.test_dataset, 48 | batch_size=self.batch_size, 49 | num_workers=self.num_workers, 50 | sampler=self.get_test_sampler() 51 | ) 52 | 53 | def get_test_path(self): 54 | test_path = os.path.join(self.location, 'val_in_folder') 55 | if not os.path.exists(test_path): 56 | test_path = os.path.join(self.location, 'val') 57 | return test_path 58 | 59 | def get_train_sampler(self): 60 | return None 61 | 62 | def get_test_sampler(self): 63 | return None 64 | 65 | def get_test_dataset(self): 66 | return ImageFolderWithPaths(self.get_test_path(), transform=self.preprocess) 67 | 68 | def name(self): 69 | return 'eurosat' 70 | 71 | class EuroSATrain(EuroSAT): 72 | 73 | def get_test_dataset(self): 74 | pass 75 | 76 | class EuroSATK(EuroSAT): 77 | 78 | def get_train_sampler(self): 79 | idxs = np.zeros(len(self.train_dataset.targets)) 80 | target_array = np.array(self.train_dataset.targets) 81 | for c in range(1000): 82 | m = target_array == c 83 | n = len(idxs[m]) 84 | arr = np.zeros(n) 85 | arr[:self.k()] = 1 86 | np.random.shuffle(arr) 87 | idxs[m] = arr 88 | 89 | idxs = idxs.astype('int') 90 | sampler = SubsetSampler(np.where(idxs)[0]) 91 | return sampler 92 | 93 | 94 | def project_logits(logits, class_sublist_mask, device): 95 | if isinstance(logits, list): 96 | return [project_logits(l, class_sublist_mask, device) for l in logits] 97 | if logits.size(1) > sum(class_sublist_mask): 98 | return logits[:, class_sublist_mask].to(device) 99 | else: 100 | return logits.to(device) 101 | 102 | class EuroSATSubsample(EuroSAT): 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | class_sublist, self.class_sublist_mask = self.get_class_sublist_and_mask() 106 | self.classnames = [self.classnames[i] for i in class_sublist] 107 | 108 | def get_class_sublist_and_mask(self): 109 | raise NotImplementedError() 110 | 111 | def populate_train(self): 112 | pass 113 | 114 | def project_logits(self, logits, device): 115 | return project_logits(logits, self.class_sublist_mask, device) 116 | 117 | class EuroSATSubsampleValClasses(EuroSAT): 118 | def get_class_sublist_and_mask(self): 119 | raise NotImplementedError() 120 | 121 | def populate_train(self): 122 | pass 123 | 124 | def get_test_sampler(self): 125 | self.class_sublist, self.class_sublist_mask = self.get_class_sublist_and_mask() 126 | idx_subsample_list = [range(x * 50, (x + 1) * 50) for x in self.class_sublist] 127 | idx_subsample_list = sorted([item for sublist in idx_subsample_list for item in sublist]) 128 | 129 | sampler = SubsetSampler(idx_subsample_list) 130 | return sampler 131 | 132 | def project_labels(self, labels, device): 133 | projected_labels = [self.class_sublist.index(int(label)) for label in labels] 134 | return torch.LongTensor(projected_labels).to(device) 135 | 136 | def project_logits(self, logits, device): 137 | return project_logits(logits, self.class_sublist_mask, device) 138 | 139 | # ks = [1, 2, 4, 8, 16, 25, 32, 50, 64, 128, 600] 140 | 141 | # for k in ks: 142 | # cls_name = f"ImageNet{k}" 143 | # dyn_cls = type(cls_name, (ImageNetK, ), { 144 | # "k": lambda self, num_samples=k: num_samples, 145 | # }) 146 | # globals()[cls_name] = dyn_cls -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/fmow.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import wilds 4 | 5 | from torchvision.datasets import CIFAR10 as PyTorchCIFAR10 6 | from wilds.common.data_loaders import get_train_loader, get_eval_loader 7 | 8 | class FMOW: 9 | test_subset = None 10 | 11 | def __init__(self, 12 | preprocess, 13 | location=os.path.expanduser('~/data'), 14 | batch_size=128, 15 | num_workers=16, 16 | subset='test', 17 | classnames=None, 18 | **kwargs): 19 | 20 | self.dataset = wilds.get_dataset(dataset='fmow', root_dir=location) 21 | 22 | self.train_dataset = self.dataset.get_subset('train', transform=preprocess) 23 | self.train_loader = get_train_loader("standard", self.train_dataset, num_workers=num_workers, batch_size=batch_size) 24 | 25 | self.test_dataset = self.dataset.get_subset(self.test_subset, transform=preprocess) 26 | self.test_loader = get_eval_loader("standard", self.test_dataset, num_workers=num_workers, batch_size=batch_size) 27 | 28 | self.classnames = [ 29 | "airport", "airport_hangar", "airport_terminal", "amusement_park", "aquaculture", 30 | "archaeological_site", "barn", "border_checkpoint", "burial_site", "car_dealership", 31 | "construction_site", "crop_field", "dam", "debris_or_rubble", "educational_institution", 32 | "electric_substation", "factory_or_powerplant", "fire_station", "flooded_road", "fountain", 33 | "gas_station", "golf_course", "ground_transportation_station", "helipad", "hospital", 34 | "impoverished_settlement", "interchange", "lake_or_pond", "lighthouse", "military_facility", 35 | "multi-unit_residential", "nuclear_powerplant", "office_building", "oil_or_gas_facility", "park", 36 | "parking_lot_or_garage", "place_of_worship", "police_station", "port", "prison", "race_track", 37 | "railway_bridge", "recreational_facility", "road_bridge", "runway", "shipyard", "shopping_mall", 38 | "single-unit_residential", "smokestack", "solar_farm", "space_facility", "stadium", "storage_tank", 39 | "surface_mine", "swimming_pool", "toll_booth", "tower", "tunnel_opening", "waste_disposal", 40 | "water_treatment_facility", "wind_farm", "zoo" 41 | ] 42 | 43 | def post_loop_metrics(self, labels, preds, metadata, args): 44 | metadata = torch.stack(metadata) 45 | preds = preds.argmax(dim=1, keepdim=True).view_as(labels) 46 | results = self.dataset.eval(preds, labels, metadata) 47 | return results[0] 48 | 49 | class FMOWID(FMOW): 50 | def __init__(self, *args, **kwargs): 51 | self.test_subset = 'id_test' 52 | super().__init__(*args, **kwargs) 53 | 54 | class FMOWOOD(FMOW): 55 | def __init__(self, *args, **kwargs): 56 | self.test_subset = 'test' 57 | super().__init__(*args, **kwargs) -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from .common import ImageFolderWithPaths, SubsetSampler 5 | from .imagenet_classnames import get_classnames 6 | import numpy as np 7 | 8 | class ImageNet: 9 | def __init__(self, 10 | preprocess, 11 | location=os.path.expanduser('~/data'), 12 | batch_size=32, 13 | num_workers=32, 14 | classnames='openai'): 15 | self.preprocess = preprocess 16 | self.location = location 17 | self.batch_size = batch_size 18 | self.num_workers = num_workers 19 | self.classnames = get_classnames(classnames) 20 | 21 | self.populate_train() 22 | try: 23 | self.populate_test() 24 | except: 25 | self.test_dataset = None 26 | self.test_loader = None 27 | 28 | def populate_train(self): 29 | traindir = os.path.join(self.location, 'train') 30 | self.train_dataset = ImageFolderWithPaths( 31 | traindir, 32 | transform=self.preprocess) 33 | sampler = self.get_train_sampler() 34 | kwargs = {'shuffle' : True} if sampler is None else {} 35 | self.train_loader = torch.utils.data.DataLoader( 36 | self.train_dataset, 37 | sampler=sampler, 38 | batch_size=self.batch_size, 39 | num_workers=self.num_workers, 40 | **kwargs, 41 | ) 42 | 43 | def populate_test(self): 44 | self.test_dataset = self.get_test_dataset() 45 | self.test_loader = torch.utils.data.DataLoader( 46 | self.test_dataset, 47 | batch_size=self.batch_size, 48 | num_workers=self.num_workers, 49 | sampler=self.get_test_sampler() 50 | ) 51 | 52 | def get_test_path(self): 53 | test_path = os.path.join(self.location, 'val_in_folder') 54 | if not os.path.exists(test_path): 55 | test_path = os.path.join(self.location, 'val') 56 | return test_path 57 | 58 | def get_train_sampler(self): 59 | return None 60 | 61 | def get_test_sampler(self): 62 | return None 63 | 64 | def get_test_dataset(self): 65 | return ImageFolderWithPaths(self.get_test_path(), transform=self.preprocess) 66 | 67 | def name(self): 68 | return 'imagenet' 69 | 70 | class ImageNetTrain(ImageNet): 71 | 72 | def get_test_dataset(self): 73 | pass 74 | 75 | class ImageNetK(ImageNet): 76 | 77 | def get_train_sampler(self): 78 | idxs = np.zeros(len(self.train_dataset.targets)) 79 | target_array = np.array(self.train_dataset.targets) 80 | for c in range(1000): 81 | m = target_array == c 82 | n = len(idxs[m]) 83 | arr = np.zeros(n) 84 | arr[:self.k()] = 1 85 | np.random.shuffle(arr) 86 | idxs[m] = arr 87 | 88 | idxs = idxs.astype('int') 89 | sampler = SubsetSampler(np.where(idxs)[0]) 90 | return sampler 91 | 92 | 93 | def project_logits(logits, class_sublist_mask, device): 94 | if isinstance(logits, list): 95 | return [project_logits(l, class_sublist_mask, device) for l in logits] 96 | if logits.size(1) > sum(class_sublist_mask): 97 | return logits[:, class_sublist_mask].to(device) 98 | else: 99 | return logits.to(device) 100 | 101 | class ImageNetSubsample(ImageNet): 102 | def __init__(self, *args, **kwargs): 103 | super().__init__(*args, **kwargs) 104 | class_sublist, self.class_sublist_mask = self.get_class_sublist_and_mask() 105 | self.classnames = [self.classnames[i] for i in class_sublist] 106 | 107 | def get_class_sublist_and_mask(self): 108 | raise NotImplementedError() 109 | 110 | def populate_train(self): 111 | pass 112 | 113 | def project_logits(self, logits, device): 114 | return project_logits(logits, self.class_sublist_mask, device) 115 | 116 | class ImageNetSubsampleValClasses(ImageNet): 117 | def get_class_sublist_and_mask(self): 118 | raise NotImplementedError() 119 | 120 | def populate_train(self): 121 | pass 122 | 123 | def get_test_sampler(self): 124 | self.class_sublist, self.class_sublist_mask = self.get_class_sublist_and_mask() 125 | idx_subsample_list = [range(x * 50, (x + 1) * 50) for x in self.class_sublist] 126 | idx_subsample_list = sorted([item for sublist in idx_subsample_list for item in sublist]) 127 | 128 | sampler = SubsetSampler(idx_subsample_list) 129 | return sampler 130 | 131 | def project_labels(self, labels, device): 132 | projected_labels = [self.class_sublist.index(int(label)) for label in labels] 133 | return torch.LongTensor(projected_labels).to(device) 134 | 135 | def project_logits(self, logits, device): 136 | return project_logits(logits, self.class_sublist_mask, device) 137 | 138 | ks = [1, 2, 4, 8, 16, 25, 32, 50, 64, 128, 600] 139 | 140 | for k in ks: 141 | cls_name = f"ImageNet{k}" 142 | dyn_cls = type(cls_name, (ImageNetK, ), { 143 | "k": lambda self, num_samples=k: num_samples, 144 | }) 145 | globals()[cls_name] = dyn_cls -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/imagenet_a.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | from .imagenet import ImageNetSubsample, ImageNetSubsampleValClasses 6 | import numpy as np 7 | 8 | 9 | CLASS_SUBLIST = [ 10 | 6, 11, 13, 15, 17, 22, 23, 27, 30, 37, 39, 42, 47, 50, 57, 70, 71, 76, 79, 89, 90, 94, 96, 97, 99, 105, 107, 11 | 108, 110, 12 | 113, 124, 125, 130, 132, 143, 144, 150, 151, 207, 234, 235, 254, 277, 283, 287, 291, 295, 298, 301, 306, 307, 13 | 308, 309, 14 | 310, 311, 313, 314, 315, 317, 319, 323, 324, 326, 327, 330, 334, 335, 336, 347, 361, 363, 372, 378, 386, 397, 15 | 400, 401, 16 | 402, 404, 407, 411, 416, 417, 420, 425, 428, 430, 437, 438, 445, 456, 457, 461, 462, 470, 472, 483, 486, 488, 17 | 492, 496, 18 | 514, 516, 528, 530, 539, 542, 543, 549, 552, 557, 561, 562, 569, 572, 573, 575, 579, 589, 606, 607, 609, 614, 19 | 626, 627, 20 | 640, 641, 642, 643, 658, 668, 677, 682, 684, 687, 701, 704, 719, 736, 746, 749, 752, 758, 763, 765, 768, 773, 21 | 774, 776, 22 | 779, 780, 786, 792, 797, 802, 803, 804, 813, 815, 820, 823, 831, 833, 835, 839, 845, 847, 850, 859, 862, 870, 23 | 879, 880, 24 | 888, 890, 897, 900, 907, 913, 924, 932, 933, 934, 937, 943, 945, 947, 951, 954, 956, 957, 959, 971, 972, 980, 25 | 981, 984, 26 | 986, 987, 988] 27 | CLASS_SUBLIST_MASK = [(i in CLASS_SUBLIST) for i in range(1000)] 28 | 29 | 30 | class ImageNetAValClasses(ImageNetSubsampleValClasses): 31 | def get_class_sublist_and_mask(self): 32 | return CLASS_SUBLIST, CLASS_SUBLIST_MASK 33 | 34 | 35 | class ImageNetA(ImageNetSubsample): 36 | def get_class_sublist_and_mask(self): 37 | return CLASS_SUBLIST, CLASS_SUBLIST_MASK 38 | 39 | def get_test_path(self): 40 | return os.path.join(self.location, 'imagenet-a') 41 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/imagenet_r.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | from .imagenet import ImageNetSubsample, ImageNetSubsampleValClasses 6 | import numpy as np 7 | 8 | 9 | CLASS_SUBLIST = [ 10 | 1, 2, 4, 6, 8, 9, 11, 13, 22, 23, 26, 29, 31, 39, 47, 63, 71, 76, 79, 84, 90, 94, 96, 97, 99, 100, 105, 107, 11 | 113, 122, 12 | 125, 130, 132, 144, 145, 147, 148, 150, 151, 155, 160, 161, 162, 163, 171, 172, 178, 187, 195, 199, 203, 13 | 207, 208, 219, 14 | 231, 232, 234, 235, 242, 245, 247, 250, 251, 254, 259, 260, 263, 265, 267, 269, 276, 277, 281, 288, 289, 15 | 291, 292, 293, 16 | 296, 299, 301, 308, 309, 310, 311, 314, 315, 319, 323, 327, 330, 334, 335, 337, 338, 340, 341, 344, 347, 17 | 353, 355, 361, 18 | 362, 365, 366, 367, 368, 372, 388, 390, 393, 397, 401, 407, 413, 414, 425, 428, 430, 435, 437, 441, 447, 19 | 448, 457, 462, 20 | 463, 469, 470, 471, 472, 476, 483, 487, 515, 546, 555, 558, 570, 579, 583, 587, 593, 594, 596, 609, 613, 21 | 617, 621, 629, 22 | 637, 657, 658, 701, 717, 724, 763, 768, 774, 776, 779, 780, 787, 805, 812, 815, 820, 824, 833, 847, 852, 23 | 866, 875, 883, 24 | 889, 895, 907, 928, 931, 932, 933, 934, 936, 937, 943, 945, 947, 948, 949, 951, 953, 954, 957, 963, 965, 25 | 967, 980, 981, 26 | 983, 988] 27 | CLASS_SUBLIST_MASK = [(i in CLASS_SUBLIST) for i in range(1000)] 28 | 29 | 30 | class ImageNetRValClasses(ImageNetSubsampleValClasses): 31 | def get_class_sublist_and_mask(self): 32 | return CLASS_SUBLIST, CLASS_SUBLIST_MASK 33 | 34 | class ImageNetR(ImageNetSubsample): 35 | def get_class_sublist_and_mask(self): 36 | return CLASS_SUBLIST, CLASS_SUBLIST_MASK 37 | 38 | def get_test_path(self): 39 | return os.path.join(self.location, 'imagenet-r') -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/imagenet_sketch.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .imagenet import ImageNet 3 | 4 | 5 | class ImageNetSketch(ImageNet): 6 | 7 | def populate_train(self): 8 | pass 9 | 10 | def get_test_path(self): 11 | return os.path.join(self.location, 'sketch') 12 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/imagenetv2.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from imagenetv2_pytorch import ImageNetV2Dataset 4 | 5 | from .imagenet import ImageNet 6 | 7 | class ImageNetV2DatasetWithPaths(ImageNetV2Dataset): 8 | def __getitem__(self, i): 9 | img, label = Image.open(self.fnames[i]), int(self.fnames[i].parent.name) 10 | if self.transform is not None: 11 | img = self.transform(img) 12 | return { 13 | 'images': img, 14 | 'labels': label, 15 | 'image_paths': str(self.fnames[i]) 16 | } 17 | 18 | class ImageNetV2(ImageNet): 19 | def get_test_dataset(self): 20 | return ImageNetV2DatasetWithPaths(transform=self.preprocess, location=self.location) 21 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/iwildcam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import json 4 | import numpy as np 5 | import pathlib 6 | 7 | import wilds 8 | from wilds.common.data_loaders import get_train_loader, get_eval_loader 9 | from wilds.datasets.wilds_dataset import WILDSSubset 10 | 11 | 12 | def get_mask_non_empty(dataset): 13 | metadf = pd.read_csv(dataset._data_dir / 'metadata.csv') 14 | filename = os.path.expanduser(dataset._data_dir / 'iwildcam2020_megadetector_results.json') 15 | with open(filename, 'r') as f: 16 | md_data = json.load(f) 17 | id_to_maxdet = {x['id']: x['max_detection_conf'] for x in md_data['images']} 18 | threshold = 0.95 19 | mask_non_empty = [id_to_maxdet[x] >= threshold for x in metadf['image_id']] 20 | return mask_non_empty 21 | 22 | 23 | def get_nonempty_subset(dataset, split, frac=1.0, transform=None): 24 | if split not in dataset.split_dict: 25 | raise ValueError(f"Split {split} not found in dataset's split_dict.") 26 | split_mask = dataset.split_array == dataset.split_dict[split] 27 | 28 | # intersect split mask with non_empty. here is the only place this fn differs 29 | # from https://github.com/p-lambda/wilds/blob/main/wilds/datasets/wilds_dataset.py#L56 30 | mask_non_empty = get_mask_non_empty(dataset) 31 | split_mask = split_mask & mask_non_empty 32 | 33 | split_idx = np.where(split_mask)[0] 34 | if frac < 1.0: 35 | num_to_retain = int(np.round(float(len(split_idx)) * frac)) 36 | split_idx = np.sort(np.random.permutation(split_idx)[:num_to_retain]) 37 | subset = WILDSSubset(dataset, split_idx, transform) 38 | return subset 39 | 40 | 41 | class IWildCam: 42 | def __init__(self, 43 | preprocess, 44 | location=os.path.expanduser('~/data'), 45 | remove_non_empty=False, 46 | batch_size=128, 47 | num_workers=16, 48 | classnames=None, 49 | subset='train'): 50 | self.dataset = wilds.get_dataset(dataset='iwildcam', root_dir=location) 51 | self.train_dataset = self.dataset.get_subset('train', transform=preprocess) 52 | self.train_loader = get_train_loader("standard", self.train_dataset, num_workers=num_workers, batch_size=batch_size) 53 | 54 | if remove_non_empty: 55 | self.train_dataset = get_nonempty_subset(self.dataset, 'train', transform=preprocess) 56 | else: 57 | self.train_dataset = self.dataset.get_subset('train', transform=preprocess) 58 | 59 | if remove_non_empty: 60 | self.test_dataset = get_nonempty_subset(self.dataset, subset, transform=preprocess) 61 | else: 62 | self.test_dataset = self.dataset.get_subset(subset, transform=preprocess) 63 | 64 | self.test_loader = get_eval_loader( 65 | "standard", self.test_dataset, 66 | num_workers=num_workers, 67 | batch_size=batch_size) 68 | 69 | labels_csv = pathlib.Path(__file__).parent / 'iwildcam_metadata' / 'labels.csv' 70 | df = pd.read_csv(labels_csv) 71 | df = df[df['y'] < 99999] 72 | 73 | self.classnames = [s.lower() for s in list(df['english'])] 74 | 75 | def post_loop_metrics(self, labels, preds, metadata, args): 76 | preds = preds.argmax(dim=1, keepdim=True).view_as(labels) 77 | results = self.dataset.eval(preds, labels, metadata) 78 | return results[0] 79 | 80 | 81 | class IWildCamID(IWildCam): 82 | def __init__(self, *args, **kwargs): 83 | kwargs['subset'] = 'id_test' 84 | super().__init__(*args, **kwargs) 85 | 86 | 87 | class IWildCamOOD(IWildCam): 88 | def __init__(self, *args, **kwargs): 89 | kwargs['subset'] = 'test' 90 | super().__init__(*args, **kwargs) 91 | 92 | 93 | class IWildCamNonEmpty(IWildCam): 94 | def __init__(self, *args, **kwargs): 95 | kwargs['subset'] = 'train' 96 | super().__init__(*args, **kwargs) 97 | 98 | 99 | class IWildCamIDNonEmpty(IWildCam): 100 | def __init__(self, *args, **kwargs): 101 | kwargs['subset'] = 'id_test' 102 | super().__init__(*args, **kwargs) 103 | 104 | 105 | class IWildCamOODNonEmpty(IWildCam): 106 | def __init__(self, *args, **kwargs): 107 | kwargs['subset'] = 'test' 108 | super().__init__(*args, **kwargs) 109 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/objectnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from pathlib import Path 4 | import PIL 5 | 6 | import numpy as np 7 | 8 | import torch 9 | from torchvision import datasets 10 | from torchvision.transforms import Compose 11 | 12 | from .common import ImageFolderWithPaths, SubsetSampler 13 | from .imagenet import ImageNet, ImageNetSubsampleValClasses 14 | 15 | 16 | def get_metadata(): 17 | metadata = Path(__file__).parent / 'objectnet_metadata' 18 | 19 | with open(metadata / 'folder_to_objectnet_label.json', 'r') as f: 20 | folder_map = json.load(f) 21 | folder_map = {v: k for k, v in folder_map.items()} 22 | with open(metadata / 'objectnet_to_imagenet_1k.json', 'r') as f: 23 | objectnet_map = json.load(f) 24 | 25 | with open(metadata / 'pytorch_to_imagenet_2012_id.json', 'r') as f: 26 | pytorch_map = json.load(f) 27 | pytorch_map = {v: k for k, v in pytorch_map.items()} 28 | 29 | with open(metadata / 'imagenet_to_label_2012_v2', 'r') as f: 30 | imagenet_map = {v.strip(): str(pytorch_map[i]) for i, v in enumerate(f)} 31 | 32 | folder_to_ids, class_sublist = {}, [] 33 | classnames = [] 34 | for objectnet_name, imagenet_names in objectnet_map.items(): 35 | imagenet_names = imagenet_names.split('; ') 36 | imagenet_ids = [int(imagenet_map[imagenet_name]) for imagenet_name in imagenet_names] 37 | class_sublist.extend(imagenet_ids) 38 | folder_to_ids[folder_map[objectnet_name]] = imagenet_ids 39 | 40 | class_sublist = sorted(class_sublist) 41 | class_sublist_mask = [(i in class_sublist) for i in range(1000)] 42 | classname_map = {v: k for k, v in folder_map.items()} 43 | return class_sublist, class_sublist_mask, folder_to_ids, classname_map 44 | 45 | 46 | def crop(img): 47 | width, height = img.size 48 | cropArea = (2, 2, width - 2, height - 2) 49 | img = img.crop(cropArea) 50 | return img 51 | 52 | 53 | class ObjectNetDataset(datasets.ImageFolder): 54 | 55 | def __init__(self, label_map, path, transform): 56 | self.label_map = label_map 57 | super().__init__(path, transform=transform) 58 | self.samples = [ 59 | d for d in self.samples 60 | if os.path.basename(os.path.dirname(d[0])) in self.label_map 61 | ] 62 | self.imgs = self.samples 63 | 64 | def __len__(self): 65 | return len(self.samples) 66 | 67 | def __getitem__(self, index): 68 | path, target = self.samples[index] 69 | sample = self.loader(path) 70 | if self.transform is not None: 71 | sample = self.transform(sample) 72 | label = os.path.basename(os.path.dirname(path)) 73 | return { 74 | 'images': sample, 75 | 'labels': self.label_map[label], 76 | 'image_paths': path 77 | } 78 | 79 | 80 | class ObjectNetBase(ImageNet): 81 | def __init__(self, *args, **kwargs): 82 | (self._class_sublist, 83 | self.class_sublist_mask, 84 | self.folders_to_ids, 85 | self.classname_map) = get_metadata() 86 | 87 | super().__init__(*args, **kwargs) 88 | 89 | self.classnames = sorted(list(self.folders_to_ids.keys())) 90 | self.rev_class_idx_map = {} 91 | self.class_idx_map = {} 92 | for idx, name in enumerate(self.classnames): 93 | self.rev_class_idx_map[idx] = self.folders_to_ids[name] 94 | for imagenet_idx in self.rev_class_idx_map[idx]: 95 | self.class_idx_map[imagenet_idx] = idx 96 | 97 | self.crop = crop 98 | self.preprocess = Compose([crop, self.preprocess]) 99 | self.classnames = [self.classname_map[c].lower() for c in self.classnames] 100 | 101 | def populate_train(self): 102 | pass 103 | 104 | def get_test_dataset(self): 105 | subdir = 'objectnet-1.0/images' 106 | valdir = os.path.join(self.location, subdir) 107 | label_map = {name: idx for idx, name in enumerate(sorted(list(self.folders_to_ids.keys())))} 108 | return ObjectNetDataset(label_map, valdir, transform=self.preprocess) 109 | 110 | def project_logits(self, logits, device): 111 | if isinstance(logits, list) or isinstance(logits, tuple): 112 | return [self.project_logits(l, device) for l in logits] 113 | if logits.shape[1] == 113: 114 | return logits 115 | if torch.is_tensor(logits): 116 | logits = logits.cpu().numpy() 117 | logits_projected = np.zeros((logits.shape[0], 113)) 118 | for k, v in self.rev_class_idx_map.items(): 119 | logits_projected[:, k] = np.max(logits[:, v], axis=1).squeeze() 120 | return torch.tensor(logits_projected).to(device) 121 | 122 | def scatter_weights(self, weights): 123 | if weights.size(1) == 1000: 124 | return weights 125 | new_weights = torch.ones((weights.size(0), 1000)).to(weights.device) * -10e8 126 | for k, v in self.rev_class_idx_map.items(): 127 | for vv in v: 128 | new_weights[:, vv] = weights[:, k] 129 | return new_weights 130 | 131 | 132 | 133 | def accuracy(logits, targets, img_paths, args): 134 | assert logits.shape[1] == 113 135 | preds = logits.argmax(dim=1) 136 | if torch.is_tensor(preds): 137 | preds = preds.cpu().numpy() 138 | if torch.is_tensor(targets): 139 | targets = targets.cpu().numpy() 140 | return np.sum(preds == targets), len(preds) 141 | 142 | 143 | class ObjectNetValClasses(ObjectNetBase): 144 | 145 | def get_test_sampler(self): 146 | idx_subsample_list = [range(x * 50, (x + 1) * 50) for x in self._class_sublist] 147 | idx_subsample_list = sorted([item for sublist in idx_subsample_list for item in sublist]) 148 | 149 | sampler = SubsetSampler(idx_subsample_list) 150 | return sampler 151 | 152 | def get_test_dataset(self): 153 | return ImageFolderWithPaths(self.get_test_path(), transform=self.preprocess) 154 | 155 | def project_labels(self, labels, device): 156 | projected_labels = [self.class_idx_map[int(label)] for label in labels] 157 | return torch.LongTensor(projected_labels).to(device) 158 | 159 | 160 | class ObjectNet(ObjectNetBase): 161 | 162 | def accuracy(self, logits, targets, img_paths, args): 163 | return accuracy(logits, targets, img_paths, args) 164 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/objectnet_metadata/objectnet_to_imagenet_1k.json: -------------------------------------------------------------------------------- 1 | { 2 | "Alarm clock": "analog clock; digital clock", 3 | "Backpack": "backpack, back pack, knapsack, packsack, rucksack, haversack", 4 | "Banana": "banana", 5 | "Band Aid": "Band Aid", 6 | "Basket": "shopping basket", 7 | "Bath towel": "bath towel", 8 | "Beer bottle": "beer bottle", 9 | "Bench": "park bench", 10 | "Bicycle": "mountain bike, all-terrain bike, off-roader; bicycle-built-for-two, tandem bicycle, tandem", 11 | "Binder (closed)": "binder, ring-binder", 12 | "Bottle cap": "bottlecap", 13 | "Bread loaf": "French loaf", 14 | "Broom": "broom", 15 | "Bucket": "bucket, pail", 16 | "Butcher's knife": "cleaver, meat cleaver, chopper", 17 | "Can opener": "can opener, tin opener", 18 | "Candle": "candle, taper, wax light", 19 | "Cellphone": "cellular telephone, cellular phone, cellphone, cell, mobile phone", 20 | "Chair": "barber chair; folding chair; rocking chair, rocker", 21 | "Clothes hamper": "hamper", 22 | "Coffee/French press": "espresso maker", 23 | "Combination lock": "combination lock", 24 | "Computer mouse": "mouse, computer mouse", 25 | "Desk lamp": "table lamp", 26 | "Dishrag or hand towel": "dishrag, dishcloth", 27 | "Doormat": "doormat, welcome mat", 28 | "Dress shoe (men)": "Loafer", 29 | "Drill": "power drill", 30 | "Drinking Cup": "cup", 31 | "Drying rack for plates": "plate rack", 32 | "Envelope": "envelope", 33 | "Fan": "electric fan, blower", 34 | "Frying pan": "frying pan, frypan, skillet", 35 | "Dress": "gown", 36 | "Hair dryer": "hand blower, blow dryer, blow drier, hair dryer, hair drier", 37 | "Hammer": "hammer", 38 | "Helmet": "football helmet; crash helmet", 39 | "Iron (for clothes)": "iron, smoothing iron", 40 | "Jeans": "jean, blue jean, denim", 41 | "Keyboard": "computer keyboard, keypad", 42 | "Ladle": "ladle", 43 | "Lampshade": "lampshade, lamp shade", 44 | "Laptop (open)": "laptop, laptop computer", 45 | "Lemon": "lemon", 46 | "Letter opener": "letter opener, paper knife, paperknife", 47 | "Lighter": "lighter, light, igniter, ignitor", 48 | "Lipstick": "lipstick, lip rouge", 49 | "Match": "matchstick", 50 | "Measuring cup": "measuring cup", 51 | "Microwave": "microwave, microwave oven", 52 | "Mixing / Salad Bowl": "mixing bowl", 53 | "Monitor": "monitor", 54 | "Mug": "coffee mug", 55 | "Nail (fastener)": "nail", 56 | "Necklace": "necklace", 57 | "Orange": "orange", 58 | "Padlock": "padlock", 59 | "Paintbrush": "paintbrush", 60 | "Paper towel": "paper towel", 61 | "Pen": "ballpoint, ballpoint pen, ballpen, Biro; quill, quill pen; fountain pen", 62 | "Pill bottle": "pill bottle", 63 | "Pillow": "pillow", 64 | "Pitcher": "pitcher, ewer", 65 | "Plastic bag": "plastic bag", 66 | "Plate": "plate", 67 | "Plunger": "plunger, plumber's helper", 68 | "Pop can": "pop bottle, soda bottle", 69 | "Portable heater": "space heater", 70 | "Printer": "printer", 71 | "Remote control": "remote control, remote", 72 | "Ruler": "rule, ruler", 73 | "Running shoe": "running shoe", 74 | "Safety pin": "safety pin", 75 | "Salt shaker": "saltshaker, salt shaker", 76 | "Sandal": "sandal", 77 | "Screw": "screw", 78 | "Shovel": "shovel", 79 | "Skirt": "hoopskirt, crinoline; miniskirt, mini; overskirt", 80 | "Sleeping bag": "sleeping bag", 81 | "Soap dispenser": "soap dispenser", 82 | "Sock": "sock", 83 | "Soup Bowl": "soup bowl", 84 | "Spatula": "spatula", 85 | "Speaker": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", 86 | "Still Camera": "Polaroid camera, Polaroid Land camera; reflex camera", 87 | "Strainer": "strainer", 88 | "Stuffed animal": "teddy, teddy bear", 89 | "Suit jacket": "suit, suit of clothes", 90 | "Sunglasses": "sunglasses, dark glasses, shades", 91 | "Sweater": "sweatshirt", 92 | "Swimming trunks": "swimming trunks, bathing trunks", 93 | "T-shirt": "jersey, T-shirt, tee shirt", 94 | "TV": "television, television system", 95 | "Teapot": "teapot", 96 | "Tennis racket": "racket, racquet", 97 | "Tie": "bow tie, bow-tie, bowtie; Windsor tie", 98 | "Toaster": "toaster", 99 | "Toilet paper roll": "toilet tissue, toilet paper, bathroom tissue", 100 | "Trash bin": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", 101 | "Tray": "tray", 102 | "Umbrella": "umbrella", 103 | "Vacuum cleaner": "vacuum, vacuum cleaner", 104 | "Vase": "vase", 105 | "Wallet": "wallet, billfold, notecase, pocketbook", 106 | "Watch": "digital watch", 107 | "Water bottle": "water bottle", 108 | "Weight (exercise)": "dumbbell", 109 | "Weight scale": "scale, weighing machine", 110 | "Wheel": "car wheel; paddlewheel, paddle wheel", 111 | "Whistle": "whistle", 112 | "Wine bottle": "wine bottle", 113 | "Winter glove": "mitten", 114 | "Wok": "wok" 115 | } 116 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/transfer_ds/cal_mean_std.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | 4 | # dataset = datasets.ImageFolder('/opt/tiger/filter_transfer/data/PASS_dataset/train',transform=transforms.ToTensor()) 5 | # dataset = datasets.ImageFolder('/opt/tiger/filter_transfer/data/imagenet/ILSVRC2012_img_train/train',transform=transforms.ToTensor()) 6 | dataset = datasets.ImageFolder('/opt/tiger/filter_transfer/data/PASS_dataset/train', 7 | transform=transforms.Compose([transforms.Resize(256), 8 | transforms.CenterCrop(224), 9 | transforms.ToTensor()])) 10 | 11 | # --------- PASS 12 | # mean: tensor([0.4646, 0.4484, 0.4129]) 13 | # std: tensor([0.2750, 0.2689, 0.2885]) 14 | 15 | # dataset = datasets.ImageFolder('/opt/tiger/filter_transfer/data/imagenet/ILSVRC2012_img_train/train', 16 | # transform=transforms.Compose([transforms.Resize(256), 17 | # transforms.CenterCrop(224), 18 | # transforms.ToTensor()])) 19 | 20 | loader = torch.utils.data.DataLoader(dataset, 21 | batch_size=1000, 22 | num_workers=8, 23 | shuffle=False) 24 | 25 | # mean = 0. 26 | # meansq = 0. 27 | # i=0 28 | # for data,_ in loader: 29 | # print('{}/{}'.format(i,len(loader))) 30 | # i+=1 31 | # mean = data.mean() 32 | # meansq = (data ** 2).mean() 33 | # 34 | # std = torch.sqrt(meansq - mean ** 2) 35 | # print("mean: " + str(mean)) 36 | # print("std: " + str(std)) 37 | # print() 38 | 39 | # mean = 0.0 40 | # i=0 41 | # for images, _ in loader: 42 | # batch_samples = images.size(0) 43 | # images = images.view(batch_samples, images.size(1), -1) 44 | # mean += images.mean(2).sum(0) 45 | # print('{}/{}'.format(i, len(loader))) 46 | # i+=1 47 | # print(mean / i / 1000) 48 | # mean = mean / len(loader.dataset) / 1000 49 | 50 | # import ipdb 51 | # ipdb.set_trace(context=20) 52 | # mean = torch.FloatTensor([0.485, 0.456, 0.406]) 53 | mean = torch.FloatTensor([0.4646, 0.4484, 0.4129]) 54 | var = 0.0 55 | i=0 56 | for images, _ in loader: 57 | batch_samples = images.size(0) 58 | images = images.view(batch_samples, images.size(1), -1) 59 | var += ((images - mean.unsqueeze(1))**2).sum([0,2]) 60 | print('{}/{}'.format(i, len(loader))) 61 | i += 1 62 | print(torch.sqrt(var / (i*224*224))) 63 | print(torch.sqrt(var / (i*1000*224*224))) 64 | std = torch.sqrt(var / (len(loader.dataset)*224*224)) 65 | 66 | import ipdb 67 | ipdb.set_trace(context=20) 68 | 69 | a=1 70 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/transfer_ds/constants.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from .Randaug import RandAugment 3 | 4 | prefix = "data" 5 | 6 | # IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_train_100k/" 7 | # IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_train_128k/" 8 | # IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_train_200k/" 9 | IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_train/" 10 | # IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_train_plus15tasks/" 11 | # IMGNET_PATH = prefix + "/imagenet/imagenet.10_1000/" 12 | # IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_traintrain_200cls_640shot_128.0k/" 13 | # IMGNET_PATH = prefix + "/imagenet/ILSVRC2012_img_traintrain_500cls_256shot_128.0k/" 14 | # IMGNET_PATH = prefix + "/place_128k" 15 | # IMGNET_PATH = prefix + "/pass/PASS_128k" 16 | 17 | # Planes dataset 18 | FGVC_PATH = prefix + "/fgvc-aircraft-2013b/" 19 | 20 | # Oxford Flowers dataset 21 | # FLOWERS_PATH = prefix + "/oxford_flowers_pytorch/" 22 | FLOWERS_PATH = prefix + "/flowers_new/" 23 | 24 | # DTD dataset 25 | DTD_PATH = prefix + "/dtd/" 26 | 27 | # Stanford Cars dataset 28 | CARS_PATH = prefix + "/cars_new" 29 | 30 | # SUN397 dataset 31 | SUN_PATH = prefix + "/SUN397/splits_01/" 32 | 33 | # FOOD dataset 34 | FOOD_PATH = prefix + "/food-101" 35 | 36 | # BIRDS dataset 37 | BIRDS_PATH = prefix + "/birdsnap" 38 | 39 | # CUB-200-2011 birds 40 | CUB_PATH = prefix + "/CUB_200_2011" 41 | 42 | # COCO 43 | COCO_PATH = prefix + "/coco_cls" 44 | 45 | # ade20k 46 | ADE20K_PATH = prefix + "/ade20k_cls" 47 | 48 | # Mix seg: cs voc ade 49 | MIX_SEG_PATH = prefix + "/mix_seg" 50 | 51 | # PETS dataset 52 | PETS_PATH = prefix + "" 53 | 54 | # Caltech datasets 55 | CALTECH101_PATH = prefix + "" 56 | CALTECH256_PATH = prefix + "" 57 | 58 | value_scale = 255 59 | mean = [0.485, 0.456, 0.406] 60 | mean = [0.48145466, 0.4578275, 0.40821073] 61 | mean = [item * value_scale for item in mean] 62 | std = [0.229, 0.224, 0.225] 63 | std = [0.26862954, 0.26130258, 0.27577711] 64 | std = [item * value_scale for item in std] 65 | 66 | # Data Augmentation defaults 67 | TRAIN_TRANSFORMS = transforms.Compose([ 68 | # transforms.Resize(32), 69 | transforms.RandomResizedCrop(224), 70 | # transforms.RandomResizedCrop(224, scale=(0.08,1.0), ratio=(0.75,1.333333)), 71 | # transforms.RandomResizedCrop(224, scale=(0.08,1.0), ratio=(0.5,2.0)), 72 | transforms.RandomHorizontalFlip(), 73 | transforms.ToTensor(), 74 | # transforms.Normalize(mean=mean, std=std), 75 | ]) 76 | # TRAIN_TRANSFORMS = transforms.Compose([ 77 | # # transforms.Resize(32), 78 | # transforms.Resize(256), 79 | # transforms.CenterCrop(224), 80 | # transforms.ToTensor(), 81 | # # transforms.Normalize(mean=mean, std=std), 82 | # ]) 83 | 84 | TEST_TRANSFORMS = transforms.Compose([ 85 | # transforms.Resize(32), 86 | transforms.Resize(256), 87 | transforms.CenterCrop(224), 88 | transforms.ToTensor(), 89 | # transforms.Normalize(mean=mean, std=std), 90 | ]) 91 | 92 | # from PIL import Image 93 | # BICUBIC = Image.BICUBIC 94 | # TEST_TRANSFORMS = transforms.Compose([ 95 | # # transforms.Resize(32), 96 | # transforms.Resize(224,interpolation=BICUBIC), 97 | # transforms.CenterCrop(224), 98 | # transforms.ToTensor(), 99 | # # transforms.Normalize(mean=mean, std=std), 100 | # ]) 101 | 102 | # Add RandAugment with N, M(hyperparameter) 103 | # N=3 104 | # M=9 105 | # TRAIN_TRANSFORMS.transforms.insert(0, RandAugment(N, M)) -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/transfer_ds/cub.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class CUB(Dataset): 8 | 9 | def __init__(self, path, train=True, transform=None, target_transform=None): 10 | 11 | self.root = path 12 | self.is_train = train 13 | self.transform = transform 14 | self.target_transform = target_transform 15 | self.images_path = {} 16 | with open(os.path.join(self.root, 'images.txt')) as f: 17 | for line in f: 18 | image_id, path = line.split() 19 | self.images_path[image_id] = path 20 | 21 | self.class_ids = {} 22 | with open(os.path.join(self.root, 'image_class_labels.txt')) as f: 23 | for line in f: 24 | image_id, class_id = line.split() 25 | self.class_ids[image_id] = class_id 26 | 27 | self.data_id = [] 28 | if self.is_train: 29 | with open(os.path.join(self.root, 'train_test_split.txt')) as f: 30 | for line in f: 31 | image_id, is_train = line.split() 32 | if int(is_train): 33 | self.data_id.append(image_id) 34 | if not self.is_train: 35 | with open(os.path.join(self.root, 'train_test_split.txt')) as f: 36 | for line in f: 37 | image_id, is_train = line.split() 38 | if not int(is_train): 39 | self.data_id.append(image_id) 40 | 41 | def __len__(self): 42 | return len(self.data_id) 43 | 44 | def __getitem__(self, index): 45 | """ 46 | Args: 47 | index: index of training dataset 48 | Returns: 49 | image and its corresponding label 50 | """ 51 | image_id = self.data_id[index] 52 | class_id = int(self._get_class_by_id(image_id)) - 1 53 | path = self._get_path_by_id(image_id) 54 | image = cv2.imread(os.path.join(self.root, 'images', path)) 55 | 56 | if self.transform: 57 | image = self.transform(image) 58 | 59 | if self.target_transform: 60 | class_id = self.target_transform(class_id) 61 | return image, class_id 62 | 63 | def _get_path_by_id(self, image_id): 64 | 65 | return self.images_path[image_id] 66 | 67 | def _get_class_by_id(self, image_id): 68 | 69 | return self.class_ids[image_id] -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/transfer_ds/cub_for_robust_codebase.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | def main(): 7 | path= "/opt/tiger/filter_transfer/data/CUB_200_2011" 8 | root = path 9 | images_path = {} 10 | 11 | os.system('mkdir train') 12 | os.system('cp -r images/* train/') 13 | os.system('mkdir val') 14 | os.system('cp -r images/* val/') 15 | 16 | with open(os.path.join(root, 'images.txt')) as f: 17 | for line in f: 18 | image_id, path = line.split() 19 | images_path[image_id] = path 20 | 21 | class_ids = {} 22 | with open(os.path.join(root, 'image_class_labels.txt')) as f: 23 | for line in f: 24 | image_id, class_id = line.split() 25 | class_ids[image_id] = class_id 26 | 27 | train_id = [] # train not val 28 | with open(os.path.join(root, 'train_test_split.txt')) as f: 29 | for line in f: 30 | image_id, is_train = line.split() 31 | if int(is_train): 32 | train_id.append(image_id) 33 | 34 | with open(os.path.join(root, 'images.txt')) as f: 35 | for line in f: 36 | image_id, path = line.split() 37 | if image_id in train_id: 38 | os.system('rm val/{}'.format(path)) 39 | else: 40 | # import ipdb 41 | # ipdb.set_trace(context=20) 42 | os.system('rm train/{}'.format(path)) 43 | 44 | 45 | 46 | 47 | 48 | 49 | if __name__ == "__main__": 50 | main() -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/transfer_ds/dtd.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | # from . import constants as cs 3 | from torch.utils.data.dataset import Dataset 4 | from torch.utils.data import DataLoader 5 | from os.path import join as osj 6 | from PIL import Image 7 | from torchvision import transforms 8 | import os 9 | 10 | TRAIN_TRANSFORMS = transforms.Compose([ 11 | # transforms.Resize(32), 12 | transforms.RandomResizedCrop(224), 13 | transforms.RandomHorizontalFlip(), 14 | transforms.ToTensor(), 15 | # transforms.Normalize(mean=mean, std=std), 16 | ]) 17 | 18 | TEST_TRANSFORMS = transforms.Compose([ 19 | # transforms.Resize(32), 20 | transforms.Resize(256), 21 | transforms.CenterCrop(224), 22 | transforms.ToTensor(), 23 | # transforms.Normalize(mean=mean, std=std), 24 | ]) 25 | 26 | class DTD(Dataset): 27 | def __init__(self, split="1", train=False, transform=TRAIN_TRANSFORMS): 28 | super().__init__() 29 | DTD_PATH='/opt/tiger/filter_transfer/data/dtd' 30 | train_path = osj(DTD_PATH, f"labels/train{split}.txt") 31 | val_path = osj(DTD_PATH, f"labels/val{split}.txt") 32 | test_path = osj(DTD_PATH, f"labels/test{split}.txt") 33 | if train: 34 | print(DTD_PATH) 35 | self.ims = open(train_path).readlines() + \ 36 | open(val_path).readlines() 37 | else: 38 | self.ims = open(test_path).readlines() 39 | 40 | self.full_ims = [osj(DTD_PATH, "images", x) for x in self.ims] 41 | 42 | pth = osj(DTD_PATH, f"labels/classes.txt") 43 | self.c_to_t = {x.strip(): i for i, x in enumerate(open(pth).readlines())} 44 | 45 | # self.transform = TRAIN_TRANSFORMS if train else TEST_TRANSFORMS 46 | self.transform = transform 47 | self.labels = [self.c_to_t[x.split("/")[0]] for x in self.ims] 48 | 49 | def __getitem__(self, index): 50 | im = Image.open(self.full_ims[index].strip()) 51 | im = self.transform(im) 52 | return im, self.labels[index] 53 | 54 | def __len__(self): 55 | return len(self.ims) 56 | 57 | if __name__ == "__main__": 58 | dtd = DTD(train=True) 59 | # import ipdb 60 | # ipdb.set_trace(context=20) 61 | target_folder = "/opt/tiger/filter_transfer/data/dtd/mix_dtd/" 62 | for im in dtd.full_ims: 63 | img = im[:-1] 64 | category = img.split('/')[-2] 65 | if not os.path.exists(target_folder+category):os.makedirs(target_folder+category) 66 | os.system('cp {} {}'.format(img, target_folder+category)) 67 | a=1 -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/transfer_ds/fine_tunify.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from robustness.tools.custom_modules import SequentialWithArgs 3 | 4 | def ft(model_name, model_ft, num_classes, additional_hidden=0): 5 | if model_name in ['clip_resnest50d','resnet50_feat_pca_pre_relu_multi_pool','resnet18_feat_pre_relu_multi_pool','resnet50_feat_pca_pre_relu_multi','resnet18_feat_pre_relu_multi','resnet50_feat_interpolate_multi','resnet18_multi','resnet50_feat_pre_relu_multi','resnet50_overhaul','resnet18_feat_pre_relu_regressor','resnet18_custom','resnet18_feat_pre_relu',"resnet50_feat_interpolate","resnet50_feat_pca","resnet50_feat_nmf","resnet50_feat_lda","resnet50_feat_mag","resnet18_feat","resnet152_feat","resnet50_feat","resnet","resnet20_as_gift","resnet50_clean", "resnet18", "resnet34","resnet50", "wide_resnet50_2", "wide_resnet50_4", "resnext50_32x4d", 'shufflenet']: 6 | num_ftrs = model_ft.fc.in_features 7 | # The two cases are split just to allow loading 8 | # models trained prior to adding the additional_hidden argument 9 | # without errors 10 | if additional_hidden == 0: 11 | model_ft.fc = nn.Linear(num_ftrs, num_classes) 12 | else: 13 | model_ft.fc = SequentialWithArgs( 14 | *list(sum([[nn.Linear(num_ftrs, num_ftrs), nn.ReLU()] for i in range(additional_hidden)], [])), 15 | nn.Linear(num_ftrs, num_classes) 16 | ) 17 | input_size = 224 18 | 19 | elif model_name == 'RN50': 20 | num_ftrs = 1024 21 | # model_ft.fc = nn.Linear(num_ftrs, num_classes) 22 | input_size = 224 23 | elif model_name == "alexnet": 24 | num_ftrs = model_ft.classifier[6].in_features 25 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes) 26 | input_size = 224 27 | elif "vgg" in model_name: 28 | num_ftrs = model_ft.classifier[6].in_features 29 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes) 30 | input_size = 224 31 | elif model_name == "squeezenet": 32 | model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1)) 33 | model_ft.num_classes = num_classes 34 | input_size = 224 35 | elif model_name == "densenet": 36 | num_ftrs = model_ft.classifier.in_features 37 | model_ft.classifier = nn.Linear(num_ftrs, num_classes) 38 | input_size = 224 39 | elif model_name in ["mnasnet", "mobilenet"]: 40 | num_ftrs = model_ft.classifier.in_features 41 | model_ft.classifier = nn.Linear(num_ftrs, num_classes) 42 | input_size = 224 43 | else: 44 | pass 45 | # raise ValueError("Invalid model type, exiting...") 46 | 47 | return model_ft 48 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/transfer_ds/food_101.py: -------------------------------------------------------------------------------- 1 | # pytorch imports 2 | import torch 3 | from torchvision import models, transforms, datasets 4 | from torch.utils.data import DataLoader 5 | from robustness import data_augmentation as da 6 | from . import constants as cs 7 | 8 | class FOOD101(): 9 | def __init__(self, transform=None): 10 | # self.TRAIN_PATH = cs.FOOD_PATH+"/train" 11 | self.TRAIN_PATH = "/opt/tiger/filter_transfer/data/food-101/train" 12 | # self.VALID_PATH = cs.FOOD_PATH+"/valid" 13 | self.VALID_PATH = "/opt/tiger/filter_transfer/data/food-101/valid" 14 | 15 | self.train_ds, self.valid_ds, self.train_cls, self.valid_cls = [None]*4 16 | self.transform = transform 17 | 18 | def _get_tfms(self): 19 | train_tfms = cs.TRAIN_TRANSFORMS 20 | valid_tfms = cs.TEST_TRANSFORMS 21 | return train_tfms, valid_tfms 22 | 23 | def get_dataset(self): 24 | # train_tfms, valid_tfms = self._get_tfms() # transformations 25 | train_tfms, valid_tfms = self.transform, self.transform 26 | self.train_ds = datasets.ImageFolder(root=self.TRAIN_PATH, 27 | transform=train_tfms) 28 | self.valid_ds = datasets.ImageFolder(root=self.VALID_PATH, 29 | transform=valid_tfms) 30 | self.train_classes = self.train_ds.classes 31 | self.valid_classes = self.valid_ds.classes 32 | 33 | # print(self.train_classes) 34 | 35 | assert self.train_classes==self.valid_classes 36 | return self.train_ds, self.valid_ds, self.train_classes 37 | 38 | def get_dls(self, train_ds, valid_ds, bs, **kwargs): 39 | return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs), 40 | DataLoader(valid_ds, batch_size=bs, shuffle=True, **kwargs)) 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/transfer_ds/imbalance_cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | 6 | 7 | class IMBALANCECIFAR10(torchvision.datasets.CIFAR10): 8 | cls_num = 10 9 | 10 | def __init__(self, root, imb_type='exp', imb_factor=0.01, rand_number=0, train=True, 11 | transform=None, target_transform=None, 12 | download=False): 13 | super(IMBALANCECIFAR10, self).__init__(root, train, transform, target_transform, download) 14 | np.random.seed(rand_number) 15 | img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor) 16 | self.gen_imbalanced_data(img_num_list) 17 | 18 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor): 19 | img_max = len(self.data) / cls_num 20 | img_num_per_cls = [] 21 | if imb_type == 'exp': 22 | for cls_idx in range(cls_num): 23 | num = img_max * (imb_factor ** (cls_idx / (cls_num - 1.0))) 24 | img_num_per_cls.append(int(num)) 25 | elif imb_type == 'step': 26 | for cls_idx in range(cls_num // 2): 27 | img_num_per_cls.append(int(img_max)) 28 | for cls_idx in range(cls_num // 2): 29 | img_num_per_cls.append(int(img_max * imb_factor)) 30 | else: 31 | img_num_per_cls.extend([int(img_max)] * cls_num) 32 | return img_num_per_cls 33 | 34 | def gen_imbalanced_data(self, img_num_per_cls): 35 | new_data = [] 36 | new_targets = [] 37 | targets_np = np.array(self.targets, dtype=np.int64) 38 | classes = np.unique(targets_np) 39 | # np.random.shuffle(classes) 40 | self.num_per_cls_dict = dict() 41 | for the_class, the_img_num in zip(classes, img_num_per_cls): 42 | self.num_per_cls_dict[the_class] = the_img_num 43 | idx = np.where(targets_np == the_class)[0] 44 | np.random.shuffle(idx) 45 | selec_idx = idx[:the_img_num] 46 | new_data.append(self.data[selec_idx, ...]) 47 | new_targets.extend([the_class, ] * the_img_num) 48 | new_data = np.vstack(new_data) 49 | self.data = new_data 50 | self.targets = new_targets 51 | 52 | def get_cls_num_list(self): 53 | cls_num_list = [] 54 | for i in range(self.cls_num): 55 | cls_num_list.append(self.num_per_cls_dict[i]) 56 | return cls_num_list 57 | 58 | 59 | class IMBALANCECIFAR100(IMBALANCECIFAR10): 60 | """`CIFAR100 `_ Dataset. 61 | This is a subclass of the `CIFAR10` Dataset. 62 | """ 63 | base_folder = 'cifar-100-python' 64 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 65 | filename = "cifar-100-python.tar.gz" 66 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 67 | train_list = [ 68 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 69 | ] 70 | 71 | test_list = [ 72 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 73 | ] 74 | meta = { 75 | 'filename': 'meta', 76 | 'key': 'fine_label_names', 77 | 'md5': '7973b15100ade9c7d40fb424638fde48', 78 | } 79 | cls_num = 100 80 | 81 | 82 | if __name__ == '__main__': 83 | transform = transforms.Compose( 84 | [transforms.ToTensor(), 85 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 86 | trainset = IMBALANCECIFAR100(root='./data', train=True, 87 | download=True, transform=transform) 88 | trainloader = iter(trainset) 89 | data, label = next(trainloader) 90 | import pdb; 91 | 92 | pdb.set_trace() -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/transfer_ds/process_dataset/pro_aircraft.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import numpy as np 5 | 6 | # --- get class names 7 | # read_filepath = 'data/variants.txt' 8 | # names=[] 9 | # with open(read_filepath, 'r') as f: 10 | # for line in f.readlines(): 11 | # print(line.strip()) 12 | # names.append(line.strip()) 13 | # 14 | # names = ["a "+n for n in names] 15 | # print(names) 16 | 17 | 18 | # --- find train images 19 | read_filepath = 'data/images_variant_trainval.txt' 20 | names=[] 21 | with open(read_filepath, 'r') as f: 22 | for line in f.readlines(): 23 | line=line.strip().split(' ') 24 | img_name = line[0] + '.jpg' 25 | print(img_name) 26 | names.append(img_name) 27 | 28 | # names = ["a "+n for n in names] 29 | print(names) 30 | 31 | for name in names: 32 | os.system('cp data/images/{} ../orig_pool15/aircraft/'.format(name)) -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/transfer_ds/process_dataset/pro_caltech101.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import numpy as np 5 | 6 | 7 | # --- find train images 8 | read_filepath = 'data/images_variant_trainval.txt' 9 | names=[] 10 | with open(read_filepath, 'r') as f: 11 | for line in f.readlines(): 12 | line=line.strip().split(' ') 13 | img_name = line[0] + '.jpg' 14 | print(img_name) 15 | names.append(img_name) 16 | 17 | # names = ["a "+n for n in names] 18 | print(names) 19 | 20 | for name in names: 21 | os.system('cp data/images/{} ../orig_pool15/aircraft/'.format(name)) 22 | os.system('pwd') -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/transfer_ds/process_dataset/pro_flowers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | cat2name = {"21": "fire lily", "3": "canterbury bells", "45": "bolero deep blue", "1": "pink primrose", "34": "mexican aster", "27": "prince of wales feathers", "7": "moon orchid", "16": "globe-flower", "25": "grape hyacinth", "26": "corn poppy", "79": "toad lily", "39": "siam tulip", "24": "red ginger", "67": "spring crocus", "35": "alpine sea holly", "32": "garden phlox", "10": "globe thistle", "6": "tiger lily", "93": "ball moss", "33": "love in the mist", "9": "monkshood", "102": "blackberry lily", "14": "spear thistle", "19": "balloon flower", "100": "blanket flower", "13": "king protea", "49": "oxeye daisy", "15": "yellow iris", "61": "cautleya spicata", "31": "carnation", "64": "silverbush", "68": "bearded iris", "63": "black-eyed susan", "69": "windflower", "62": "japanese anemone", "20": "giant white arum lily", "38": "great masterwort", "4": "sweet pea", "86": "tree mallow", "101": "trumpet creeper", "42": "daffodil", "22": "pincushion flower", "2": "hard-leaved pocket orchid", "54": "sunflower", "66": "osteospermum", "70": "tree poppy", "85": "desert-rose", "99": "bromelia", "87": "magnolia", "5": "english marigold", "92": "bee balm", "28": "stemless gentian", "97": "mallow", "57": "gaura", "40": "lenten rose", "47": "marigold", "59": "orange dahlia", "48": "buttercup", "55": "pelargonium", "36": "ruby-lipped cattleya", "91": "hippeastrum", "29": "artichoke", "71": "gazania", "90": "canna lily", "18": "peruvian lily", "98": "mexican petunia", "8": "bird of paradise", "30": "sweet william", "17": "purple coneflower", "52": "wild pansy", "84": "columbine", "12": "colt's foot", "11": "snapdragon", "96": "camellia", "23": "fritillary", "50": "common dandelion", "44": "poinsettia", "53": "primula", "72": "azalea", "65": "californian poppy", "80": "anthurium", "76": "morning glory", "37": "cape flower", "56": "bishop of llandaff", "60": "pink-yellow dahlia", "82": "clematis", "58": "geranium", "75": "thorn apple", "41": "barbeton daisy", "95": "bougainvillea", "43": "sword lily", "83": "hibiscus", "78": "lotus", "88": "cyclamen", "94": "foxglove", "81": "frangipani", "74": "rose", "89": "watercress", "73": "water lily", "46": "wallflower", "77": "passion flower", "51": "petunia"} 5 | 6 | cates = sorted(glob.glob('*'))[:-1] 7 | 8 | 9 | 10 | names = ["a "+ cat2name[cat] for cat in cates] 11 | import ipdb 12 | 13 | ipdb.set_trace(context=20) 14 | a=1 15 | 16 | 17 | b=[a[3:] for a in al] 18 | al = ["a "+ v for v in al] -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/transfer_ds/process_dataset/pro_imgnet.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | src_classes = sorted(glob.glob('ILSVRC2012_img_train/train/*')) 5 | os.system('mkdir -p ILSVRC2012_img_train_plus15tasks/train/img') 6 | 7 | i=0 8 | for src_cls in src_classes: 9 | print(i) 10 | i+=1 11 | os.system('cp {}/* ILSVRC2012_img_train_plus15tasks/train/img'.format(src_cls)) 12 | 13 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/transfer_ds/process_dataset/pro_pool15.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | # pwd = /opt/tiger/filter_transfer/data 4 | new_ds = 'ds15img' 5 | all_ds = sorted(glob.glob('pool15/*')) 6 | os.system('mkdir {}'.format(new_ds)) 7 | 8 | for ds in all_ds: 9 | os.system('mv {}/train/*/* {}'.format(ds, new_ds)) 10 | 11 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/transfer_ds/process_dataset/process_pets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | names=[] 5 | 6 | all = sorted(glob.glob('*')) 7 | for i in all: 8 | # os.system('mv {} train/'.format(i)) 9 | # os.system('mkdir val/{}'.format(i)) 10 | 11 | name = sorted(glob.glob('{}/*.jpg'.format(i)))[0] 12 | 13 | # name = 'a '+name.split('/')[-1].split('_')[0] 14 | name = 'a '+name.split('/')[-1].split('_1')[0] 15 | # name = sorted(glob.glob('train/{}/*.jpg'.format(i)))[0].split('_')[0] 16 | 17 | print(name) 18 | names.append(name) 19 | print(names) 20 | # os.system('cp {} val/{}/'.format(to_cp_name,i)) 21 | import ipdb 22 | ipdb.set_trace(context=20) 23 | print(names) 24 | 25 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/transfer_ds/utils.py: -------------------------------------------------------------------------------- 1 | from . import generic_dataset 2 | from . import food_101 3 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/ytbb-robust_metadata/class_idx_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "7": 1, 3 | "8": 1, 4 | "9": 1, 5 | "10": 1, 6 | "11": 1, 7 | "12": 1, 8 | "13": 1, 9 | "14": 1, 10 | "15": 1, 11 | "16": 1, 12 | "17": 1, 13 | "18": 1, 14 | "19": 1, 15 | "20": 1, 16 | "21": 1, 17 | "22": 1, 18 | "23": 1, 19 | "24": 1, 20 | "80": 1, 21 | "81": 1, 22 | "82": 1, 23 | "83": 1, 24 | "84": 1, 25 | "85": 1, 26 | "86": 1, 27 | "87": 1, 28 | "88": 1, 29 | "89": 1, 30 | "90": 1, 31 | "91": 1, 32 | "92": 1, 33 | "93": 1, 34 | "94": 1, 35 | "95": 1, 36 | "96": 1, 37 | "97": 1, 38 | "98": 1, 39 | "99": 1, 40 | "100": 1, 41 | "127": 1, 42 | "128": 1, 43 | "129": 1, 44 | "130": 1, 45 | "131": 1, 46 | "132": 1, 47 | "133": 1, 48 | "134": 1, 49 | "135": 1, 50 | "136": 1, 51 | "137": 1, 52 | "138": 1, 53 | "139": 1, 54 | "140": 1, 55 | "141": 1, 56 | "142": 1, 57 | "143": 1, 58 | "144": 1, 59 | "145": 1, 60 | "146": 1, 61 | "151": 19, 62 | "152": 19, 63 | "153": 19, 64 | "154": 19, 65 | "155": 19, 66 | "156": 19, 67 | "157": 19, 68 | "158": 19, 69 | "159": 19, 70 | "160": 19, 71 | "161": 19, 72 | "162": 19, 73 | "163": 19, 74 | "164": 19, 75 | "165": 19, 76 | "166": 19, 77 | "167": 19, 78 | "168": 19, 79 | "169": 19, 80 | "170": 19, 81 | "171": 19, 82 | "172": 19, 83 | "173": 19, 84 | "174": 19, 85 | "175": 19, 86 | "176": 19, 87 | "177": 19, 88 | "178": 19, 89 | "179": 19, 90 | "180": 19, 91 | "181": 19, 92 | "182": 19, 93 | "183": 19, 94 | "184": 19, 95 | "185": 19, 96 | "186": 19, 97 | "187": 19, 98 | "188": 19, 99 | "189": 19, 100 | "190": 19, 101 | "191": 19, 102 | "192": 19, 103 | "193": 19, 104 | "194": 19, 105 | "195": 19, 106 | "196": 19, 107 | "197": 19, 108 | "198": 19, 109 | "199": 19, 110 | "200": 19, 111 | "201": 19, 112 | "202": 19, 113 | "203": 19, 114 | "204": 19, 115 | "205": 19, 116 | "206": 19, 117 | "207": 19, 118 | "208": 19, 119 | "209": 19, 120 | "210": 19, 121 | "211": 19, 122 | "212": 19, 123 | "213": 19, 124 | "214": 19, 125 | "215": 19, 126 | "216": 19, 127 | "217": 19, 128 | "218": 19, 129 | "219": 19, 130 | "220": 19, 131 | "221": 19, 132 | "222": 19, 133 | "223": 19, 134 | "224": 19, 135 | "225": 19, 136 | "226": 19, 137 | "227": 19, 138 | "228": 19, 139 | "229": 19, 140 | "230": 19, 141 | "231": 19, 142 | "232": 19, 143 | "233": 19, 144 | "234": 19, 145 | "235": 19, 146 | "236": 19, 147 | "237": 19, 148 | "238": 19, 149 | "239": 19, 150 | "240": 19, 151 | "241": 19, 152 | "242": 19, 153 | "243": 19, 154 | "244": 19, 155 | "245": 19, 156 | "246": 19, 157 | "247": 19, 158 | "248": 19, 159 | "249": 19, 160 | "250": 19, 161 | "251": 19, 162 | "252": 19, 163 | "253": 19, 164 | "254": 19, 165 | "255": 19, 166 | "256": 19, 167 | "257": 19, 168 | "258": 19, 169 | "259": 19, 170 | "260": 19, 171 | "261": 19, 172 | "262": 19, 173 | "263": 19, 174 | "264": 19, 175 | "265": 19, 176 | "266": 19, 177 | "267": 19, 178 | "268": 19, 179 | "281": 7, 180 | "282": 7, 181 | "283": 7, 182 | "284": 7, 183 | "285": 7, 184 | "286": 7, 185 | "287": 7, 186 | "294": 5, 187 | "295": 5, 188 | "296": 5, 189 | "297": 5, 190 | "339": 10, 191 | "340": 17, 192 | "385": 20, 193 | "386": 20, 194 | "404": 13, 195 | "407": 23, 196 | "436": 23, 197 | "444": 2, 198 | "466": 15, 199 | "468": 23, 200 | "472": 3, 201 | "499": 12, 202 | "511": 23, 203 | "554": 3, 204 | "555": 16, 205 | "569": 16, 206 | "576": 3, 207 | "609": 23, 208 | "623": 12, 209 | "625": 3, 210 | "627": 23, 211 | "654": 4, 212 | "656": 23, 213 | "661": 23, 214 | "665": 11, 215 | "671": 2, 216 | "675": 16, 217 | "717": 16, 218 | "734": 16, 219 | "751": 23, 220 | "779": 4, 221 | "814": 3, 222 | "817": 23, 223 | "864": 16, 224 | "867": 16, 225 | "874": 4, 226 | "879": 21, 227 | "914": 3, 228 | "981": 0, 229 | "982": 0, 230 | "983": 0, 231 | "985": 9, 232 | "986": 9 233 | } -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/ytbb-robust_metadata/rev_class_idx_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": [ 3 | 981, 4 | 982, 5 | 983 6 | ], 7 | "1": [ 8 | 7, 9 | 8, 10 | 9, 11 | 10, 12 | 11, 13 | 12, 14 | 13, 15 | 14, 16 | 15, 17 | 16, 18 | 17, 19 | 18, 20 | 19, 21 | 20, 22 | 21, 23 | 22, 24 | 23, 25 | 24, 26 | 80, 27 | 81, 28 | 82, 29 | 83, 30 | 84, 31 | 85, 32 | 86, 33 | 87, 34 | 88, 35 | 89, 36 | 90, 37 | 91, 38 | 92, 39 | 93, 40 | 94, 41 | 95, 42 | 96, 43 | 97, 44 | 98, 45 | 99, 46 | 100, 47 | 127, 48 | 128, 49 | 129, 50 | 130, 51 | 131, 52 | 132, 53 | 133, 54 | 134, 55 | 135, 56 | 136, 57 | 137, 58 | 138, 59 | 139, 60 | 140, 61 | 141, 62 | 142, 63 | 143, 64 | 144, 65 | 145, 66 | 146 67 | ], 68 | "2": [ 69 | 444, 70 | 671 71 | ], 72 | "3": [ 73 | 472, 74 | 554, 75 | 576, 76 | 625, 77 | 814, 78 | 914 79 | ], 80 | "4": [ 81 | 654, 82 | 779, 83 | 874 84 | ], 85 | "5": [ 86 | 294, 87 | 295, 88 | 296, 89 | 297 90 | ], 91 | "7": [ 92 | 281, 93 | 282, 94 | 283, 95 | 284, 96 | 285, 97 | 286, 98 | 287 99 | ], 100 | "9": [ 101 | 985, 102 | 986 103 | ], 104 | "10": [ 105 | 339 106 | ], 107 | "11": [ 108 | 665 109 | ], 110 | "12": [ 111 | 499, 112 | 623 113 | ], 114 | "13": [ 115 | 404 116 | ], 117 | "15": [ 118 | 466 119 | ], 120 | "16": [ 121 | 555, 122 | 569, 123 | 675, 124 | 717, 125 | 734, 126 | 864, 127 | 867 128 | ], 129 | "17": [ 130 | 340 131 | ], 132 | "19": [ 133 | 151, 134 | 152, 135 | 153, 136 | 154, 137 | 155, 138 | 156, 139 | 157, 140 | 158, 141 | 159, 142 | 160, 143 | 161, 144 | 162, 145 | 163, 146 | 164, 147 | 165, 148 | 166, 149 | 167, 150 | 168, 151 | 169, 152 | 170, 153 | 171, 154 | 172, 155 | 173, 156 | 174, 157 | 175, 158 | 176, 159 | 177, 160 | 178, 161 | 179, 162 | 180, 163 | 181, 164 | 182, 165 | 183, 166 | 184, 167 | 185, 168 | 186, 169 | 187, 170 | 188, 171 | 189, 172 | 190, 173 | 191, 174 | 192, 175 | 193, 176 | 194, 177 | 195, 178 | 196, 179 | 197, 180 | 198, 181 | 199, 182 | 200, 183 | 201, 184 | 202, 185 | 203, 186 | 204, 187 | 205, 188 | 206, 189 | 207, 190 | 208, 191 | 209, 192 | 210, 193 | 211, 194 | 212, 195 | 213, 196 | 214, 197 | 215, 198 | 216, 199 | 217, 200 | 218, 201 | 219, 202 | 220, 203 | 221, 204 | 222, 205 | 223, 206 | 224, 207 | 225, 208 | 226, 209 | 227, 210 | 228, 211 | 229, 212 | 230, 213 | 231, 214 | 232, 215 | 233, 216 | 234, 217 | 235, 218 | 236, 219 | 237, 220 | 238, 221 | 239, 222 | 240, 223 | 241, 224 | 242, 225 | 243, 226 | 244, 227 | 245, 228 | 246, 229 | 247, 230 | 248, 231 | 249, 232 | 250, 233 | 251, 234 | 252, 235 | 253, 236 | 254, 237 | 255, 238 | 256, 239 | 257, 240 | 258, 241 | 259, 242 | 260, 243 | 261, 244 | 262, 245 | 263, 246 | 264, 247 | 265, 248 | 266, 249 | 267, 250 | 268 251 | ], 252 | "20": [ 253 | 385, 254 | 386 255 | ], 256 | "21": [ 257 | 879 258 | ], 259 | "23": [ 260 | 407, 261 | 436, 262 | 468, 263 | 511, 264 | 609, 265 | 627, 266 | 656, 267 | 661, 268 | 751, 269 | 817 270 | ] 271 | } -------------------------------------------------------------------------------- /src/classfier_tuning/src/datasets/ytbb-robust_metadata/ytbb_class_index.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "person", 3 | "1": "bird", 4 | "2": "bicycle", 5 | "3": "boat", 6 | "4": "bus", 7 | "5": "bear", 8 | "6": "cow", 9 | "7": "cat", 10 | "8": "giraffe", 11 | "9": "potted plant", 12 | "10": "horse", 13 | "11": "motorcycle", 14 | "12": "knife", 15 | "13": "airplane", 16 | "14": "skateboard", 17 | "15": "train", 18 | "16": "truck", 19 | "17": "zebra", 20 | "18": "toilet", 21 | "19": "dog", 22 | "20": "elephant", 23 | "21": "umbrella", 24 | "22": "none", 25 | "23": "car" 26 | } 27 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/eurosat_text_feature.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/src/classfier_tuning/src/eurosat_text_feature.pt -------------------------------------------------------------------------------- /src/classfier_tuning/src/get_classifier_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | import torch 6 | 7 | 8 | from src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier 9 | # from ..src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier 10 | 11 | 12 | # load_path = '/opt/tiger/filter_transfer/src/wise-ft/results/extest.r1/zeroshotEurosat.pt' 13 | # load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/syn_init_tfda_55.68.pt' 14 | # load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/syn_ref16.1_iters50_71.72.pt' 15 | # load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/real_init_16.1_tfda_86.85.pt' 16 | # load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/mix_16.1_iters50_88.21.pt' 17 | # load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/mix_16.1_iters20_88.86.pt' 18 | # load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/mix_16.1_v4_20k_87.86.pt' 19 | load_path = '/opt/tiger/filter_transfer/pretrained-models/clip_few_shot/eurosat/mix_4.1_iter50_81.72.pt' 20 | image_classifier = ImageClassifier.load(load_path) 21 | 22 | # import ipdb 23 | # ipdb.set_trace(context=20) 24 | 25 | head = image_classifier.classification_head 26 | weights = head.weight.detach().numpy() 27 | 28 | torch.save(weights, '/opt/tiger/filter_transfer/src/wise-ft/cache/Eurosat/mix_4.1_iter50_81.72_weights.pt') 29 | a=0 30 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/imagenet_text_feature.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/src/classfier_tuning/src/imagenet_text_feature.pt -------------------------------------------------------------------------------- /src/classfier_tuning/src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongchaoz/diffusion_inversion/fa49f92d550d45f38bc2fe4157eb7a9f3d149158/src/classfier_tuning/src/models/__init__.py -------------------------------------------------------------------------------- /src/classfier_tuning/src/models/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from src.models import utils 8 | from src.datasets.common import get_dataloader, maybe_dictionarize 9 | 10 | import src.datasets as datasets 11 | 12 | 13 | def eval_single_dataset(image_classifier, dataset, args): 14 | if args.freeze_encoder and not args.data_aug_lp: 15 | model = image_classifier.classification_head 16 | input_key = 'features' 17 | image_enc = image_classifier.image_encoder 18 | else: 19 | model = image_classifier 20 | input_key = 'images' 21 | image_enc = None 22 | 23 | model.eval() 24 | dataloader = get_dataloader( 25 | dataset, is_train=False, args=args, image_encoder=image_enc) 26 | batched_data = enumerate(dataloader) 27 | device = args.device 28 | 29 | if hasattr(dataset, 'post_loop_metrics'): 30 | # keep track of labels, predictions and metadata 31 | all_labels, all_preds, all_metadata = [], [], [] 32 | 33 | with torch.no_grad(): 34 | top1, correct, n = 0., 0., 0. 35 | for i, data in batched_data: 36 | data = maybe_dictionarize(data) 37 | x = data[input_key].to(device) 38 | y = data['labels'].to(device) 39 | 40 | if 'image_paths' in data: 41 | image_paths = data['image_paths'] 42 | 43 | logits = utils.get_logits(x, model) 44 | projection_fn = getattr(dataset, 'project_logits', None) 45 | if projection_fn is not None: 46 | logits = projection_fn(logits, device) 47 | 48 | if hasattr(dataset, 'project_labels'): 49 | y = dataset.project_labels(y, device) 50 | pred = logits.argmax(dim=1, keepdim=True).to(device) 51 | if hasattr(dataset, 'accuracy'): 52 | acc1, num_total = dataset.accuracy(logits, y, image_paths, args) 53 | correct += acc1 54 | n += num_total 55 | else: 56 | correct += pred.eq(y.view_as(pred)).sum().item() 57 | n += y.size(0) 58 | 59 | if hasattr(dataset, 'post_loop_metrics'): 60 | all_labels.append(y.cpu().clone().detach()) 61 | all_preds.append(logits.cpu().clone().detach()) 62 | metadata = data['metadata'] if 'metadata' in data else image_paths 63 | all_metadata.extend(metadata) 64 | 65 | top1 = correct / n 66 | 67 | if hasattr(dataset, 'post_loop_metrics'): 68 | all_labels = torch.cat(all_labels) 69 | all_preds = torch.cat(all_preds) 70 | metrics = dataset.post_loop_metrics(all_labels, all_preds, all_metadata, args) 71 | if 'acc' in metrics: 72 | metrics['top1'] = metrics['acc'] 73 | else: 74 | metrics = {} 75 | if 'top1' not in metrics: 76 | metrics['top1'] = top1 77 | 78 | return metrics 79 | 80 | def evaluate(image_classifier, args): 81 | if args.eval_datasets is None: 82 | return 83 | info = vars(args) 84 | for i, dataset_name in enumerate(args.eval_datasets): 85 | print('Evaluating on', dataset_name) 86 | dataset_class = getattr(datasets, dataset_name) 87 | dataset = dataset_class( 88 | image_classifier.val_preprocess, 89 | location=args.data_location, 90 | batch_size=args.batch_size 91 | ) 92 | 93 | results = eval_single_dataset(image_classifier, dataset, args) 94 | 95 | if 'top1' in results: 96 | print(f"{dataset_name} Top-1 accuracy: {results['top1']:.4f}") 97 | for key, val in results.items(): 98 | if 'worst' in key or 'f1' in key.lower() or 'pm0' in key: 99 | print(f"{dataset_name} {key}: {val:.4f}") 100 | info[dataset_name + ':' + key] = val 101 | 102 | if args.results_db is not None: 103 | dirname = os.path.dirname(args.results_db) 104 | if dirname: 105 | os.makedirs(dirname, exist_ok=True) 106 | with open(args.results_db, 'a+') as f: 107 | f.write(json.dumps(info) + '\n') 108 | print(f'Results saved to {args.results_db}.') 109 | else: 110 | print('Results not saved (to do so, use --results_db to specify a path).') 111 | 112 | return info -------------------------------------------------------------------------------- /src/classfier_tuning/src/models/modeling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | 4 | import clip.clip as clip 5 | 6 | from src.models import utils 7 | 8 | 9 | class ImageEncoder(torch.nn.Module): 10 | def __init__(self, args, keep_lang=False): 11 | super().__init__() 12 | 13 | self.model, self.train_preprocess = clip.load( 14 | args.model, args.device, jit=False) 15 | self.val_preprocess = copy.deepcopy(self.train_preprocess) 16 | 17 | self.cache_dir = args.cache_dir 18 | 19 | if not keep_lang and hasattr(self.model, 'transformer'): 20 | delattr(self.model, 'transformer') 21 | 22 | def forward(self, images): 23 | assert self.model is not None 24 | return self.model.encode_image(images) 25 | 26 | def save(self, filename): 27 | print(f'Saving image encoder to {filename}') 28 | utils.torch_save(self, filename) 29 | 30 | @classmethod 31 | def load(cls, filename): 32 | print(f'Loading image encoder from {filename}') 33 | return utils.torch_load(filename) 34 | 35 | 36 | class ClassificationHead(torch.nn.Linear): 37 | def __init__(self, normalize, weights, biases=None): 38 | output_size, input_size = weights.shape 39 | super().__init__(input_size, output_size) 40 | self.normalize = normalize 41 | if weights is not None: 42 | self.weight = torch.nn.Parameter(weights.clone()) 43 | if biases is not None: 44 | self.bias = torch.nn.Parameter(biases.clone()) 45 | else: 46 | self.bias = torch.nn.Parameter(torch.zeros_like(self.bias)) 47 | 48 | def forward(self, inputs): 49 | if self.normalize: 50 | inputs = inputs / inputs.norm(dim=-1, keepdim=True) 51 | return super().forward(inputs) 52 | 53 | def save(self, filename): 54 | print(f'Saving classification head to {filename}') 55 | utils.torch_save(self, filename) 56 | 57 | @classmethod 58 | def load(cls, filename): 59 | print(f'Loading classification head from {filename}') 60 | return utils.torch_load(filename) 61 | 62 | 63 | class ImageClassifier(torch.nn.Module): 64 | def __init__(self, image_encoder, classification_head, process_images=True): 65 | super().__init__() 66 | self.image_encoder = image_encoder 67 | self.classification_head = classification_head 68 | self.process_images = process_images 69 | if self.image_encoder is not None: 70 | self.train_preprocess = self.image_encoder.train_preprocess 71 | self.val_preprocess = self.image_encoder.val_preprocess 72 | 73 | def forward(self, inputs): 74 | if self.process_images: 75 | inputs = self.image_encoder(inputs) 76 | outputs = self.classification_head(inputs) 77 | return outputs 78 | 79 | def save(self, filename): 80 | print(f'Saving image classifier to {filename}') 81 | utils.torch_save(self, filename) 82 | 83 | @classmethod 84 | def load(cls, filename): 85 | print(f'Loading image classifier from {filename}') 86 | return utils.torch_load(filename) 87 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/models/zeroshot.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | import numpy as np 7 | 8 | import clip.clip as clip 9 | 10 | import src.templates as templates 11 | import src.datasets as datasets 12 | 13 | from src.args import parse_arguments 14 | from src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier 15 | from src.models.eval import evaluate 16 | 17 | 18 | def get_classnames_zs(args, clip_model): 19 | assert args.template is not None 20 | assert args.train_dataset is not None 21 | template = getattr(templates, args.template) 22 | logit_scale = clip_model.logit_scale 23 | dataset_class = getattr(datasets, args.train_dataset) 24 | dataset = dataset_class( 25 | None, 26 | location=args.data_location, 27 | batch_size=args.batch_size, 28 | classnames=args.classnames 29 | ) 30 | classnames = dataset.classnames 31 | return classnames 32 | 33 | 34 | 35 | def get_zeroshot_classifier(args, clip_model): 36 | assert args.template is not None 37 | assert args.train_dataset is not None 38 | template = getattr(templates, args.template) 39 | print('Template', template) 40 | logit_scale = clip_model.logit_scale 41 | dataset_class = getattr(datasets, args.train_dataset) 42 | dataset = dataset_class( 43 | None, 44 | location=args.data_location, 45 | batch_size=args.batch_size, 46 | classnames=args.classnames 47 | ) 48 | device = args.device 49 | clip_model.eval() 50 | clip_model.to(device) 51 | 52 | print('Getting zeroshot weights.') 53 | with torch.no_grad(): 54 | zeroshot_weights = [] 55 | 56 | for classname in tqdm(dataset.classnames): 57 | texts = [] 58 | for t in template: 59 | texts.append(t(classname)) 60 | texts = clip.tokenize(texts).to(device) # tokenize 61 | embeddings = clip_model.encode_text(texts) # embed with text encoder 62 | embeddings /= embeddings.norm(dim=-1, keepdim=True) 63 | 64 | embeddings = embeddings.mean(dim=0, keepdim=True) 65 | embeddings /= embeddings.norm() 66 | 67 | zeroshot_weights.append(embeddings) 68 | 69 | zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device) 70 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2) 71 | 72 | zeroshot_weights *= logit_scale.exp() 73 | 74 | zeroshot_weights = zeroshot_weights.squeeze().float() 75 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1) 76 | # import ipdb 77 | # ipdb.set_trace(context=20) 78 | classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights) 79 | 80 | return classification_head 81 | 82 | 83 | def eval(args): 84 | args.freeze_encoder = True 85 | if args.load is not None: 86 | classifier = ImageClassifier.load(args.load) 87 | else: 88 | image_encoder = ImageEncoder(args, keep_lang=True) 89 | classification_head = get_zeroshot_classifier(args, image_encoder.model) 90 | delattr(image_encoder.model, 'transformer') 91 | classifier = ImageClassifier(image_encoder, classification_head, process_images=False) 92 | 93 | evaluate(classifier, args) 94 | 95 | if args.save is not None: 96 | classifier.save(args.save) 97 | 98 | 99 | if __name__ == '__main__': 100 | args = parse_arguments() 101 | eval(args) -------------------------------------------------------------------------------- /src/classfier_tuning/src/models/zeroshot_retina.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | import numpy as np 7 | 8 | import clip.clip as clip 9 | 10 | import src.templates as templates 11 | import src.datasets as datasets 12 | 13 | from src.args import parse_arguments 14 | from src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier 15 | from src.models.eval import evaluate 16 | 17 | 18 | def get_zeroshot_classifier(args, clip_model, classnames=None): 19 | assert args.template is not None 20 | assert args.train_dataset is not None 21 | 22 | template = [lambda c: f"a photo of retina with disease {c}."] 23 | logit_scale = clip_model.logit_scale 24 | 25 | device = args.device 26 | clip_model.eval() 27 | clip_model.to(device) 28 | 29 | print('Getting zeroshot weights.') 30 | with torch.no_grad(): 31 | zeroshot_weights = [] 32 | 33 | for classname in tqdm(classnames): 34 | texts = [] 35 | for t in template: 36 | texts.append(t(classname)) 37 | texts = clip.tokenize(texts).to(device) # tokenize 38 | embeddings = clip_model.encode_text(texts) # embed with text encoder 39 | embeddings /= embeddings.norm(dim=-1, keepdim=True) 40 | 41 | embeddings = embeddings.mean(dim=0, keepdim=True) 42 | embeddings /= embeddings.norm() 43 | 44 | zeroshot_weights.append(embeddings) 45 | 46 | zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device) 47 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2) 48 | 49 | zeroshot_weights *= logit_scale.exp() 50 | 51 | zeroshot_weights = zeroshot_weights.squeeze().float() 52 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1) 53 | # import ipdb 54 | # ipdb.set_trace(context=20) 55 | classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights) 56 | 57 | return classification_head 58 | 59 | 60 | def eval(args): 61 | args.freeze_encoder = True 62 | if args.load is not None: 63 | classifier = ImageClassifier.load(args.load) 64 | else: 65 | image_encoder = ImageEncoder(args, keep_lang=True) 66 | classification_head = get_zeroshot_classifier(args, image_encoder.model) 67 | delattr(image_encoder.model, 'transformer') 68 | classifier = ImageClassifier(image_encoder, classification_head, process_images=False) 69 | 70 | evaluate(classifier, args) 71 | 72 | if args.save is not None: 73 | classifier.save(args.save) 74 | 75 | 76 | if __name__ == '__main__': 77 | args = parse_arguments() 78 | eval(args) -------------------------------------------------------------------------------- /src/classfier_tuning/src/templates/__init__.py: -------------------------------------------------------------------------------- 1 | from .openai_imagenet_template import openai_imagenet_template 2 | from .simple_template import simple_template 3 | from .fmow_template import fmow_template 4 | from .iwildcam_template import iwildcam_template 5 | from .transfer_ds_template import * -------------------------------------------------------------------------------- /src/classfier_tuning/src/templates/fmow_template.py: -------------------------------------------------------------------------------- 1 | from .utils import append_proper_article, get_plural 2 | 3 | fmow_template = [ 4 | lambda c : f"satellite photo of a {c}.", 5 | lambda c : f"aerial photo of a {c}.", 6 | lambda c : f"satellite photo of {append_proper_article(c)}.", 7 | lambda c : f"aerial photo of {append_proper_article(c)}.", 8 | lambda c : f"satellite photo of a {c} in asia.", 9 | lambda c : f"aerial photo of a {c} in asia.", 10 | lambda c : f"satellite photo of a {c} in africa.", 11 | lambda c : f"aerial photo of a {c} in africa.", 12 | lambda c : f"satellite photo of a {c} in the americas.", 13 | lambda c : f"aerial photo of a {c} in the americas.", 14 | lambda c : f"satellite photo of a {c} in europe.", 15 | lambda c : f"aerial photo of a {c} in europe.", 16 | lambda c : f"satellite photo of a {c} in oceania.", 17 | lambda c : f"aerial photo of a {c} in oceania.", 18 | lambda c: f"a photo of a {c}.", 19 | lambda c: f"{c}.", 20 | ] 21 | -------------------------------------------------------------------------------- /src/classfier_tuning/src/templates/iwildcam_template.py: -------------------------------------------------------------------------------- 1 | from .utils import append_proper_article, get_plural 2 | 3 | iwildcam_template = [ 4 | lambda c: f"a photo of {c}.", 5 | lambda c: f"{c} in the wild.", 6 | ] -------------------------------------------------------------------------------- /src/classfier_tuning/src/templates/openai_imagenet_template.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | openai_imagenet_template = [ 5 | lambda c: f'a bad photo of a {c}.', 6 | lambda c: f'a photo of many {c}.', 7 | lambda c: f'a sculpture of a {c}.', 8 | lambda c: f'a photo of the hard to see {c}.', 9 | lambda c: f'a low resolution photo of the {c}.', 10 | lambda c: f'a rendering of a {c}.', 11 | lambda c: f'graffiti of a {c}.', 12 | lambda c: f'a bad photo of the {c}.', 13 | lambda c: f'a cropped photo of the {c}.', 14 | lambda c: f'a tattoo of a {c}.', 15 | lambda c: f'the embroidered {c}.', 16 | lambda c: f'a photo of a hard to see {c}.', 17 | lambda c: f'a bright photo of a {c}.', 18 | lambda c: f'a photo of a clean {c}.', 19 | lambda c: f'a photo of a dirty {c}.', 20 | lambda c: f'a dark photo of the {c}.', 21 | lambda c: f'a drawing of a {c}.', 22 | lambda c: f'a photo of my {c}.', 23 | lambda c: f'the plastic {c}.', 24 | lambda c: f'a photo of the cool {c}.', 25 | lambda c: f'a close-up photo of a {c}.', 26 | lambda c: f'a black and white photo of the {c}.', 27 | lambda c: f'a painting of the {c}.', 28 | lambda c: f'a painting of a {c}.', 29 | lambda c: f'a pixelated photo of the {c}.', 30 | lambda c: f'a sculpture of the {c}.', 31 | lambda c: f'a bright photo of the {c}.', 32 | lambda c: f'a cropped photo of a {c}.', 33 | lambda c: f'a plastic {c}.', 34 | lambda c: f'a photo of the dirty {c}.', 35 | lambda c: f'a jpeg corrupted photo of a {c}.', 36 | lambda c: f'a blurry photo of the {c}.', 37 | lambda c: f'a photo of the {c}.', 38 | lambda c: f'a good photo of the {c}.', 39 | lambda c: f'a rendering of the {c}.', 40 | lambda c: f'a {c} in a video game.', 41 | lambda c: f'a photo of one {c}.', 42 | lambda c: f'a doodle of a {c}.', 43 | lambda c: f'a close-up photo of the {c}.', 44 | lambda c: f'a photo of a {c}.', 45 | lambda c: f'the origami {c}.', 46 | lambda c: f'the {c} in a video game.', 47 | lambda c: f'a sketch of a {c}.', 48 | lambda c: f'a doodle of the {c}.', 49 | lambda c: f'a origami {c}.', 50 | lambda c: f'a low resolution photo of a {c}.', 51 | lambda c: f'the toy {c}.', 52 | lambda c: f'a rendition of the {c}.', 53 | lambda c: f'a photo of the clean {c}.', 54 | lambda c: f'a photo of a large {c}.', 55 | lambda c: f'a rendition of a {c}.', 56 | lambda c: f'a photo of a nice {c}.', 57 | lambda c: f'a photo of a weird {c}.', 58 | lambda c: f'a blurry photo of a {c}.', 59 | lambda c: f'a cartoon {c}.', 60 | lambda c: f'art of a {c}.', 61 | lambda c: f'a sketch of the {c}.', 62 | lambda c: f'a embroidered {c}.', 63 | lambda c: f'a pixelated photo of a {c}.', 64 | lambda c: f'itap of the {c}.', 65 | lambda c: f'a jpeg corrupted photo of the {c}.', 66 | lambda c: f'a good photo of a {c}.', 67 | lambda c: f'a plushie {c}.', 68 | lambda c: f'a photo of the nice {c}.', 69 | lambda c: f'a photo of the small {c}.', 70 | lambda c: f'a photo of the weird {c}.', 71 | lambda c: f'the cartoon {c}.', 72 | lambda c: f'art of the {c}.', 73 | lambda c: f'a drawing of the {c}.', 74 | lambda c: f'a photo of the large {c}.', 75 | lambda c: f'a black and white photo of a {c}.', 76 | lambda c: f'the plushie {c}.', 77 | lambda c: f'a dark photo of a {c}.', 78 | lambda c: f'itap of a {c}.', 79 | lambda c: f'graffiti of the {c}.', 80 | lambda c: f'a toy {c}.', 81 | lambda c: f'itap of my {c}.', 82 | lambda c: f'a photo of a cool {c}.', 83 | lambda c: f'a photo of a small {c}.', 84 | lambda c: f'a tattoo of the {c}.', 85 | ] -------------------------------------------------------------------------------- /src/classfier_tuning/src/templates/simple_template.py: -------------------------------------------------------------------------------- 1 | from src.templates.utils import append_proper_article 2 | 3 | simple_template = [ 4 | lambda c: f"a photo of a {c}." 5 | # lambda c: f"a sketch of a {c}." 6 | ] -------------------------------------------------------------------------------- /src/classfier_tuning/src/templates/transfer_ds_template.py: -------------------------------------------------------------------------------- 1 | from src.templates.utils import append_proper_article 2 | 3 | aircraft_template = [ 4 | lambda c: f"a photo of a {c}, a type of aircraft.", 5 | # lambda c: f"a photo of the {c}, a type of aircraft." 6 | ] 7 | 8 | birds_template = [ 9 | lambda c: f"a photo of a {c}, a type of bird." 10 | ] 11 | 12 | eurosat_template = [ 13 | lambda c: f"a centered satellite photo of {c}." 14 | ] 15 | 16 | # eurosat_template = [ 17 | # lambda c: f"a centered satellite photo of {c}.", 18 | # lambda c: f'a centered satellite photo of a {c}.', 19 | # lambda c: f'a centered satellite photo of the {c}.', 20 | # ] 21 | 22 | 23 | flowers_template = [ 24 | lambda c: f"a photo of a {c}, a type of flower." 25 | ] 26 | 27 | food_template = [ 28 | lambda c: f"a photo of a {c}, a type of food." 29 | ] 30 | 31 | pets_template = [ 32 | lambda c: f"a photo of a {c}, a type of pet." 33 | ] 34 | 35 | imagenet_template = [ 36 | lambda c: f"itap of a {c}.", 37 | lambda c: f"a bad photo of the {c}.", 38 | lambda c: f"a origami {c}.", 39 | lambda c: f"a photo of the large {c}.", 40 | lambda c: f"a {c} in a video game.", 41 | lambda c: f"art of the {c}.", 42 | lambda c: f"a photo of the small {c}."] 43 | 44 | cifar100_template = [ 45 | lambda c: f'a photo of a {c}.', 46 | lambda c: f'a blurry photo of a {c}.', 47 | lambda c: f'a black and white photo of a {c}.', 48 | lambda c: f'a low contrast photo of a {c}.', 49 | lambda c: f'a high contrast photo of a {c}.', 50 | lambda c: f'a bad photo of a {c}.', 51 | lambda c: f'a good photo of a {c}.', 52 | lambda c: f'a photo of a small {c}.', 53 | lambda c: f'a photo of a big {c}.', 54 | lambda c: f'a photo of the {c}.', 55 | lambda c: f'a blurry photo of the {c}.', 56 | lambda c: f'a black and white photo of the {c}.', 57 | lambda c: f'a low contrast photo of the {c}.', 58 | lambda c: f'a high contrast photo of the {c}.', 59 | lambda c: f'a bad photo of the {c}.', 60 | lambda c: f'a good photo of the {c}.', 61 | lambda c: f'a photo of the small {c}.', 62 | lambda c: f'a photo of the big {c}.', 63 | ] 64 | 65 | cifar10_templates = [ 66 | lambda c: f'a photo of a {c}.', 67 | lambda c: f'a blurry photo of a {c}.', 68 | lambda c: f'a black and white photo of a {c}.', 69 | lambda c: f'a low contrast photo of a {c}.', 70 | lambda c: f'a high contrast photo of a {c}.', 71 | lambda c: f'a bad photo of a {c}.', 72 | lambda c: f'a good photo of a {c}.', 73 | lambda c: f'a photo of a small {c}.', 74 | lambda c: f'a photo of a big {c}.', 75 | lambda c: f'a photo of the {c}.', 76 | lambda c: f'a blurry photo of the {c}.', 77 | lambda c: f'a black and white photo of the {c}.', 78 | lambda c: f'a low contrast photo of the {c}.', 79 | lambda c: f'a high contrast photo of the {c}.', 80 | lambda c: f'a bad photo of the {c}.', 81 | lambda c: f'a good photo of the {c}.', 82 | lambda c: f'a photo of the small {c}.', 83 | lambda c: f'a photo of the big {c}.', 84 | ] 85 | 86 | sun_template = [ 87 | lambda c: f'a photo of a {c}.', 88 | lambda c: f'a photo of the {c}.', 89 | ] 90 | 91 | cars_template = [ 92 | lambda c: f'a photo of a {c}.', 93 | lambda c: f'a photo of the {c}.', 94 | lambda c: f'a photo of my {c}.', 95 | lambda c: f'i love my {c}!', 96 | lambda c: f'a photo of my dirty {c}.', 97 | lambda c: f'a photo of my clean {c}.', 98 | lambda c: f'a photo of my new {c}.', 99 | lambda c: f'a photo of my old {c}.', 100 | ] 101 | 102 | dtd_template = [ 103 | lambda c: f"{c} texture." 104 | ] 105 | 106 | # dtd_template = [ 107 | # lambda c: f'a photo of a {c} texture.', 108 | # lambda c: f'a photo of a {c} pattern.', 109 | # lambda c: f'a photo of a {c} thing.', 110 | # lambda c: f'a photo of a {c} object.', 111 | # lambda c: f'a photo of the {c} texture.', 112 | # lambda c: f'a photo of the {c} pattern.', 113 | # lambda c: f'a photo of the {c} thing.', 114 | # lambda c: f'a photo of the {c} object.', 115 | # ] -------------------------------------------------------------------------------- /src/classfier_tuning/src/templates/utils.py: -------------------------------------------------------------------------------- 1 | 2 | def get_plural(name): 3 | name = name.replace('_', ' ') 4 | if name[-2:] == 'sh': 5 | name = name + 'es' 6 | elif name[-2:] == 'ch': 7 | name = name + 'es' 8 | elif name[-1:] == 'y': 9 | name = name[:-1] + 'ies' 10 | elif name[-1:] == 's': 11 | name = name + 'es' 12 | elif name[-1:] == 'x': 13 | name = name + 'es' 14 | elif name[-3:] == 'man': 15 | name = name[:-3] + 'men' 16 | elif name == 'mouse': 17 | name = 'mice' 18 | elif name[-1:] == 'f': 19 | name = name[:-1] + 'ves' 20 | else: 21 | name = name + 's' 22 | return name 23 | 24 | 25 | def append_proper_article(name): 26 | name = name.replace('_', ' ') 27 | if name[0] in 'aeiou': 28 | return 'an ' + name 29 | return 'a ' + name 30 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from absl import logging 4 | 5 | import tqdm 6 | import numpy as np 7 | import tensorflow as tf 8 | import tensorflow_datasets as tfds 9 | 10 | 11 | def center_crop(x, resolution): 12 | shape = tf.shape(x) 13 | h, w = shape[0], shape[1] 14 | size = tf.minimum(h, w) 15 | begin = tf.cast([h - size, w - size], tf.float32) / 2.0 16 | begin = tf.cast(begin, tf.int32) 17 | begin = tf.concat([begin, [0]], axis=0) # Add channel dimension. 18 | x = tf.slice(x, begin, [size, size, 3]) 19 | x = tf.image.resize_with_pad( 20 | x, resolution, resolution, method='area', antialias=True) 21 | return x 22 | 23 | 24 | def load_data(ds, img_shape, resolution=32): 25 | if resolution <= 64: 26 | batch_size = 5000 27 | else: 28 | batch_size = 1000 29 | 30 | size = len(ds) 31 | logging.info('Dataset size: {}'.format(size)) 32 | if None in img_shape: 33 | x = np.zeros(shape=(size, resolution, resolution, 3), dtype=np.uint8) 34 | else: 35 | x = np.zeros( 36 | shape=(size, img_shape[0], img_shape[1], img_shape[2]), dtype=np.uint8) 37 | 38 | if None in img_shape: 39 | ds = ds.map(lambda x, y: (center_crop( 40 | x, resolution), y), tf.data.AUTOTUNE) 41 | ds = ds.batch(batch_size=batch_size) 42 | ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE) 43 | 44 | y_list = [] 45 | count = 0 46 | for x_batch, y_batch in tqdm.tqdm(tfds.as_numpy(ds), desc='Process the data'): 47 | num = x_batch.shape[0] 48 | x_processed = np.array(x_batch) 49 | x[count:count + num] = x_processed 50 | y_list.append(y_batch) 51 | count += num 52 | 53 | return x, np.concatenate(y_list, axis=0) 54 | 55 | 56 | def configure_dataloader(ds, batch_size, x_transform=None, y_transform=None, train=False, shuffle=False, seed=0, resolution=None): 57 | if y_transform is None: 58 | def y_transform(x): return x 59 | else: 60 | y_transform = y_transform 61 | 62 | ds = ds.cache() 63 | if train: 64 | ds = ds.repeat() 65 | if shuffle: 66 | ds = ds.shuffle(16 * batch_size, seed=seed) 67 | 68 | if resolution is not None: 69 | ds = ds.map(lambda x, y: (tf.clip_by_value( 70 | tf.image.resize(x, [resolution, resolution], 'bilinear'), 71 | 0, 255), y), tf.data.AUTOTUNE) 72 | 73 | ds = ds.map(lambda x, y: (tf.cast(x, tf.float32), y), tf.data.AUTOTUNE) 74 | 75 | if x_transform: 76 | ds = ds.map(lambda x, y: (x_transform( 77 | x), y_transform(y)), tf.data.AUTOTUNE) 78 | else: 79 | ds = ds.map(lambda x, y: x, y_transform(y), tf.data.AUTOTUNE) 80 | 81 | ds = ds.batch(batch_size=batch_size) 82 | ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE) 83 | return ds 84 | 85 | 86 | def get_dataset(config, return_raw=False, resolution=None, train_only=True): 87 | dataset_name = config.name 88 | data_dir = config.data_dir 89 | 90 | if dataset_name in ['imagenette']: 91 | split = ['train', 'validation'] 92 | else: 93 | split = ['train', 'test'] 94 | 95 | if resolution is None: 96 | if dataset_name in ['cifar10', 'cifar100']: 97 | resolution = 32 98 | elif dataset_name in ['stl10']: 99 | resolution = 96 100 | elif dataset_name in ['imagenette']: 101 | resolution = 256 102 | 103 | ds_builder = tfds.builder(dataset_name, data_dir=data_dir) 104 | 105 | ds_builder.download_and_prepare() 106 | 107 | img_shape = ds_builder.info.features['image'].shape 108 | num_train, num_test = ds_builder.info.splits[split[0] 109 | ].num_examples, ds_builder.info.splits[split[1]].num_examples 110 | num_classes, class_names = ds_builder.info.features[ 111 | 'label'].num_classes, ds_builder.info.features['label'].names 112 | 113 | ds_train, ds_test = ds_builder.as_dataset(split=split, as_supervised=True) 114 | 115 | print('Number of training samples: {}'.format(num_train)) 116 | print('Number of test samples: {}'.format(num_test)) 117 | sys.stdout.flush() 118 | 119 | with config.unlocked(): 120 | config.img_shape = (resolution, resolution, 121 | 3) if None in img_shape else img_shape 122 | config.num_classes = num_classes 123 | config.class_names = class_names 124 | config.train_size = num_train 125 | config.test_size = num_test 126 | 127 | x_train, y_train = load_data(ds_train, img_shape, resolution) 128 | 129 | if train_only: 130 | return x_train, y_train 131 | 132 | x_test, y_test = load_data(ds_test, img_shape, resolution) 133 | 134 | if return_raw: 135 | return x_train, y_train, x_test, y_test 136 | else: 137 | ds_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)) 138 | ds_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)) 139 | return ds_train, ds_test -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * 2 | from .dpn import * 3 | from .lenet import * 4 | from .senet import * 5 | from .pnasnet import * 6 | from .densenet import * 7 | from .googlenet import * 8 | from .shufflenet import * 9 | from .shufflenetv2 import * 10 | from .resnet import * 11 | from .resnext import * 12 | from .preact_resnet import * 13 | from .mobilenet import * 14 | from .mobilenetv2 import * 15 | from .efficientnet import * 16 | from .regnet import * 17 | from .dla_simple import * 18 | from .dla import * 19 | -------------------------------------------------------------------------------- /src/models/densenet.py: -------------------------------------------------------------------------------- 1 | '''DenseNet in PyTorch.''' 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | def __init__(self, in_planes, growth_rate): 11 | super(Bottleneck, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 15 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 16 | 17 | def forward(self, x): 18 | out = self.conv1(F.relu(self.bn1(x))) 19 | out = self.conv2(F.relu(self.bn2(out))) 20 | out = torch.cat([out,x], 1) 21 | return out 22 | 23 | 24 | class Transition(nn.Module): 25 | def __init__(self, in_planes, out_planes): 26 | super(Transition, self).__init__() 27 | self.bn = nn.BatchNorm2d(in_planes) 28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 29 | 30 | def forward(self, x): 31 | out = self.conv(F.relu(self.bn(x))) 32 | out = F.avg_pool2d(out, 2) 33 | return out 34 | 35 | 36 | class DenseNet(nn.Module): 37 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 38 | super(DenseNet, self).__init__() 39 | self.growth_rate = growth_rate 40 | 41 | num_planes = 2*growth_rate 42 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 43 | 44 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 45 | num_planes += nblocks[0]*growth_rate 46 | out_planes = int(math.floor(num_planes*reduction)) 47 | self.trans1 = Transition(num_planes, out_planes) 48 | num_planes = out_planes 49 | 50 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 51 | num_planes += nblocks[1]*growth_rate 52 | out_planes = int(math.floor(num_planes*reduction)) 53 | self.trans2 = Transition(num_planes, out_planes) 54 | num_planes = out_planes 55 | 56 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 57 | num_planes += nblocks[2]*growth_rate 58 | out_planes = int(math.floor(num_planes*reduction)) 59 | self.trans3 = Transition(num_planes, out_planes) 60 | num_planes = out_planes 61 | 62 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 63 | num_planes += nblocks[3]*growth_rate 64 | 65 | self.bn = nn.BatchNorm2d(num_planes) 66 | self.linear = nn.Linear(num_planes, num_classes) 67 | 68 | def _make_dense_layers(self, block, in_planes, nblock): 69 | layers = [] 70 | for i in range(nblock): 71 | layers.append(block(in_planes, self.growth_rate)) 72 | in_planes += self.growth_rate 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = self.conv1(x) 77 | out = self.trans1(self.dense1(out)) 78 | out = self.trans2(self.dense2(out)) 79 | out = self.trans3(self.dense3(out)) 80 | out = self.dense4(out) 81 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 82 | out = out.view(out.size(0), -1) 83 | out = self.linear(out) 84 | return out 85 | 86 | def DenseNet121(): 87 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32) 88 | 89 | def DenseNet169(): 90 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32) 91 | 92 | def DenseNet201(): 93 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32) 94 | 95 | def DenseNet161(): 96 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48) 97 | 98 | def densenet_cifar(): 99 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12) 100 | 101 | def test(): 102 | net = densenet_cifar() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /src/models/dla.py: -------------------------------------------------------------------------------- 1 | '''DLA in PyTorch. 2 | 3 | Reference: 4 | Deep Layer Aggregation. https://arxiv.org/abs/1707.06484 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, in_planes, planes, stride=1): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = nn.Conv2d( 17 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 20 | stride=1, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.shortcut = nn.Sequential() 24 | if stride != 1 or in_planes != self.expansion*planes: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv2d(in_planes, self.expansion*planes, 27 | kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | out = F.relu(out) 36 | return out 37 | 38 | 39 | class Root(nn.Module): 40 | def __init__(self, in_channels, out_channels, kernel_size=1): 41 | super(Root, self).__init__() 42 | self.conv = nn.Conv2d( 43 | in_channels, out_channels, kernel_size, 44 | stride=1, padding=(kernel_size - 1) // 2, bias=False) 45 | self.bn = nn.BatchNorm2d(out_channels) 46 | 47 | def forward(self, xs): 48 | x = torch.cat(xs, 1) 49 | out = F.relu(self.bn(self.conv(x))) 50 | return out 51 | 52 | 53 | class Tree(nn.Module): 54 | def __init__(self, block, in_channels, out_channels, level=1, stride=1): 55 | super(Tree, self).__init__() 56 | self.level = level 57 | if level == 1: 58 | self.root = Root(2*out_channels, out_channels) 59 | self.left_node = block(in_channels, out_channels, stride=stride) 60 | self.right_node = block(out_channels, out_channels, stride=1) 61 | else: 62 | self.root = Root((level+2)*out_channels, out_channels) 63 | for i in reversed(range(1, level)): 64 | subtree = Tree(block, in_channels, out_channels, 65 | level=i, stride=stride) 66 | self.__setattr__('level_%d' % i, subtree) 67 | self.prev_root = block(in_channels, out_channels, stride=stride) 68 | self.left_node = block(out_channels, out_channels, stride=1) 69 | self.right_node = block(out_channels, out_channels, stride=1) 70 | 71 | def forward(self, x): 72 | xs = [self.prev_root(x)] if self.level > 1 else [] 73 | for i in reversed(range(1, self.level)): 74 | level_i = self.__getattr__('level_%d' % i) 75 | x = level_i(x) 76 | xs.append(x) 77 | x = self.left_node(x) 78 | xs.append(x) 79 | x = self.right_node(x) 80 | xs.append(x) 81 | out = self.root(xs) 82 | return out 83 | 84 | 85 | class DLA(nn.Module): 86 | def __init__(self, block=BasicBlock, num_classes=10): 87 | super(DLA, self).__init__() 88 | self.base = nn.Sequential( 89 | nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False), 90 | nn.BatchNorm2d(16), 91 | nn.ReLU(True) 92 | ) 93 | 94 | self.layer1 = nn.Sequential( 95 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False), 96 | nn.BatchNorm2d(16), 97 | nn.ReLU(True) 98 | ) 99 | 100 | self.layer2 = nn.Sequential( 101 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False), 102 | nn.BatchNorm2d(32), 103 | nn.ReLU(True) 104 | ) 105 | 106 | self.layer3 = Tree(block, 32, 64, level=1, stride=1) 107 | self.layer4 = Tree(block, 64, 128, level=2, stride=2) 108 | self.layer5 = Tree(block, 128, 256, level=2, stride=2) 109 | self.layer6 = Tree(block, 256, 512, level=1, stride=2) 110 | self.linear = nn.Linear(512, num_classes) 111 | 112 | def forward(self, x): 113 | out = self.base(x) 114 | out = self.layer1(out) 115 | out = self.layer2(out) 116 | out = self.layer3(out) 117 | out = self.layer4(out) 118 | out = self.layer5(out) 119 | out = self.layer6(out) 120 | out = F.avg_pool2d(out, 4) 121 | out = out.view(out.size(0), -1) 122 | out = self.linear(out) 123 | return out 124 | 125 | 126 | def test(): 127 | net = DLA() 128 | print(net) 129 | x = torch.randn(1, 3, 32, 32) 130 | y = net(x) 131 | print(y.size()) 132 | 133 | 134 | if __name__ == '__main__': 135 | test() 136 | -------------------------------------------------------------------------------- /src/models/dla_simple.py: -------------------------------------------------------------------------------- 1 | '''Simplified version of DLA in PyTorch. 2 | 3 | Note this implementation is not identical to the original paper version. 4 | But it seems works fine. 5 | 6 | See dla.py for the original paper version. 7 | 8 | Reference: 9 | Deep Layer Aggregation. https://arxiv.org/abs/1707.06484 10 | ''' 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = nn.Conv2d( 22 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 25 | stride=1, padding=1, bias=False) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | 28 | self.shortcut = nn.Sequential() 29 | if stride != 1 or in_planes != self.expansion*planes: 30 | self.shortcut = nn.Sequential( 31 | nn.Conv2d(in_planes, self.expansion*planes, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(self.expansion*planes) 34 | ) 35 | 36 | def forward(self, x): 37 | out = F.relu(self.bn1(self.conv1(x))) 38 | out = self.bn2(self.conv2(out)) 39 | out += self.shortcut(x) 40 | out = F.relu(out) 41 | return out 42 | 43 | 44 | class Root(nn.Module): 45 | def __init__(self, in_channels, out_channels, kernel_size=1): 46 | super(Root, self).__init__() 47 | self.conv = nn.Conv2d( 48 | in_channels, out_channels, kernel_size, 49 | stride=1, padding=(kernel_size - 1) // 2, bias=False) 50 | self.bn = nn.BatchNorm2d(out_channels) 51 | 52 | def forward(self, xs): 53 | x = torch.cat(xs, 1) 54 | out = F.relu(self.bn(self.conv(x))) 55 | return out 56 | 57 | 58 | class Tree(nn.Module): 59 | def __init__(self, block, in_channels, out_channels, level=1, stride=1): 60 | super(Tree, self).__init__() 61 | self.root = Root(2*out_channels, out_channels) 62 | if level == 1: 63 | self.left_tree = block(in_channels, out_channels, stride=stride) 64 | self.right_tree = block(out_channels, out_channels, stride=1) 65 | else: 66 | self.left_tree = Tree(block, in_channels, 67 | out_channels, level=level-1, stride=stride) 68 | self.right_tree = Tree(block, out_channels, 69 | out_channels, level=level-1, stride=1) 70 | 71 | def forward(self, x): 72 | out1 = self.left_tree(x) 73 | out2 = self.right_tree(out1) 74 | out = self.root([out1, out2]) 75 | return out 76 | 77 | 78 | class SimpleDLA(nn.Module): 79 | def __init__(self, block=BasicBlock, num_classes=10): 80 | super(SimpleDLA, self).__init__() 81 | self.base = nn.Sequential( 82 | nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False), 83 | nn.BatchNorm2d(16), 84 | nn.ReLU(True) 85 | ) 86 | 87 | self.layer1 = nn.Sequential( 88 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False), 89 | nn.BatchNorm2d(16), 90 | nn.ReLU(True) 91 | ) 92 | 93 | self.layer2 = nn.Sequential( 94 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False), 95 | nn.BatchNorm2d(32), 96 | nn.ReLU(True) 97 | ) 98 | 99 | self.layer3 = Tree(block, 32, 64, level=1, stride=1) 100 | self.layer4 = Tree(block, 64, 128, level=2, stride=2) 101 | self.layer5 = Tree(block, 128, 256, level=2, stride=2) 102 | self.layer6 = Tree(block, 256, 512, level=1, stride=2) 103 | self.linear = nn.Linear(512, num_classes) 104 | 105 | def forward(self, x): 106 | out = self.base(x) 107 | out = self.layer1(out) 108 | out = self.layer2(out) 109 | out = self.layer3(out) 110 | out = self.layer4(out) 111 | out = self.layer5(out) 112 | out = self.layer6(out) 113 | out = F.avg_pool2d(out, 4) 114 | out = out.view(out.size(0), -1) 115 | out = self.linear(out) 116 | return out 117 | 118 | 119 | def test(): 120 | net = SimpleDLA() 121 | print(net) 122 | x = torch.randn(1, 3, 32, 32) 123 | y = net(x) 124 | print(y.size()) 125 | 126 | 127 | if __name__ == '__main__': 128 | test() 129 | -------------------------------------------------------------------------------- /src/models/dpn.py: -------------------------------------------------------------------------------- 1 | '''Dual Path Networks in PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Bottleneck(nn.Module): 8 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer): 9 | super(Bottleneck, self).__init__() 10 | self.out_planes = out_planes 11 | self.dense_depth = dense_depth 12 | 13 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(in_planes) 15 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False) 16 | self.bn2 = nn.BatchNorm2d(in_planes) 17 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False) 18 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth) 19 | 20 | self.shortcut = nn.Sequential() 21 | if first_layer: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(out_planes+dense_depth) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = F.relu(self.bn2(self.conv2(out))) 30 | out = self.bn3(self.conv3(out)) 31 | x = self.shortcut(x) 32 | d = self.out_planes 33 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1) 34 | out = F.relu(out) 35 | return out 36 | 37 | 38 | class DPN(nn.Module): 39 | def __init__(self, cfg): 40 | super(DPN, self).__init__() 41 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes'] 42 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth'] 43 | 44 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(64) 46 | self.last_planes = 64 47 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1) 48 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2) 49 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2) 50 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2) 51 | self.linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], 10) 52 | 53 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride): 54 | strides = [stride] + [1]*(num_blocks-1) 55 | layers = [] 56 | for i,stride in enumerate(strides): 57 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0)) 58 | self.last_planes = out_planes + (i+2) * dense_depth 59 | return nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(self.conv1(x))) 63 | out = self.layer1(out) 64 | out = self.layer2(out) 65 | out = self.layer3(out) 66 | out = self.layer4(out) 67 | out = F.avg_pool2d(out, 4) 68 | out = out.view(out.size(0), -1) 69 | out = self.linear(out) 70 | return out 71 | 72 | 73 | def DPN26(): 74 | cfg = { 75 | 'in_planes': (96,192,384,768), 76 | 'out_planes': (256,512,1024,2048), 77 | 'num_blocks': (2,2,2,2), 78 | 'dense_depth': (16,32,24,128) 79 | } 80 | return DPN(cfg) 81 | 82 | def DPN92(): 83 | cfg = { 84 | 'in_planes': (96,192,384,768), 85 | 'out_planes': (256,512,1024,2048), 86 | 'num_blocks': (3,4,20,3), 87 | 'dense_depth': (16,32,24,128) 88 | } 89 | return DPN(cfg) 90 | 91 | 92 | def test(): 93 | net = DPN92() 94 | x = torch.randn(1,3,32,32) 95 | y = net(x) 96 | print(y) 97 | 98 | # test() 99 | -------------------------------------------------------------------------------- /src/models/efficientnet.py: -------------------------------------------------------------------------------- 1 | '''EfficientNet in PyTorch. 2 | 3 | Paper: "EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks". 4 | 5 | Reference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | def swish(x): 13 | return x * x.sigmoid() 14 | 15 | 16 | def drop_connect(x, drop_ratio): 17 | keep_ratio = 1.0 - drop_ratio 18 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device) 19 | mask.bernoulli_(keep_ratio) 20 | x.div_(keep_ratio) 21 | x.mul_(mask) 22 | return x 23 | 24 | 25 | class SE(nn.Module): 26 | '''Squeeze-and-Excitation block with Swish.''' 27 | 28 | def __init__(self, in_channels, se_channels): 29 | super(SE, self).__init__() 30 | self.se1 = nn.Conv2d(in_channels, se_channels, 31 | kernel_size=1, bias=True) 32 | self.se2 = nn.Conv2d(se_channels, in_channels, 33 | kernel_size=1, bias=True) 34 | 35 | def forward(self, x): 36 | out = F.adaptive_avg_pool2d(x, (1, 1)) 37 | out = swish(self.se1(out)) 38 | out = self.se2(out).sigmoid() 39 | out = x * out 40 | return out 41 | 42 | 43 | class Block(nn.Module): 44 | '''expansion + depthwise + pointwise + squeeze-excitation''' 45 | 46 | def __init__(self, 47 | in_channels, 48 | out_channels, 49 | kernel_size, 50 | stride, 51 | expand_ratio=1, 52 | se_ratio=0., 53 | drop_rate=0.): 54 | super(Block, self).__init__() 55 | self.stride = stride 56 | self.drop_rate = drop_rate 57 | self.expand_ratio = expand_ratio 58 | 59 | # Expansion 60 | channels = expand_ratio * in_channels 61 | self.conv1 = nn.Conv2d(in_channels, 62 | channels, 63 | kernel_size=1, 64 | stride=1, 65 | padding=0, 66 | bias=False) 67 | self.bn1 = nn.BatchNorm2d(channels) 68 | 69 | # Depthwise conv 70 | self.conv2 = nn.Conv2d(channels, 71 | channels, 72 | kernel_size=kernel_size, 73 | stride=stride, 74 | padding=(1 if kernel_size == 3 else 2), 75 | groups=channels, 76 | bias=False) 77 | self.bn2 = nn.BatchNorm2d(channels) 78 | 79 | # SE layers 80 | se_channels = int(in_channels * se_ratio) 81 | self.se = SE(channels, se_channels) 82 | 83 | # Output 84 | self.conv3 = nn.Conv2d(channels, 85 | out_channels, 86 | kernel_size=1, 87 | stride=1, 88 | padding=0, 89 | bias=False) 90 | self.bn3 = nn.BatchNorm2d(out_channels) 91 | 92 | # Skip connection if in and out shapes are the same (MV-V2 style) 93 | self.has_skip = (stride == 1) and (in_channels == out_channels) 94 | 95 | def forward(self, x): 96 | out = x if self.expand_ratio == 1 else swish(self.bn1(self.conv1(x))) 97 | out = swish(self.bn2(self.conv2(out))) 98 | out = self.se(out) 99 | out = self.bn3(self.conv3(out)) 100 | if self.has_skip: 101 | if self.training and self.drop_rate > 0: 102 | out = drop_connect(out, self.drop_rate) 103 | out = out + x 104 | return out 105 | 106 | 107 | class EfficientNet(nn.Module): 108 | def __init__(self, cfg, num_classes=10): 109 | super(EfficientNet, self).__init__() 110 | self.cfg = cfg 111 | self.conv1 = nn.Conv2d(3, 112 | 32, 113 | kernel_size=3, 114 | stride=1, 115 | padding=1, 116 | bias=False) 117 | self.bn1 = nn.BatchNorm2d(32) 118 | self.layers = self._make_layers(in_channels=32) 119 | self.linear = nn.Linear(cfg['out_channels'][-1], num_classes) 120 | 121 | def _make_layers(self, in_channels): 122 | layers = [] 123 | cfg = [self.cfg[k] for k in ['expansion', 'out_channels', 'num_blocks', 'kernel_size', 124 | 'stride']] 125 | b = 0 126 | blocks = sum(self.cfg['num_blocks']) 127 | for expansion, out_channels, num_blocks, kernel_size, stride in zip(*cfg): 128 | strides = [stride] + [1] * (num_blocks - 1) 129 | for stride in strides: 130 | drop_rate = self.cfg['drop_connect_rate'] * b / blocks 131 | layers.append( 132 | Block(in_channels, 133 | out_channels, 134 | kernel_size, 135 | stride, 136 | expansion, 137 | se_ratio=0.25, 138 | drop_rate=drop_rate)) 139 | in_channels = out_channels 140 | return nn.Sequential(*layers) 141 | 142 | def forward(self, x): 143 | out = swish(self.bn1(self.conv1(x))) 144 | out = self.layers(out) 145 | out = F.adaptive_avg_pool2d(out, 1) 146 | out = out.view(out.size(0), -1) 147 | dropout_rate = self.cfg['dropout_rate'] 148 | if self.training and dropout_rate > 0: 149 | out = F.dropout(out, p=dropout_rate) 150 | out = self.linear(out) 151 | return out 152 | 153 | 154 | def EfficientNetB0(num_classes=10): 155 | cfg = { 156 | 'num_blocks': [1, 2, 2, 3, 3, 4, 1], 157 | 'expansion': [1, 6, 6, 6, 6, 6, 6], 158 | 'out_channels': [16, 24, 40, 80, 112, 192, 320], 159 | 'kernel_size': [3, 3, 5, 3, 5, 5, 3], 160 | 'stride': [1, 2, 2, 2, 1, 2, 1], 161 | 'dropout_rate': 0.2, 162 | 'drop_connect_rate': 0.2, 163 | } 164 | return EfficientNet(cfg, num_classes=num_classes) 165 | 166 | 167 | def test(): 168 | net = EfficientNetB0() 169 | x = torch.randn(2, 3, 32, 32) 170 | y = net(x) 171 | print(y.shape) 172 | 173 | 174 | if __name__ == '__main__': 175 | test() 176 | -------------------------------------------------------------------------------- /src/models/googlenet.py: -------------------------------------------------------------------------------- 1 | '''GoogLeNet with PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Inception(nn.Module): 8 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes): 9 | super(Inception, self).__init__() 10 | # 1x1 conv branch 11 | self.b1 = nn.Sequential( 12 | nn.Conv2d(in_planes, n1x1, kernel_size=1), 13 | nn.BatchNorm2d(n1x1), 14 | nn.ReLU(True), 15 | ) 16 | 17 | # 1x1 conv -> 3x3 conv branch 18 | self.b2 = nn.Sequential( 19 | nn.Conv2d(in_planes, n3x3red, kernel_size=1), 20 | nn.BatchNorm2d(n3x3red), 21 | nn.ReLU(True), 22 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1), 23 | nn.BatchNorm2d(n3x3), 24 | nn.ReLU(True), 25 | ) 26 | 27 | # 1x1 conv -> 5x5 conv branch 28 | self.b3 = nn.Sequential( 29 | nn.Conv2d(in_planes, n5x5red, kernel_size=1), 30 | nn.BatchNorm2d(n5x5red), 31 | nn.ReLU(True), 32 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1), 33 | nn.BatchNorm2d(n5x5), 34 | nn.ReLU(True), 35 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(n5x5), 37 | nn.ReLU(True), 38 | ) 39 | 40 | # 3x3 pool -> 1x1 conv branch 41 | self.b4 = nn.Sequential( 42 | nn.MaxPool2d(3, stride=1, padding=1), 43 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 44 | nn.BatchNorm2d(pool_planes), 45 | nn.ReLU(True), 46 | ) 47 | 48 | def forward(self, x): 49 | y1 = self.b1(x) 50 | y2 = self.b2(x) 51 | y3 = self.b3(x) 52 | y4 = self.b4(x) 53 | return torch.cat([y1,y2,y3,y4], 1) 54 | 55 | 56 | class GoogLeNet(nn.Module): 57 | def __init__(self): 58 | super(GoogLeNet, self).__init__() 59 | self.pre_layers = nn.Sequential( 60 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 61 | nn.BatchNorm2d(192), 62 | nn.ReLU(True), 63 | ) 64 | 65 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 66 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 67 | 68 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 69 | 70 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 71 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 72 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 73 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 74 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 75 | 76 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 77 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 78 | 79 | self.avgpool = nn.AvgPool2d(8, stride=1) 80 | self.linear = nn.Linear(1024, 10) 81 | 82 | def forward(self, x): 83 | out = self.pre_layers(x) 84 | out = self.a3(out) 85 | out = self.b3(out) 86 | out = self.maxpool(out) 87 | out = self.a4(out) 88 | out = self.b4(out) 89 | out = self.c4(out) 90 | out = self.d4(out) 91 | out = self.e4(out) 92 | out = self.maxpool(out) 93 | out = self.a5(out) 94 | out = self.b5(out) 95 | out = self.avgpool(out) 96 | out = out.view(out.size(0), -1) 97 | out = self.linear(out) 98 | return out 99 | 100 | 101 | def test(): 102 | net = GoogLeNet() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y.size()) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /src/models/lenet.py: -------------------------------------------------------------------------------- 1 | '''LeNet in PyTorch.''' 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class LeNet(nn.Module): 6 | def __init__(self): 7 | super(LeNet, self).__init__() 8 | self.conv1 = nn.Conv2d(3, 6, 5) 9 | self.conv2 = nn.Conv2d(6, 16, 5) 10 | self.fc1 = nn.Linear(16*5*5, 120) 11 | self.fc2 = nn.Linear(120, 84) 12 | self.fc3 = nn.Linear(84, 10) 13 | 14 | def forward(self, x): 15 | out = F.relu(self.conv1(x)) 16 | out = F.max_pool2d(out, 2) 17 | out = F.relu(self.conv2(out)) 18 | out = F.max_pool2d(out, 2) 19 | out = out.view(out.size(0), -1) 20 | out = F.relu(self.fc1(out)) 21 | out = F.relu(self.fc2(out)) 22 | out = self.fc3(out) 23 | return out 24 | -------------------------------------------------------------------------------- /src/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | '''MobileNet in PyTorch. 2 | 3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 4 | for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''Depthwise conv + Pointwise conv''' 13 | def __init__(self, in_planes, out_planes, stride=1): 14 | super(Block, self).__init__() 15 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 18 | self.bn2 = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | out = F.relu(self.bn1(self.conv1(x))) 22 | out = F.relu(self.bn2(self.conv2(out))) 23 | return out 24 | 25 | 26 | class MobileNet(nn.Module): 27 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 28 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 29 | 30 | def __init__(self, num_classes=10): 31 | super(MobileNet, self).__init__() 32 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(32) 34 | self.layers = self._make_layers(in_planes=32) 35 | self.linear = nn.Linear(1024, num_classes) 36 | 37 | def _make_layers(self, in_planes): 38 | layers = [] 39 | for x in self.cfg: 40 | out_planes = x if isinstance(x, int) else x[0] 41 | stride = 1 if isinstance(x, int) else x[1] 42 | layers.append(Block(in_planes, out_planes, stride)) 43 | in_planes = out_planes 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | out = F.relu(self.bn1(self.conv1(x))) 48 | out = self.layers(out) 49 | out = F.avg_pool2d(out, 2) 50 | out = out.view(out.size(0), -1) 51 | out = self.linear(out) 52 | return out 53 | 54 | 55 | def test(): 56 | net = MobileNet() 57 | x = torch.randn(1,3,32,32) 58 | y = net(x) 59 | print(y.size()) 60 | 61 | # test() 62 | -------------------------------------------------------------------------------- /src/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | '''MobileNetV2 in PyTorch. 2 | 3 | See the paper "Inverted Residuals and Linear Bottlenecks: 4 | Mobile Networks for Classification, Detection and Segmentation" for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''expand + depthwise + pointwise''' 13 | def __init__(self, in_planes, out_planes, expansion, stride): 14 | super(Block, self).__init__() 15 | self.stride = stride 16 | 17 | planes = expansion * in_planes 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 23 | self.bn3 = nn.BatchNorm2d(out_planes) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride == 1 and in_planes != out_planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 29 | nn.BatchNorm2d(out_planes), 30 | ) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.bn1(self.conv1(x))) 34 | out = F.relu(self.bn2(self.conv2(out))) 35 | out = self.bn3(self.conv3(out)) 36 | out = out + self.shortcut(x) if self.stride==1 else out 37 | return out 38 | 39 | 40 | class MobileNetV2(nn.Module): 41 | # (expansion, out_planes, num_blocks, stride) 42 | cfg = [(1, 16, 1, 1), 43 | # (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 44 | (6, 24, 2, 2), # NOTE: change for STL10 45 | (6, 32, 3, 2), 46 | (6, 64, 4, 2), 47 | (6, 96, 3, 1), 48 | (6, 160, 3, 2), 49 | (6, 320, 1, 1)] 50 | 51 | def __init__(self, num_classes=10): 52 | super(MobileNetV2, self).__init__() 53 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 54 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 55 | self.bn1 = nn.BatchNorm2d(32) 56 | self.layers = self._make_layers(in_planes=32) 57 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 58 | self.bn2 = nn.BatchNorm2d(1280) 59 | self.linear = nn.Linear(1280, num_classes) 60 | 61 | def _make_layers(self, in_planes): 62 | layers = [] 63 | for expansion, out_planes, num_blocks, stride in self.cfg: 64 | strides = [stride] + [1]*(num_blocks-1) 65 | for stride in strides: 66 | layers.append(Block(in_planes, out_planes, expansion, stride)) 67 | in_planes = out_planes 68 | return nn.Sequential(*layers) 69 | 70 | def forward(self, x): 71 | out = F.relu(self.bn1(self.conv1(x))) 72 | out = self.layers(out) 73 | out = F.relu(self.bn2(self.conv2(out))) 74 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 75 | # out = F.avg_pool2d(out, 4) 76 | out = F.adaptive_avg_pool2d(out, 1) # NOTE: change for STL10 77 | out = out.view(out.size(0), -1) 78 | out = self.linear(out) 79 | return out 80 | 81 | 82 | def test(): 83 | net = MobileNetV2() 84 | x = torch.randn(2,3,32,32) 85 | y = net(x) 86 | print(y.size()) 87 | 88 | # test() 89 | -------------------------------------------------------------------------------- /src/models/pnasnet.py: -------------------------------------------------------------------------------- 1 | '''PNASNet in PyTorch. 2 | 3 | Paper: Progressive Neural Architecture Search 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class SepConv(nn.Module): 11 | '''Separable Convolution.''' 12 | def __init__(self, in_planes, out_planes, kernel_size, stride): 13 | super(SepConv, self).__init__() 14 | self.conv1 = nn.Conv2d(in_planes, out_planes, 15 | kernel_size, stride, 16 | padding=(kernel_size-1)//2, 17 | bias=False, groups=in_planes) 18 | self.bn1 = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | return self.bn1(self.conv1(x)) 22 | 23 | 24 | class CellA(nn.Module): 25 | def __init__(self, in_planes, out_planes, stride=1): 26 | super(CellA, self).__init__() 27 | self.stride = stride 28 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 29 | if stride==2: 30 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 31 | self.bn1 = nn.BatchNorm2d(out_planes) 32 | 33 | def forward(self, x): 34 | y1 = self.sep_conv1(x) 35 | y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 36 | if self.stride==2: 37 | y2 = self.bn1(self.conv1(y2)) 38 | return F.relu(y1+y2) 39 | 40 | class CellB(nn.Module): 41 | def __init__(self, in_planes, out_planes, stride=1): 42 | super(CellB, self).__init__() 43 | self.stride = stride 44 | # Left branch 45 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 46 | self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride) 47 | # Right branch 48 | self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride) 49 | if stride==2: 50 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 51 | self.bn1 = nn.BatchNorm2d(out_planes) 52 | # Reduce channels 53 | self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 54 | self.bn2 = nn.BatchNorm2d(out_planes) 55 | 56 | def forward(self, x): 57 | # Left branch 58 | y1 = self.sep_conv1(x) 59 | y2 = self.sep_conv2(x) 60 | # Right branch 61 | y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 62 | if self.stride==2: 63 | y3 = self.bn1(self.conv1(y3)) 64 | y4 = self.sep_conv3(x) 65 | # Concat & reduce channels 66 | b1 = F.relu(y1+y2) 67 | b2 = F.relu(y3+y4) 68 | y = torch.cat([b1,b2], 1) 69 | return F.relu(self.bn2(self.conv2(y))) 70 | 71 | class PNASNet(nn.Module): 72 | def __init__(self, cell_type, num_cells, num_planes): 73 | super(PNASNet, self).__init__() 74 | self.in_planes = num_planes 75 | self.cell_type = cell_type 76 | 77 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False) 78 | self.bn1 = nn.BatchNorm2d(num_planes) 79 | 80 | self.layer1 = self._make_layer(num_planes, num_cells=6) 81 | self.layer2 = self._downsample(num_planes*2) 82 | self.layer3 = self._make_layer(num_planes*2, num_cells=6) 83 | self.layer4 = self._downsample(num_planes*4) 84 | self.layer5 = self._make_layer(num_planes*4, num_cells=6) 85 | 86 | self.linear = nn.Linear(num_planes*4, 10) 87 | 88 | def _make_layer(self, planes, num_cells): 89 | layers = [] 90 | for _ in range(num_cells): 91 | layers.append(self.cell_type(self.in_planes, planes, stride=1)) 92 | self.in_planes = planes 93 | return nn.Sequential(*layers) 94 | 95 | def _downsample(self, planes): 96 | layer = self.cell_type(self.in_planes, planes, stride=2) 97 | self.in_planes = planes 98 | return layer 99 | 100 | def forward(self, x): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = self.layer5(out) 107 | out = F.avg_pool2d(out, 8) 108 | out = self.linear(out.view(out.size(0), -1)) 109 | return out 110 | 111 | 112 | def PNASNetA(): 113 | return PNASNet(CellA, num_cells=6, num_planes=44) 114 | 115 | def PNASNetB(): 116 | return PNASNet(CellB, num_cells=6, num_planes=32) 117 | 118 | 119 | def test(): 120 | net = PNASNetB() 121 | x = torch.randn(1,3,32,32) 122 | y = net(x) 123 | print(y) 124 | 125 | # test() 126 | -------------------------------------------------------------------------------- /src/models/preact_resnet.py: -------------------------------------------------------------------------------- 1 | '''Pre-activation ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class PreActBlock(nn.Module): 13 | '''Pre-activation version of the BasicBlock.''' 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | out = self.conv2(F.relu(self.bn2(out))) 33 | out += shortcut 34 | return out 35 | 36 | 37 | class PreActBottleneck(nn.Module): 38 | '''Pre-activation version of the original Bottleneck module.''' 39 | expansion = 4 40 | 41 | def __init__(self, in_planes, planes, stride=1): 42 | super(PreActBottleneck, self).__init__() 43 | self.bn1 = nn.BatchNorm2d(in_planes) 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 53 | ) 54 | 55 | def forward(self, x): 56 | out = F.relu(self.bn1(x)) 57 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 58 | out = self.conv1(out) 59 | out = self.conv2(F.relu(self.bn2(out))) 60 | out = self.conv3(F.relu(self.bn3(out))) 61 | out += shortcut 62 | return out 63 | 64 | 65 | class PreActResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10): 67 | super(PreActResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 72 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 73 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 74 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 75 | self.linear = nn.Linear(512*block.expansion, num_classes) 76 | 77 | def _make_layer(self, block, planes, num_blocks, stride): 78 | strides = [stride] + [1]*(num_blocks-1) 79 | layers = [] 80 | for stride in strides: 81 | layers.append(block(self.in_planes, planes, stride)) 82 | self.in_planes = planes * block.expansion 83 | return nn.Sequential(*layers) 84 | 85 | def forward(self, x): 86 | out = self.conv1(x) 87 | out = self.layer1(out) 88 | out = self.layer2(out) 89 | out = self.layer3(out) 90 | out = self.layer4(out) 91 | out = F.avg_pool2d(out, 4) 92 | out = out.view(out.size(0), -1) 93 | out = self.linear(out) 94 | return out 95 | 96 | 97 | def PreActResNet18(): 98 | return PreActResNet(PreActBlock, [2,2,2,2]) 99 | 100 | def PreActResNet34(): 101 | return PreActResNet(PreActBlock, [3,4,6,3]) 102 | 103 | def PreActResNet50(): 104 | return PreActResNet(PreActBottleneck, [3,4,6,3]) 105 | 106 | def PreActResNet101(): 107 | return PreActResNet(PreActBottleneck, [3,4,23,3]) 108 | 109 | def PreActResNet152(): 110 | return PreActResNet(PreActBottleneck, [3,8,36,3]) 111 | 112 | 113 | def test(): 114 | net = PreActResNet18() 115 | y = net((torch.randn(1,3,32,32))) 116 | print(y.size()) 117 | 118 | # test() 119 | -------------------------------------------------------------------------------- /src/models/regnet.py: -------------------------------------------------------------------------------- 1 | '''RegNet in PyTorch. 2 | 3 | Paper: "Designing Network Design Spaces". 4 | 5 | Reference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class SE(nn.Module): 13 | '''Squeeze-and-Excitation block.''' 14 | 15 | def __init__(self, in_planes, se_planes): 16 | super(SE, self).__init__() 17 | self.se1 = nn.Conv2d(in_planes, se_planes, kernel_size=1, bias=True) 18 | self.se2 = nn.Conv2d(se_planes, in_planes, kernel_size=1, bias=True) 19 | 20 | def forward(self, x): 21 | out = F.adaptive_avg_pool2d(x, (1, 1)) 22 | out = F.relu(self.se1(out)) 23 | out = self.se2(out).sigmoid() 24 | out = x * out 25 | return out 26 | 27 | 28 | class Block(nn.Module): 29 | def __init__(self, w_in, w_out, stride, group_width, bottleneck_ratio, se_ratio): 30 | super(Block, self).__init__() 31 | # 1x1 32 | w_b = int(round(w_out * bottleneck_ratio)) 33 | self.conv1 = nn.Conv2d(w_in, w_b, kernel_size=1, bias=False) 34 | self.bn1 = nn.BatchNorm2d(w_b) 35 | # 3x3 36 | num_groups = w_b // group_width 37 | self.conv2 = nn.Conv2d(w_b, w_b, kernel_size=3, 38 | stride=stride, padding=1, groups=num_groups, bias=False) 39 | self.bn2 = nn.BatchNorm2d(w_b) 40 | # se 41 | self.with_se = se_ratio > 0 42 | if self.with_se: 43 | w_se = int(round(w_in * se_ratio)) 44 | self.se = SE(w_b, w_se) 45 | # 1x1 46 | self.conv3 = nn.Conv2d(w_b, w_out, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(w_out) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride != 1 or w_in != w_out: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(w_in, w_out, 53 | kernel_size=1, stride=stride, bias=False), 54 | nn.BatchNorm2d(w_out) 55 | ) 56 | 57 | def forward(self, x): 58 | out = F.relu(self.bn1(self.conv1(x))) 59 | out = F.relu(self.bn2(self.conv2(out))) 60 | if self.with_se: 61 | out = self.se(out) 62 | out = self.bn3(self.conv3(out)) 63 | out += self.shortcut(x) 64 | out = F.relu(out) 65 | return out 66 | 67 | 68 | class RegNet(nn.Module): 69 | def __init__(self, cfg, num_classes=10): 70 | super(RegNet, self).__init__() 71 | self.cfg = cfg 72 | self.in_planes = 64 73 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 74 | stride=1, padding=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(64) 76 | self.layer1 = self._make_layer(0) 77 | self.layer2 = self._make_layer(1) 78 | self.layer3 = self._make_layer(2) 79 | self.layer4 = self._make_layer(3) 80 | self.linear = nn.Linear(self.cfg['widths'][-1], num_classes) 81 | 82 | def _make_layer(self, idx): 83 | depth = self.cfg['depths'][idx] 84 | width = self.cfg['widths'][idx] 85 | stride = self.cfg['strides'][idx] 86 | group_width = self.cfg['group_width'] 87 | bottleneck_ratio = self.cfg['bottleneck_ratio'] 88 | se_ratio = self.cfg['se_ratio'] 89 | 90 | layers = [] 91 | for i in range(depth): 92 | s = stride if i == 0 else 1 93 | layers.append(Block(self.in_planes, width, 94 | s, group_width, bottleneck_ratio, se_ratio)) 95 | self.in_planes = width 96 | return nn.Sequential(*layers) 97 | 98 | def forward(self, x): 99 | out = F.relu(self.bn1(self.conv1(x))) 100 | out = self.layer1(out) 101 | out = self.layer2(out) 102 | out = self.layer3(out) 103 | out = self.layer4(out) 104 | out = F.adaptive_avg_pool2d(out, (1, 1)) 105 | out = out.view(out.size(0), -1) 106 | out = self.linear(out) 107 | return out 108 | 109 | 110 | def RegNetX_200MF(): 111 | cfg = { 112 | 'depths': [1, 1, 4, 7], 113 | 'widths': [24, 56, 152, 368], 114 | 'strides': [1, 1, 2, 2], 115 | 'group_width': 8, 116 | 'bottleneck_ratio': 1, 117 | 'se_ratio': 0, 118 | } 119 | return RegNet(cfg) 120 | 121 | 122 | def RegNetX_400MF(): 123 | cfg = { 124 | 'depths': [1, 2, 7, 12], 125 | 'widths': [32, 64, 160, 384], 126 | 'strides': [1, 1, 2, 2], 127 | 'group_width': 16, 128 | 'bottleneck_ratio': 1, 129 | 'se_ratio': 0, 130 | } 131 | return RegNet(cfg) 132 | 133 | 134 | def RegNetY_400MF(): 135 | cfg = { 136 | 'depths': [1, 2, 7, 12], 137 | 'widths': [32, 64, 160, 384], 138 | 'strides': [1, 1, 2, 2], 139 | 'group_width': 16, 140 | 'bottleneck_ratio': 1, 141 | 'se_ratio': 0.25, 142 | } 143 | return RegNet(cfg) 144 | 145 | 146 | def test(): 147 | net = RegNetX_200MF() 148 | print(net) 149 | x = torch.randn(2, 3, 32, 32) 150 | y = net(x) 151 | print(y.shape) 152 | 153 | 154 | if __name__ == '__main__': 155 | test() 156 | -------------------------------------------------------------------------------- /src/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d( 20 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 23 | stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*planes: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*planes, 30 | kernel_size=1, stride=stride, bias=False), 31 | nn.BatchNorm2d(self.expansion*planes) 32 | ) 33 | 34 | def forward(self, x): 35 | out = F.relu(self.bn1(self.conv1(x))) 36 | out = self.bn2(self.conv2(out)) 37 | out += self.shortcut(x) 38 | out = F.relu(out) 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1): 46 | super(Bottleneck, self).__init__() 47 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(planes) 49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 50 | stride=stride, padding=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv3 = nn.Conv2d(planes, self.expansion * 53 | planes, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 55 | 56 | self.shortcut = nn.Sequential() 57 | if stride != 1 or in_planes != self.expansion*planes: 58 | self.shortcut = nn.Sequential( 59 | nn.Conv2d(in_planes, self.expansion*planes, 60 | kernel_size=1, stride=stride, bias=False), 61 | nn.BatchNorm2d(self.expansion*planes) 62 | ) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(self.conv1(x))) 66 | out = F.relu(self.bn2(self.conv2(out))) 67 | out = self.bn3(self.conv3(out)) 68 | out += self.shortcut(x) 69 | out = F.relu(out) 70 | return out 71 | 72 | 73 | class ResNet(nn.Module): 74 | def __init__(self, block, num_blocks, num_classes=10, resolution=32): 75 | super(ResNet, self).__init__() 76 | self.in_planes = 64 77 | self.resolution = resolution 78 | 79 | print(">>>>> ResNet resolution: ", resolution, " <<<<<") 80 | if resolution == 32: 81 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 82 | stride=1, padding=1, bias=False) 83 | elif resolution == 28: 84 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 85 | stride=1, padding=3, bias=False) 86 | elif resolution == 96: 87 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, 88 | stride=2, padding=3, bias=False) 89 | elif resolution == 256: 90 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, 91 | stride=2, padding=3, bias=False) 92 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 93 | 94 | self.bn1 = nn.BatchNorm2d(64) 95 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 96 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 97 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 98 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 99 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 100 | self.linear = nn.Linear(512*block.expansion, num_classes) 101 | 102 | def _make_layer(self, block, planes, num_blocks, stride): 103 | strides = [stride] + [1]*(num_blocks-1) 104 | layers = [] 105 | for stride in strides: 106 | layers.append(block(self.in_planes, planes, stride)) 107 | self.in_planes = planes * block.expansion 108 | return nn.Sequential(*layers) 109 | 110 | def forward(self, x): 111 | out = F.relu(self.bn1(self.conv1(x))) 112 | if self.resolution==256: 113 | out = self.maxpool(out) 114 | out = self.layer1(out) 115 | out = self.layer2(out) 116 | out = self.layer3(out) 117 | out = self.layer4(out) 118 | out = self.avgpool(out) 119 | out = out.view(out.size(0), -1) 120 | out = self.linear(out) 121 | return out 122 | 123 | def ResNet18(num_classes=10, resolution=32): 124 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, resolution=resolution) 125 | 126 | 127 | def ResNet34(num_classes=10, resolution=32): 128 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, resolution=resolution) 129 | 130 | 131 | def ResNet50(num_classes=10, resolution=32): 132 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, resolution=resolution) 133 | 134 | 135 | def ResNet101(num_classes=10, resolution=32): 136 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, resolution=resolution) 137 | 138 | 139 | def ResNet152(num_classes=10, resolution=32): 140 | return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, resolution=resolution) 141 | 142 | 143 | def test(): 144 | net = ResNet18() 145 | y = net(torch.randn(1, 3, 32, 32)) 146 | print(y.size()) 147 | 148 | # test() 149 | -------------------------------------------------------------------------------- /src/models/resnext.py: -------------------------------------------------------------------------------- 1 | '''ResNeXt in PyTorch. 2 | 3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Block(nn.Module): 11 | '''Grouped convolution block.''' 12 | expansion = 2 13 | 14 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): 15 | super(Block, self).__init__() 16 | group_width = cardinality * bottleneck_width 17 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(group_width) 19 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 20 | self.bn2 = nn.BatchNorm2d(group_width) 21 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) 22 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*group_width: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*group_width) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = F.relu(self.bn2(self.conv2(out))) 34 | out = self.bn3(self.conv3(out)) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class ResNeXt(nn.Module): 41 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10): 42 | super(ResNeXt, self).__init__() 43 | self.cardinality = cardinality 44 | self.bottleneck_width = bottleneck_width 45 | self.in_planes = 64 46 | 47 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(64) 49 | self.layer1 = self._make_layer(num_blocks[0], 1) 50 | self.layer2 = self._make_layer(num_blocks[1], 2) 51 | self.layer3 = self._make_layer(num_blocks[2], 2) 52 | # self.layer4 = self._make_layer(num_blocks[3], 2) 53 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes) 54 | 55 | def _make_layer(self, num_blocks, stride): 56 | strides = [stride] + [1]*(num_blocks-1) 57 | layers = [] 58 | for stride in strides: 59 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) 60 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width 61 | # Increase bottleneck_width by 2 after each stage. 62 | self.bottleneck_width *= 2 63 | return nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(self.conv1(x))) 67 | out = self.layer1(out) 68 | out = self.layer2(out) 69 | out = self.layer3(out) 70 | # out = self.layer4(out) 71 | out = F.avg_pool2d(out, 8) 72 | out = out.view(out.size(0), -1) 73 | out = self.linear(out) 74 | return out 75 | 76 | 77 | def ResNeXt29_2x64d(): 78 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64) 79 | 80 | def ResNeXt29_4x64d(): 81 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64) 82 | 83 | def ResNeXt29_8x64d(): 84 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64) 85 | 86 | def ResNeXt29_32x4d(): 87 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4) 88 | 89 | def test_resnext(): 90 | net = ResNeXt29_2x64d() 91 | x = torch.randn(1,3,32,32) 92 | y = net(x) 93 | print(y.size()) 94 | 95 | # test_resnext() 96 | -------------------------------------------------------------------------------- /src/models/senet.py: -------------------------------------------------------------------------------- 1 | '''SENet in PyTorch. 2 | 3 | SENet is the winner of ImageNet-2017. The paper is not released yet. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | def __init__(self, in_planes, planes, stride=1): 12 | super(BasicBlock, self).__init__() 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | 18 | self.shortcut = nn.Sequential() 19 | if stride != 1 or in_planes != planes: 20 | self.shortcut = nn.Sequential( 21 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), 22 | nn.BatchNorm2d(planes) 23 | ) 24 | 25 | # SE layers 26 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear 27 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | 33 | # Squeeze 34 | w = F.avg_pool2d(out, out.size(2)) 35 | w = F.relu(self.fc1(w)) 36 | w = F.sigmoid(self.fc2(w)) 37 | # Excitation 38 | out = out * w # New broadcasting feature from v0.2! 39 | 40 | out += self.shortcut(x) 41 | out = F.relu(out) 42 | return out 43 | 44 | 45 | class PreActBlock(nn.Module): 46 | def __init__(self, in_planes, planes, stride=1): 47 | super(PreActBlock, self).__init__() 48 | self.bn1 = nn.BatchNorm2d(in_planes) 49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 50 | self.bn2 = nn.BatchNorm2d(planes) 51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 52 | 53 | if stride != 1 or in_planes != planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False) 56 | ) 57 | 58 | # SE layers 59 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) 60 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(x)) 64 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 65 | out = self.conv1(out) 66 | out = self.conv2(F.relu(self.bn2(out))) 67 | 68 | # Squeeze 69 | w = F.avg_pool2d(out, out.size(2)) 70 | w = F.relu(self.fc1(w)) 71 | w = F.sigmoid(self.fc2(w)) 72 | # Excitation 73 | out = out * w 74 | 75 | out += shortcut 76 | return out 77 | 78 | 79 | class SENet(nn.Module): 80 | def __init__(self, block, num_blocks, num_classes=10): 81 | super(SENet, self).__init__() 82 | self.in_planes = 64 83 | 84 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 85 | self.bn1 = nn.BatchNorm2d(64) 86 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 87 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 88 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 89 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 90 | self.linear = nn.Linear(512, num_classes) 91 | 92 | def _make_layer(self, block, planes, num_blocks, stride): 93 | strides = [stride] + [1]*(num_blocks-1) 94 | layers = [] 95 | for stride in strides: 96 | layers.append(block(self.in_planes, planes, stride)) 97 | self.in_planes = planes 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = F.avg_pool2d(out, 4) 107 | out = out.view(out.size(0), -1) 108 | out = self.linear(out) 109 | return out 110 | 111 | 112 | def SENet18(): 113 | return SENet(PreActBlock, [2,2,2,2]) 114 | 115 | 116 | def test(): 117 | net = SENet18() 118 | y = net(torch.randn(1,3,32,32)) 119 | print(y.size()) 120 | 121 | # test() 122 | -------------------------------------------------------------------------------- /src/models/shufflenet.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNet in PyTorch. 2 | 3 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ShuffleBlock(nn.Module): 11 | def __init__(self, groups): 12 | super(ShuffleBlock, self).__init__() 13 | self.groups = groups 14 | 15 | def forward(self, x): 16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 17 | N,C,H,W = x.size() 18 | g = self.groups 19 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W) 20 | 21 | 22 | class Bottleneck(nn.Module): 23 | def __init__(self, in_planes, out_planes, stride, groups): 24 | super(Bottleneck, self).__init__() 25 | self.stride = stride 26 | 27 | mid_planes = out_planes/4 28 | g = 1 if in_planes==24 else groups 29 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 30 | self.bn1 = nn.BatchNorm2d(mid_planes) 31 | self.shuffle1 = ShuffleBlock(groups=g) 32 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 33 | self.bn2 = nn.BatchNorm2d(mid_planes) 34 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 35 | self.bn3 = nn.BatchNorm2d(out_planes) 36 | 37 | self.shortcut = nn.Sequential() 38 | if stride == 2: 39 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 40 | 41 | def forward(self, x): 42 | out = F.relu(self.bn1(self.conv1(x))) 43 | out = self.shuffle1(out) 44 | out = F.relu(self.bn2(self.conv2(out))) 45 | out = self.bn3(self.conv3(out)) 46 | res = self.shortcut(x) 47 | out = F.relu(torch.cat([out,res], 1)) if self.stride==2 else F.relu(out+res) 48 | return out 49 | 50 | 51 | class ShuffleNet(nn.Module): 52 | def __init__(self, cfg): 53 | super(ShuffleNet, self).__init__() 54 | out_planes = cfg['out_planes'] 55 | num_blocks = cfg['num_blocks'] 56 | groups = cfg['groups'] 57 | 58 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(24) 60 | self.in_planes = 24 61 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 62 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 63 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 64 | self.linear = nn.Linear(out_planes[2], 10) 65 | 66 | def _make_layer(self, out_planes, num_blocks, groups): 67 | layers = [] 68 | for i in range(num_blocks): 69 | stride = 2 if i == 0 else 1 70 | cat_planes = self.in_planes if i == 0 else 0 71 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups)) 72 | self.in_planes = out_planes 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = F.relu(self.bn1(self.conv1(x))) 77 | out = self.layer1(out) 78 | out = self.layer2(out) 79 | out = self.layer3(out) 80 | out = F.avg_pool2d(out, 4) 81 | out = out.view(out.size(0), -1) 82 | out = self.linear(out) 83 | return out 84 | 85 | 86 | def ShuffleNetG2(): 87 | cfg = { 88 | 'out_planes': [200,400,800], 89 | 'num_blocks': [4,8,4], 90 | 'groups': 2 91 | } 92 | return ShuffleNet(cfg) 93 | 94 | def ShuffleNetG3(): 95 | cfg = { 96 | 'out_planes': [240,480,960], 97 | 'num_blocks': [4,8,4], 98 | 'groups': 3 99 | } 100 | return ShuffleNet(cfg) 101 | 102 | 103 | def test(): 104 | net = ShuffleNetG2() 105 | x = torch.randn(1,3,32,32) 106 | y = net(x) 107 | print(y) 108 | 109 | # test() 110 | -------------------------------------------------------------------------------- /src/models/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNetV2 in PyTorch. 2 | 3 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ShuffleBlock(nn.Module): 11 | def __init__(self, groups=2): 12 | super(ShuffleBlock, self).__init__() 13 | self.groups = groups 14 | 15 | def forward(self, x): 16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 17 | N, C, H, W = x.size() 18 | g = self.groups 19 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 20 | 21 | 22 | class SplitBlock(nn.Module): 23 | def __init__(self, ratio): 24 | super(SplitBlock, self).__init__() 25 | self.ratio = ratio 26 | 27 | def forward(self, x): 28 | c = int(x.size(1) * self.ratio) 29 | return x[:, :c, :, :], x[:, c:, :, :] 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | def __init__(self, in_channels, split_ratio=0.5): 34 | super(BasicBlock, self).__init__() 35 | self.split = SplitBlock(split_ratio) 36 | in_channels = int(in_channels * split_ratio) 37 | self.conv1 = nn.Conv2d(in_channels, in_channels, 38 | kernel_size=1, bias=False) 39 | self.bn1 = nn.BatchNorm2d(in_channels) 40 | self.conv2 = nn.Conv2d(in_channels, in_channels, 41 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False) 42 | self.bn2 = nn.BatchNorm2d(in_channels) 43 | self.conv3 = nn.Conv2d(in_channels, in_channels, 44 | kernel_size=1, bias=False) 45 | self.bn3 = nn.BatchNorm2d(in_channels) 46 | self.shuffle = ShuffleBlock() 47 | 48 | def forward(self, x): 49 | x1, x2 = self.split(x) 50 | out = F.relu(self.bn1(self.conv1(x2))) 51 | out = self.bn2(self.conv2(out)) 52 | out = F.relu(self.bn3(self.conv3(out))) 53 | out = torch.cat([x1, out], 1) 54 | out = self.shuffle(out) 55 | return out 56 | 57 | 58 | class DownBlock(nn.Module): 59 | def __init__(self, in_channels, out_channels): 60 | super(DownBlock, self).__init__() 61 | mid_channels = out_channels // 2 62 | # left 63 | self.conv1 = nn.Conv2d(in_channels, in_channels, 64 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False) 65 | self.bn1 = nn.BatchNorm2d(in_channels) 66 | self.conv2 = nn.Conv2d(in_channels, mid_channels, 67 | kernel_size=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(mid_channels) 69 | # right 70 | self.conv3 = nn.Conv2d(in_channels, mid_channels, 71 | kernel_size=1, bias=False) 72 | self.bn3 = nn.BatchNorm2d(mid_channels) 73 | self.conv4 = nn.Conv2d(mid_channels, mid_channels, 74 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False) 75 | self.bn4 = nn.BatchNorm2d(mid_channels) 76 | self.conv5 = nn.Conv2d(mid_channels, mid_channels, 77 | kernel_size=1, bias=False) 78 | self.bn5 = nn.BatchNorm2d(mid_channels) 79 | 80 | self.shuffle = ShuffleBlock() 81 | 82 | def forward(self, x): 83 | # left 84 | out1 = self.bn1(self.conv1(x)) 85 | out1 = F.relu(self.bn2(self.conv2(out1))) 86 | # right 87 | out2 = F.relu(self.bn3(self.conv3(x))) 88 | out2 = self.bn4(self.conv4(out2)) 89 | out2 = F.relu(self.bn5(self.conv5(out2))) 90 | # concat 91 | out = torch.cat([out1, out2], 1) 92 | out = self.shuffle(out) 93 | return out 94 | 95 | 96 | class ShuffleNetV2(nn.Module): 97 | def __init__(self, net_size, num_classes=10): 98 | super(ShuffleNetV2, self).__init__() 99 | out_channels = configs[net_size]['out_channels'] 100 | num_blocks = configs[net_size]['num_blocks'] 101 | 102 | self.conv1 = nn.Conv2d(3, 24, kernel_size=3, 103 | stride=1, padding=1, bias=False) 104 | self.bn1 = nn.BatchNorm2d(24) 105 | self.in_channels = 24 106 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) 107 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) 108 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) 109 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3], 110 | kernel_size=1, stride=1, padding=0, bias=False) 111 | self.bn2 = nn.BatchNorm2d(out_channels[3]) 112 | self.linear = nn.Linear(out_channels[3], num_classes) 113 | 114 | def _make_layer(self, out_channels, num_blocks): 115 | layers = [DownBlock(self.in_channels, out_channels)] 116 | for i in range(num_blocks): 117 | layers.append(BasicBlock(out_channels)) 118 | self.in_channels = out_channels 119 | return nn.Sequential(*layers) 120 | 121 | def forward(self, x): 122 | out = F.relu(self.bn1(self.conv1(x))) 123 | out = F.max_pool2d(out, 3, stride=2, padding=1) 124 | out = self.layer1(out) 125 | out = self.layer2(out) 126 | out = self.layer3(out) 127 | out = F.relu(self.bn2(self.conv2(out))) 128 | out = F.adaptive_avg_pool2d(out, 1) # NOTE: change for STL10 129 | out = out.view(out.size(0), -1) 130 | out = self.linear(out) 131 | return out 132 | 133 | 134 | configs = { 135 | 0.5: { 136 | 'out_channels': (48, 96, 192, 1024), 137 | 'num_blocks': (3, 7, 3) 138 | }, 139 | 140 | 1: { 141 | 'out_channels': (116, 232, 464, 1024), 142 | 'num_blocks': (3, 7, 3) 143 | }, 144 | 1.5: { 145 | 'out_channels': (176, 352, 704, 1024), 146 | 'num_blocks': (3, 7, 3) 147 | }, 148 | 2: { 149 | 'out_channels': (224, 488, 976, 2048), 150 | 'num_blocks': (3, 7, 3) 151 | } 152 | } 153 | 154 | 155 | def test(): 156 | net = ShuffleNetV2(net_size=0.5) 157 | x = torch.randn(3, 3, 32, 32) 158 | y = net(x) 159 | print(y.shape) 160 | 161 | 162 | # test() 163 | -------------------------------------------------------------------------------- /src/models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | cfg = { 7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 11 | } 12 | 13 | 14 | class VGG(nn.Module): 15 | def __init__(self, vgg_name, num_classes=10): 16 | super(VGG, self).__init__() 17 | self.features = self._make_layers(cfg[vgg_name]) 18 | self.classifier = nn.Linear(512, num_classes) 19 | 20 | def forward(self, x): 21 | out = self.features(x) 22 | out = out.view(out.size(0), -1) 23 | out = self.classifier(out) 24 | return out 25 | 26 | def _make_layers(self, cfg): 27 | layers = [] 28 | in_channels = 3 29 | for x in cfg: 30 | if x == 'M': 31 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 32 | else: 33 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 34 | nn.BatchNorm2d(x), 35 | nn.ReLU(inplace=True)] 36 | in_channels = x 37 | layers += [nn.AdaptiveAvgPool2d((1, 1))] 38 | return nn.Sequential(*layers) 39 | 40 | 41 | def test(): 42 | net = VGG('VGG11') 43 | x = torch.randn(2,3,32,32) 44 | y = net(x) 45 | print(y.size()) 46 | 47 | # test() 48 | -------------------------------------------------------------------------------- /src/script/clip_retrieval_stl10.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import json 5 | import fire 6 | import numpy as np 7 | import pandas as pd 8 | import tqdm 9 | from clip_retrieval.clip_client import ClipClient 10 | 11 | def get_urls(img_folder, class_idx, k=45): 12 | # for each image, get the top 45 urls 13 | print(f'Number of images per query: {k}') 14 | client = ClipClient(url="https://knn.laion.ai/knn-service", 15 | indice_name="laion5B-L-14", num_images=k, 16 | use_safety_model=False, use_violence_detector=False, deduplicate=False) 17 | 18 | url_map = {} 19 | for name in tqdm.tqdm(os.listdir(f'{img_folder}/class_{class_idx:03d}')): 20 | img_path = f'{img_folder}/class_{class_idx:03d}/{name}' 21 | assert os.path.exists(img_path) 22 | try: 23 | results = client.query(image=img_path) 24 | url_map[name] = {'url': [ele['url'] for ele in results], 'id': [ele['id'] for ele in results]} 25 | except: 26 | print('Error: Class{} {}'.format(class_idx, name)) 27 | # print(url_map[name]) 28 | return url_map 29 | 30 | 31 | def url_analysis(url_map, key='id'): 32 | # count the unique urls 33 | unique_urls = set() 34 | 35 | for name, ele in url_map.items(): 36 | for url in ele[key]: 37 | unique_urls.add(url) 38 | 39 | print('Unique urls:') 40 | print(len(unique_urls)) 41 | 42 | # count occurrences of each url 43 | occurrences = {} 44 | for url in unique_urls: 45 | count = 0 46 | for name, ele in url_map.items(): 47 | if url in ele[key]: 48 | count += 1 49 | occurrences[url] = count 50 | # print('Occurrences:') 51 | # print(occurrences) 52 | 53 | # compute the average number of occurrences 54 | mean = sum(occurrences.values()) / len(occurrences) 55 | print('Mean occurrences:') 56 | print(mean) 57 | # The underlying reason is because the clip embedding does not capture the discriminative information of the image on the target domain. 58 | 59 | 60 | def main(class_idx, k=45): 61 | img_folder = "/ssd005/projects/diffusion_inversion/real_data/stl10/scaling/res96_bicubic" 62 | output_dir = f'clip_retrieval/stl10' 63 | if not os.path.exists(output_dir): 64 | os.makedirs(output_dir) 65 | 66 | # url_map = get_urls(img_folder, class_idx=class_idx, k=k) 67 | 68 | # with open(f"{output_dir}/cls{class_idx}_k{k}.json", "w") as f: 69 | # json.dump(url_map, f) 70 | 71 | with open(f"{output_dir}/cls{class_idx}_k{k}.json") as f: 72 | url_map = json.load(f) 73 | 74 | url_analysis(url_map) 75 | 76 | unique_urls = set() 77 | res = [] 78 | 79 | for name, ele in url_map.items(): 80 | for url in ele['url']: 81 | if url not in unique_urls: 82 | res.append({'url': url}) 83 | unique_urls.add(url) 84 | 85 | print('Unique urls:') 86 | print(len(unique_urls)) 87 | 88 | with open(f"{output_dir}/cls{class_idx}_k{k}_unique.json", mode="w") as f: 89 | json.dump(res, f) 90 | 91 | # image_folder = f"/ssd005/projects/diffusion_inversion/inversion_data/stl10/knn_retrieval/k{k}/class_{class_idx:03d}" 92 | # os.system( 93 | # f'img2dataset --input_format=json --url_list={output_dir}/cls{class_idx}_k{k}_unique.json --output_folder={image_folder} --thread_count=64 --image_size=96 --output_format=files') 94 | 95 | image_folder = f"/ssd005/projects/diffusion_inversion/inversion_data/stl10/knn_retrieval/k{k}_unresize/class_{class_idx:03d}" 96 | os.system( 97 | f'img2dataset --input_format=json --url_list={output_dir}/cls{class_idx}_k{k}_unique.json --output_folder={image_folder} --thread_count=64 --output_format=files') 98 | 99 | if __name__ == "__main__": 100 | fire.Fire(main) 101 | # python clip_retrieval_stl10.py --class_idx=0 --k=10 102 | # img2dataset --input_format=json --url_list=urls_10.json --output_folder=/ssd005/projects/diffusion_inversion/inversion_data/stl10/knn_retrieval/ --thread_count=64 --image_size=96 --output_format=files 103 | 104 | # https://github.com/rom1504/img2dataset 105 | -------------------------------------------------------------------------------- /src/script/compute_statistics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tqdm 4 | import fire 5 | import functools 6 | import ml_collections 7 | from absl import logging 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | import tensorflow_datasets as tfds 12 | 13 | 14 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 15 | from dataset import get_dataset # NOQA 16 | 17 | tf.config.experimental.set_visible_devices([], "GPU") 18 | 19 | 20 | def D(**kwargs): 21 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 22 | 23 | 24 | def center_crop(x, resolution): 25 | shape = tf.shape(x) 26 | h, w = shape[0], shape[1] 27 | size = tf.minimum(h, w) 28 | begin = tf.cast([h - size, w - size], tf.float32) / 2.0 29 | begin = tf.cast(begin, tf.int32) 30 | begin = tf.concat([begin, [0]], axis=0) # Add channel dimension. 31 | x = tf.slice(x, begin, [size, size, 3]) 32 | x = tf.image.resize_with_pad( 33 | x, resolution, resolution, method='area', antialias=True) 34 | return x 35 | 36 | 37 | def compute_channel_mean_std_ds(ds, img_shape, resolution=32, batch_size=1000): 38 | if None in img_shape: 39 | dim = resolution * resolution 40 | else: 41 | dim = functools.reduce(lambda x, y: x * y, img_shape[:-1], 1) 42 | 43 | # ds = ds.map(lambda x, y: tf.cast( 44 | # x, dtype='float32') / 255.0, tf.data.AUTOTUNE) 45 | ds = ds.map(lambda x, y: tf.cast( 46 | x, dtype='float32'), tf.data.AUTOTUNE) 47 | if None in img_shape: 48 | ds = ds.map(lambda x: center_crop(x, resolution), tf.data.AUTOTUNE) 49 | ds = ds.map(lambda x: tf.reshape( 50 | x, shape=(dim, img_shape[-1])), tf.data.AUTOTUNE) 51 | ds = ds.batch(batch_size=batch_size) 52 | ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE) 53 | 54 | mean = np.zeros(shape=(img_shape[-1],)) 55 | var = np.zeros(shape=(img_shape[-1],)) 56 | count = 0 57 | 58 | for x_batch in tqdm.tqdm(tfds.as_numpy(ds), desc='Compute mean with batch size: {}'.format(batch_size)): 59 | mean = mean + np.sum(x_batch, axis=(0, 1)) 60 | count += x_batch.shape[0] 61 | 62 | mean = 1.0 / (count * dim) * mean 63 | 64 | for x_batch in tqdm.tqdm(tfds.as_numpy(ds), desc='Compute variance with batch size: {}'.format(batch_size)): 65 | var = var + np.sum(np.square(x_batch - mean), axis=(0, 1)) 66 | 67 | std = np.sqrt(1.0 / (count * dim) * var) 68 | 69 | logging.info( 70 | 'Total number of data: {}, mean: {}, std: {}'.format(count, mean, std)) 71 | 72 | return mean, std 73 | 74 | 75 | def main(dataset_name, data_dir, resolution, batch_size): 76 | dataset_config = D( 77 | name=dataset_name, 78 | data_dir=data_dir) 79 | 80 | x_train, y_train = get_dataset( 81 | dataset_config, return_raw=True, resolution=resolution, train_only=True) 82 | 83 | ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) 84 | mean, std = compute_channel_mean_std_ds( 85 | ds, (resolution, resolution, 3), resolution=resolution, batch_size=batch_size) 86 | print('Mean: {}, Std: {}'.format(mean, std)) 87 | 88 | 89 | if __name__ == "__main__": 90 | fire.Fire(main) 91 | # python compute_statistics.py --dataset_name=imagenette --data_dir=$HOME/tensorflow_datasets --resolution=256 --batch_size=1000 -------------------------------------------------------------------------------- /src/script/convert_vae.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import fire 4 | import ml_collections 5 | import torch 6 | import tensorflow as tf 7 | import numpy as np 8 | import PIL 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | from torch.utils.data import Dataset 13 | 14 | from diffusers import AutoencoderKL 15 | 16 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 17 | from dataset import get_dataset # NOQA 18 | 19 | tf.config.experimental.set_visible_devices([], "GPU") 20 | 21 | 22 | PIL_INTERPOLATION = { 23 | "linear": PIL.Image.Resampling.BILINEAR, 24 | "bilinear": PIL.Image.Resampling.BILINEAR, 25 | "bicubic": PIL.Image.Resampling.BICUBIC, 26 | "lanczos": PIL.Image.Resampling.LANCZOS, 27 | "nearest": PIL.Image.Resampling.NEAREST, 28 | } 29 | 30 | 31 | def D(**kwargs): 32 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 33 | 34 | 35 | class EmbDataset(Dataset): 36 | def __init__(self, x, y, size=32, interpolation=Image.BICUBIC): 37 | self.x = x 38 | self.y = y 39 | self.size = size 40 | self.interpolation = interpolation 41 | 42 | def __len__(self): 43 | return self.x.shape[0] 44 | 45 | def __getitem__(self, idx): 46 | img = self.x[idx].astype(np.uint8) 47 | image = Image.fromarray(img) 48 | image = image.resize((self.size, self.size), 49 | resample=self.interpolation) 50 | image = np.array(image).astype(np.float32) 51 | image = (image / 127.5 - 1.0) 52 | image = torch.from_numpy(image).permute(2, 0, 1) 53 | 54 | return image, self.y[idx] 55 | 56 | 57 | def main(dataset_name, root_name, split='train', group_size=100, batch_size=50, sampling_resolution=128, 58 | interpolation='bicubic', save_resolution=32, num_classes=10): 59 | 60 | device = "cuda" 61 | dataset_config = D( 62 | name=dataset_name, 63 | data_dir='~/tensorflow_datasets' 64 | ) 65 | x_train, y_train, x_test, y_test = get_dataset( 66 | dataset_config, return_raw=True, train_only=False) 67 | 68 | root_name = os.path.join(root_name, f'{dataset_name}_{split}') 69 | model_id = "CompVis/stable-diffusion-v1-4" 70 | ae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device) 71 | 72 | config_name = f'{model_id.replace("/", "-")}_sample{sampling_resolution}_{interpolation}_save{save_resolution}' 73 | 74 | for i in range(num_classes): 75 | os.makedirs(f'{root_name}/class_{i:03d}/{config_name}', exist_ok=True) 76 | 77 | if split == 'train': 78 | dataset = EmbDataset(x_train, y_train, size=sampling_resolution, 79 | interpolation=PIL_INTERPOLATION[interpolation]) 80 | elif split == 'test': 81 | dataset = EmbDataset(x_test, y_test, size=sampling_resolution, 82 | interpolation=PIL_INTERPOLATION[interpolation]) 83 | else: 84 | raise ValueError(f'Unknown split {split}') 85 | 86 | dataloader = torch.utils.data.DataLoader( 87 | dataset, batch_size=batch_size, shuffle=False, num_workers=4) 88 | 89 | for i, (img, lb) in tqdm(enumerate(dataloader)): 90 | with torch.no_grad(): 91 | latents = ae.encode(img.to(device)).latent_dist.mode() 92 | image = ae.decode(latents).sample 93 | image = (image / 2 + 0.5).clamp(0, 1) 94 | images = image.cpu().permute(0, 2, 3, 1).float().numpy() 95 | images = (images * 255).round().astype("uint8") 96 | images = [Image.fromarray(image) for image in images] 97 | 98 | for idx, img in enumerate(images): 99 | if sampling_resolution != save_resolution: 100 | img = img.resize((save_resolution, save_resolution), 101 | resample=PIL_INTERPOLATION[interpolation]) 102 | img.save( 103 | f'{root_name}/class_{lb[idx]:03d}/{config_name}/group{(i*batch_size+idx)//group_size:02d}_sample{(i*batch_size+idx):05d}.png') 104 | 105 | 106 | if __name__ == '__main__': 107 | fire.Fire(main) 108 | # python convert_vae.py --dataset_name=cifar10 --root_name=$HOME/inversion_data --batch_size=100 --interpolation=bicubic --save_resolution=128 --num_classes=10 --split=train --sampling_resolution=128 109 | -------------------------------------------------------------------------------- /src/script/split_euro_dataset.py: -------------------------------------------------------------------------------- 1 | import fire 2 | 3 | import os 4 | import sys 5 | import json 6 | import shutil 7 | 8 | def main(): 9 | input_dir = "$HOME/tensorflow_datasets/downloads/manual/eurosat/2750" 10 | output_dir = "$HOME/tensorflow_datasets/eurosat" 11 | 12 | with open("$HOME/Project/diffusion_inversion/src/split_zhou_EuroSAT.json") as f: 13 | data = json.load(f) 14 | 15 | for k, v in data.items(): 16 | print(k, len(v)) 17 | if not os.path.exists(output_dir): 18 | os.makedirs('{}/{}'.format(output_dir, k)) 19 | 20 | for k, v in data.items(): 21 | for i in v: 22 | if not os.path.exists(os.path.join(output_dir, k, i[0].split('/')[0])): 23 | os.makedirs(os.path.join(output_dir, k, i[0].split('/')[0])) 24 | # print(os.path.join(input_dir, i[0]), os.path.join(output_dir, k, i[0])) 25 | shutil.copy(os.path.join(input_dir, i[0]), os.path.join(output_dir, k, i[0])) 26 | 27 | if __name__ == '__main__': 28 | fire.Fire(main) --------------------------------------------------------------------------------