├── spawrious ├── __init__.py ├── tf.py └── torch.py ├── .gitignore ├── overview.png ├── requirements.txt ├── twitter_gif_preview_m2m_HQ.gif ├── twitter_gif_preview_o2o_HQ.gif ├── setup.py ├── README.md ├── LICENSE ├── example.py └── generate_dataset.py /spawrious/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | *pycache* 3 | .idea/ 4 | .DS_Store 5 | 6 | -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aengusl/spawrious/HEAD/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow==9.5.0 2 | torch==1.13.1 3 | torchvision==0.14.1 4 | tqdm==4.64.1 5 | -------------------------------------------------------------------------------- /twitter_gif_preview_m2m_HQ.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aengusl/spawrious/HEAD/twitter_gif_preview_m2m_HQ.gif -------------------------------------------------------------------------------- /twitter_gif_preview_o2o_HQ.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aengusl/spawrious/HEAD/twitter_gif_preview_o2o_HQ.gif -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="spawrious", 5 | version="0.1.0", 6 | description="A package for generating random data", 7 | author="Aengus Lynch et al", 8 | author_email="aengus.lynch.17@ucl.ac.uk", 9 | url="https://github.com/aengusl/spawrious", 10 | packages=find_packages(), 11 | license="CC-BY-SA 4.0", 12 | classifiers=[ 13 | "Programming Language :: Python :: 3", 14 | "License :: OSI Approved :: Creative Commons Attribution-Share Alike 4.0 International License", 15 | "Operating System :: OS Independent", 16 | ], 17 | ) 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spawrious 2 | 3 | [Leaderboard results](https://aengusl.github.io/spawrious.github.io/) 4 | 5 | ## One-to-one Spurious Correlations 6 | ![gif](twitter_gif_preview_o2o_HQ.gif) 7 | 8 | ## Many-to-many Spurious Correlations 9 | ![gif](twitter_gif_preview_m2m_HQ.gif) 10 | 11 | Spawrious is a challenging OOD image classification benchmark ([link to paper](https://arxiv.org/abs/2303.05470)). It consists of 6 separate OOD challenges split into two types: one-to-one and many-to-many spurious correlation challenges. 12 | 13 | The dataset contains images of 4 dog breeds, found in 6 locations. The entire dataset consists of ~152,000 images, but each challenge only requires a subset of this. As a result, the repo allows users to only download the mimimal dataset required for a given spawrious challenge. 14 | 15 | ## Example script 16 | 17 | Datasets take the following names: 18 | - `entire_dataset` 19 | - `o2o_easy` 20 | - `o2o_medium` 21 | - `o2o_hard` 22 | - `m2m_easy` 23 | - `m2m_medium` 24 | - `m2m_hard` 25 | 26 | Running the command below retrieves the appropriate dataset at a user specified user directory (and downloads the dataset if not available), trains a [resnet18](https://pytorch.org/hub/pytorch_vision_resnet/), and evaluates the results on the OOD test set. 27 | 28 | ``` 29 | python example.py --data_dir --dataset 30 | ``` 31 | 32 | ## Installation 33 | ``` 34 | pip install git+https://github.com/aengusl/spawrious.git 35 | ``` 36 | 37 | ## HParams 38 | 39 | - [ResNet50](https://huggingface.co/datasets/aengusl/spawrious_resnet50_hparams_dict?row=0) 40 | - [ResNet18](https://huggingface.co/datasets/aengusl/spawrious_resnet18_hparams_dict) 41 | 42 | 43 | ## Using the datasets 44 | ```python 45 | from spawrious.torch import get_spawrious_dataset 46 | # spawrious.tf if using tensorflow or jax 47 | 48 | dataset = "m2m_medium" 49 | data_dir = ".data/" 50 | val_split = 0.2 51 | 52 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 53 | spawrious = get_spawrious_dataset(dataset_name=dataset, root_dir=data_dir) 54 | train_set = spawrious.get_train_dataset() 55 | test_set = spawrious.get_test_dataset() 56 | val_size = int(len(train_set) * val_split) 57 | train_set, val_set = torch.utils.data.random_split( 58 | train_set, [len(train_set) - val_size, val_size] 59 | ) 60 | ``` 61 | 62 | ### Click to download the datasets: 63 | - [entire_dataset](https://www.dropbox.com/s/e40j553480h3f3s/spawrious224.tar.gz?dl=1) 64 | - [one-to-one easy](https://www.dropbox.com/s/kwhiv60ihxe3owy/spawrious__o2o_easy.tar.gz?dl=1) 65 | - [one-to-one medium](https://www.dropbox.com/s/x03gkhdwar5kht4/spawrious224__o2o_medium.tar.gz?dl=1) 66 | - [one-to-one hard](https://www.dropbox.com/s/p1ry121m2gjj158/spawrious__o2o_hard.tar.gz?dl=1) 67 | - [many-to-many (all)](https://www.dropbox.com/s/5usem63nfub266y/spawrious__m2m.tar.gz?dl=1) 68 | 69 | ## Generate your own data 70 | 71 | If you want to generate your own data, or understand how we generated ours, take a look at `generate_dataset.py`. To run this file, you additionally need to install `diffusers` and `transformers`. 72 | 73 | ## Citation 74 | 75 | ``` 76 | @misc{lynch2023spawrious, 77 | title={Spawrious: A Benchmark for Fine Control of Spurious Correlation Biases}, 78 | author={Aengus Lynch and Gbètondji J-S Dovonon and Jean Kaddour and Ricardo Silva}, 79 | year={2023}, 80 | eprint={2303.05470}, 81 | archivePrefix={arXiv}, 82 | primaryClass={cs.CV} 83 | } 84 | ``` 85 | ## Licensing 86 | 87 | Shield: [![CC BY 4.0][cc-by-shield]][cc-by] 88 | 89 | This work is licensed under a 90 | [Creative Commons Attribution 4.0 International License][cc-by]. 91 | 92 | [![CC BY 4.0][cc-by-image]][cc-by] 93 | 94 | [cc-by]: http://creativecommons.org/licenses/by/4.0/ 95 | [cc-by-image]: https://i.creativecommons.org/l/by/4.0/88x31.png 96 | [cc-by-shield]: https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg 97 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. 122 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.optim as optim 5 | from torch import nn 6 | from torch.nn import Module 7 | from torch.optim import Optimizer 8 | from torch.utils.data import DataLoader 9 | from torchvision import models 10 | from tqdm import tqdm 11 | from tqdm.auto import tqdm 12 | import timm 13 | from spawrious.torch import get_spawrious_dataset 14 | import wandb 15 | 16 | # # MODEL_NAME = "vit_so400m_patch14_siglip_384" 17 | # # MODEL_NAME = 'swin_base_patch4_window7_224.ms_in22k_ft_in1k' 18 | # MODEL_NAME = 'deit3_base_patch16_224.fb_in22k_ft_in1k' 19 | # from spawrious.torch import MODEL_NAME 20 | from spawrious.torch import set_model_name 21 | # MODEL_NAME = 'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k' 22 | # set_model_name(MODEL_NAME) 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser( 26 | description="Train a ResNet18 on the Spawrious O2O-easy dataset." 27 | ) 28 | parser.add_argument( 29 | "--dataset", 30 | type=str, 31 | default="m2m_hard", 32 | help="name of the dataset", 33 | choices=[ 34 | "o2o_easy", 35 | "o2o_medium", 36 | "o2o_hard", 37 | "m2m_easy", 38 | "m2m_medium", 39 | "m2m_hard", 40 | ], 41 | ) 42 | parser.add_argument( 43 | "--val_split", 44 | type=float, 45 | default=0.1, 46 | ) 47 | 48 | parser.add_argument( 49 | "--data_dir", type=str, default="./data/", help="path to the dataset directory" 50 | ) 51 | parser.add_argument("--batch_size", type=int, default=128, help="batch size") 52 | parser.add_argument( 53 | "--num_workers", type=int, default=2, help="number of workers for data loading" 54 | ) 55 | parser.add_argument("--lr", type=float, default=0.001, help="learning rate") 56 | parser.add_argument("--momentum", type=float, default=0.9, help="momentum") 57 | parser.add_argument("--num_epochs", type=int, default=3, help="number of epochs") 58 | parser.add_argument("--model", type=str, default="siglip", help="model name") 59 | return parser.parse_args() 60 | 61 | 62 | def train( 63 | model: Module, 64 | train_loader: DataLoader, 65 | val_loader: DataLoader, 66 | optimizer: Optimizer, 67 | criterion: Module, 68 | num_epochs: int, 69 | device: torch.device, 70 | ) -> None: 71 | for epoch in tqdm(range(num_epochs), desc="Training. Epochs", leave=False): 72 | running_loss = 0.0 73 | for inputs, labels, _ in tqdm(train_loader): # third item is the location label 74 | inputs, labels = inputs.to(device), labels.to(device) 75 | optimizer.zero_grad() 76 | outputs = model(inputs) 77 | loss = criterion(outputs, labels) 78 | loss.backward() 79 | optimizer.step() 80 | running_loss += loss.item() 81 | print( 82 | f"Epoch {epoch + 1}: Training Loss: {running_loss / len(train_loader):.3f}" 83 | ) 84 | print("Evaluating on validation set...") 85 | val_acc = evaluate(model, val_loader, device) 86 | wandb.log( 87 | {"train_loss": running_loss / len(train_loader), "val_acc": val_acc}, 88 | step=epoch, 89 | ) 90 | 91 | 92 | def evaluate(model: Module, loader: DataLoader, device: torch.device) -> float: 93 | correct = 0 94 | total = 0 95 | with torch.no_grad(): 96 | for inputs, labels, _ in tqdm( 97 | loader, desc="Evaluating", leave=False 98 | ): # third item is the location label 99 | inputs, labels = inputs.to(device), labels.to(device) 100 | outputs = model(inputs) 101 | _, predicted = torch.max(outputs.data, 1) 102 | total += labels.size(0) 103 | correct += (predicted == labels).sum().item() 104 | acc = 100 * correct / total 105 | print(f"Acc: {acc:.3f}%") 106 | return acc 107 | 108 | 109 | class ClassifierOnTop(nn.Module): 110 | def __init__(self, num_classes): 111 | super().__init__() 112 | self.backbone = timm.create_model( 113 | # "vit_so400m_patch14_siglip_384", 114 | MODEL_NAME, 115 | pretrained=True, 116 | num_classes=0, 117 | ).eval() 118 | self.linear = nn.Linear(1152, num_classes) 119 | if MODEL_NAME == 'swin_base_patch4_window7_224.ms_in22k_ft_in1k': 120 | self.linear = nn.Linear(1024, num_classes) 121 | elif MODEL_NAME == 'deit3_base_patch16_224.fb_in22k_ft_in1k': 122 | self.linear = nn.Linear(768, num_classes) 123 | elif MODEL_NAME == 'beit_base_patch16_224.in22k_ft_in22k_in1k': 124 | self.linear = nn.Linear(768, num_classes) 125 | elif MODEL_NAME == 'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k': 126 | self.linear = nn.Linear(768, num_classes) 127 | elif MODEL_NAME == 'levit_128s.fb_dist_in1k': 128 | self.linear = nn.Linear(384, num_classes) 129 | 130 | def forward(self, x): 131 | with torch.no_grad(): 132 | x = self.backbone(x) 133 | return self.linear(x) 134 | 135 | 136 | def get_model(args: argparse.Namespace) -> Module: 137 | if args.model == "siglip": 138 | model = ClassifierOnTop(num_classes=4) 139 | else: 140 | model = models.resnet18(pretrained=True) 141 | model.fc = torch.nn.Linear(512, 4) 142 | return model 143 | 144 | 145 | def main(dataset) -> None: 146 | args = parse_args() 147 | experiment_name = f"{dataset}_{MODEL_NAME.split('_')[0]}-e={args.num_epochs}-lr={args.lr}" 148 | wandb.init(project="spawrious", name=experiment_name, config=args) 149 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 150 | spawrious = get_spawrious_dataset(dataset_name=args.dataset, root_dir=args.data_dir) 151 | train_set = spawrious.get_train_dataset() 152 | test_set = spawrious.get_test_dataset() 153 | val_size = int(len(train_set) * args.val_split) 154 | train_set, val_set = torch.utils.data.random_split( 155 | train_set, [len(train_set) - val_size, val_size] 156 | ) 157 | train_loader = torch.utils.data.DataLoader( 158 | train_set, 159 | batch_size=args.batch_size, 160 | shuffle=True, 161 | num_workers=args.num_workers, 162 | ) 163 | val_loader = torch.utils.data.DataLoader( 164 | val_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers 165 | ) 166 | test_loader = torch.utils.data.DataLoader( 167 | test_set, 168 | batch_size=args.batch_size, 169 | shuffle=False, 170 | num_workers=args.num_workers, 171 | ) 172 | 173 | model = get_model(args) 174 | model.to(device) 175 | 176 | criterion = nn.CrossEntropyLoss() 177 | optimizer = optim.AdamW(model.parameters(), lr=args.lr) 178 | train( 179 | model, 180 | train_loader, 181 | val_loader, 182 | optimizer, 183 | criterion, 184 | args.num_epochs, 185 | device, 186 | ) 187 | print("Finished training, now evaluating on test set.") 188 | torch.save(model.state_dict(), f"{experiment_name}.pt") 189 | test_acc = evaluate(model, test_loader, device) 190 | wandb.log({"final_test_acc": test_acc}, step=args.num_epochs) 191 | 192 | 193 | if __name__ == "__main__": 194 | dataset_choices = [ 195 | "o2o_easy", 196 | "o2o_medium", 197 | "o2o_hard", 198 | "m2m_easy", 199 | "m2m_medium", 200 | "m2m_hard", 201 | ] 202 | # MODEL_NAME = "vit_so400m_patch14_siglip_384" 203 | # MODEL_NAME = 'swin_base_patch4_window7_224.ms_in22k_ft_in1k' 204 | # MODEL_NAME = 'deit3_base_patch16_224.fb_in22k_ft_in1k' 205 | # MODEL_NAME = 'beit_base_patch16_224.in22k_ft_in22k_in1k' 206 | # MODEL_NAME = 'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k' 207 | # MODEL_NAME = 'levit_128s.fb_dist_in1k' 208 | model_name_choices = [ 209 | # 'vit_so400m_patch14_siglip_384', 210 | 'swin_base_patch4_window7_224.ms_in22k_ft_in1k', 211 | 'deit3_base_patch16_224.fb_in22k_ft_in1k', 212 | 'beit_base_patch16_224.in22k_ft_in22k_in1k', 213 | # 'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k', 214 | 'levit_128s.fb_dist_in1k', 215 | ] 216 | for dataset in dataset_choices: 217 | for model_name in model_name_choices: 218 | 219 | MODEL_NAME=model_name 220 | set_model_name(MODEL_NAME) 221 | 222 | main(dataset) 223 | -------------------------------------------------------------------------------- /generate_dataset.py: -------------------------------------------------------------------------------- 1 | # Import statements 2 | from ast import parse 3 | import os 4 | from diffusers import StableDiffusionPipeline 5 | from tqdm import tqdm 6 | import itertools 7 | from datetime import datetime 8 | import argparse 9 | from ml_collections import ConfigDict 10 | import torch 11 | import random 12 | import numpy as np 13 | from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer 14 | from PIL import Image 15 | 16 | # Get variables from command line 17 | def get_config(): 18 | # Define function to convert string to boolean 19 | def str2bool(v): 20 | if isinstance(v, bool): 21 | return v 22 | if v.lower() in ("yes", "true", "t", "y", "1"): 23 | return True 24 | elif v.lower() in ("no", "false", "f", "n", "0"): 25 | return False 26 | else: 27 | raise argparse.ArgumentTypeError("Boolean value expected.") 28 | 29 | # Define argument parser 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--global_save_label", type=str, required=True, help="path to dataset folder") 32 | parser.add_argument("--batch_size", type=int, default=3, help="Batch size") 33 | parser.add_argument("--minibatch_size", type=int, default=4, help="Minibatch size") 34 | parser.add_argument("--device", type=str, default="cuda", help="Device to use (cuda or cpu)") 35 | parser.add_argument("--seed", type=int, default=1, help="Random seed") 36 | parser.add_argument("--num_iters", type=int, default=1, help="Number of balanced datasets to generate") 37 | parser.add_argument("--machine_name", type=str, default="default", help="Machine name - a label to distinguish between different machines") 38 | parser.add_argument("--animals_to_generate", type=str, nargs="+", default='all', help="List of animals to generate") 39 | parser.add_argument("--locations_to_generate", type=str, nargs="+", default='all', help="List of locations to generate") 40 | parser.add_argument("--locations_to_avoid", type=str, nargs="+", default=[], help="List of locations to avoid") 41 | 42 | config = ConfigDict(vars(parser.parse_args())) 43 | return config 44 | 45 | # Define all necessary functions 46 | def set_seed(seed=1): 47 | random.seed(seed) 48 | np.random.seed(seed) 49 | torch.manual_seed(seed) 50 | torch.cuda.manual_seed(seed) 51 | torch.backends.cudnn.deterministic = True 52 | torch.backends.cudnn.benchmark = False 53 | 54 | model = VisionEncoderDecoderModel.from_pretrained( 55 | "nlpconnect/vit-gpt2-image-captioning" 56 | ) 57 | feature_extractor = ViTFeatureExtractor.from_pretrained( 58 | "nlpconnect/vit-gpt2-image-captioning" 59 | ) 60 | tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") 61 | 62 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 63 | model.to(device) 64 | max_length = 16 65 | num_beams = 4 66 | gen_kwargs = {"max_length": max_length, "num_beams": num_beams} 67 | 68 | config = get_config() 69 | global_save_label = config.global_save_label 70 | batch_size = config.batch_size 71 | minibatch_size = config.minibatch_size 72 | device = config.device 73 | seed = config.seed 74 | num_iters = config.num_iters 75 | machine_name = config.machine_name 76 | 77 | animals_to_generate = config.animals_to_generate 78 | locations_to_generate = config.locations_to_generate 79 | locations_to_avoid = config.locations_to_avoid 80 | 81 | now = datetime.now() 82 | begin_exp_time = now.strftime("%d%b_%H%M%S") 83 | 84 | # The model 85 | pipe = StableDiffusionPipeline.from_pretrained( 86 | "CompVis/stable-diffusion-v1-4", use_auth_token=True 87 | ) 88 | pipe = pipe.to(device) 89 | 90 | def predict_step(image_paths: list[str]) -> list[str]: 91 | images = [] 92 | for image_path in image_paths: 93 | i_image = Image.open(image_path) 94 | if i_image.mode != "RGB": 95 | i_image = i_image.convert(mode="RGB") 96 | 97 | images.append(i_image) 98 | 99 | pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values 100 | pixel_values = pixel_values.to(device) 101 | 102 | output_ids = model.generate(pixel_values, **gen_kwargs) 103 | 104 | preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) 105 | preds = [pred.strip() for pred in preds] 106 | return preds 107 | 108 | def caption_from_images(images: list) -> list[str]: 109 | """ 110 | Generate captions from a list of images 111 | """ 112 | pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values 113 | pixel_values = pixel_values.to(device) 114 | 115 | output_ids = model.generate(pixel_values, **gen_kwargs) 116 | 117 | preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) 118 | preds = [pred.strip() for pred in preds] 119 | return preds 120 | 121 | def dirty_image_keyword_filter(images: list, keywords: list) -> bool: 122 | """ 123 | Filter images by whether the caption contains the keywords for the object and background 124 | """ 125 | dirty_bool = False 126 | preds = caption_from_images(images) 127 | for caption in preds: 128 | caption_words = caption.strip().split(" ") 129 | if not set(keywords) & set(caption_words): 130 | dirty_bool = True 131 | break 132 | return dirty_bool 133 | 134 | def generate_batch(prompt: str, save_label: str, keywords: list = ['dog'], negative_prompt: str = "human, blurry, painting, cartoon, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, two, multiple", num_inference_steps: int = 150, batch_size: int = 3, minibatch_size: int = 4, additional_label: str = None): 135 | """ 136 | Generate a batch of images from a prompt. 137 | Save the images to a folder specified by the save label, under the format 138 | /_.png 139 | """ 140 | print('Generating images with prompt: \n', prompt) 141 | os.makedirs(save_label, exist_ok=True) 142 | prompt_list = [prompt] * minibatch_size 143 | negative_prompt_list = [negative_prompt] * minibatch_size 144 | batch_count = 0 145 | cleaning_total = 0 146 | with tqdm(total=batch_size) as pbar: 147 | while batch_count < batch_size: 148 | # Generate a batch of images 149 | output = pipe( 150 | prompt_list, 151 | negative_prompt=negative_prompt_list, 152 | num_inference_steps=num_inference_steps, 153 | ) 154 | images = output.images 155 | 156 | # Filter out dirty images 157 | dirty_bool = dirty_image_keyword_filter(images, keywords) 158 | nsfw_bool = sum(output.nsfw_content_detected) > 0 #True if any images are nsfw 159 | if dirty_bool or nsfw_bool: 160 | print('Bad images detected: \n', 'dirty_bool:', dirty_bool, ', nsfw_bool:', nsfw_bool) 161 | if dirty_bool: 162 | cleaning_total += 1 163 | continue 164 | 165 | # Save the images 166 | for idx, image in enumerate(images): 167 | save_path = os.path.join(save_label, f"{machine_name}_{batch_count+idx}.png") 168 | if additional_label is not None: 169 | save_path = os.path.join(save_label, f"{machine_name}_{additional_label}_{batch_count+idx}.png") 170 | image.save(save_path, format="png") 171 | pbar.update(1) 172 | batch_count += len(images) 173 | return cleaning_total 174 | 175 | 176 | """ 177 | Create prompt list dictionary of form: 178 | {'animal-background': [prompt1,..,prompD]} 179 | """ 180 | 181 | animal_list = [ 182 | "labrador", 183 | "welsh corgi dog", 184 | "bulldog", 185 | "dachshund", 186 | ] 187 | one_word_animal_list = [ 188 | "labrador", 189 | "corgi", 190 | "bulldog", 191 | "dachshund", 192 | ] 193 | animal_dict = { 194 | "labrador": "labrador", 195 | "corgi": "welsh corgi dog", 196 | "bulldog": "bulldog", 197 | "dachshund": "dachshund" 198 | } 199 | location_list = [ 200 | 'in a jungle', 201 | 'on a rocky mountain', 202 | 'in a hot, dry desert with cactuses around', 203 | 'in a park, with puddles, bushes and dirt in the background', 204 | 'playing fetch on a beach with a pier and ocean in the background', 205 | 'in a snowy landscape with a cabin and a snowball in the background', 206 | ] 207 | one_word_location_list = [ 208 | 'jungle', 209 | 'mountain', 210 | 'desert', 211 | 'dirt', 212 | 'beach', 213 | 'snow', 214 | ] 215 | location_dict = { 216 | 'jungle': 'in a jungle', 217 | 'mountain': 'on a rocky mountain', 218 | 'desert': 'in a hot, dry desert with cactuses around', 219 | 'dirt': 'in a park, with puddles, bushes and dirt in the background', 220 | 'beach': 'playing fetch on a beach with a pier and ocean in the background', 221 | 'snow': 'in a snowy landscape with a cabin and a snowball in the background', 222 | } 223 | fur_list = [ 224 | "black", 225 | "brown", 226 | "white", 227 | "", 228 | ] 229 | pose_list = [ 230 | "sitting", 231 | "", 232 | "running", 233 | ] 234 | tod_list = [ 235 | "pale sunrise", 236 | "sunset", 237 | "rainy day", 238 | "foggy day", 239 | "bright sunny day", 240 | "bright sunny day", 241 | ] 242 | prompt_template = "(((one {fur} {animal} {pose}))) {location}, {tod}. highly detailed, with cinematic lighting, 4k resolution, beautiful composition, hyperrealistic, trending, cinematic, masterpiece, close up" 243 | 244 | assert animals_to_generate in one_word_animal_list or animals_to_generate == 'all' 245 | if animals_to_generate == 'all': 246 | pass 247 | else: 248 | one_word_animal_list = [animals_to_generate] 249 | 250 | assert locations_to_generate in one_word_location_list or locations_to_generate == 'all' 251 | if locations_to_generate == 'all': 252 | pass 253 | else: 254 | one_word_location_list = [locations_to_generate] 255 | 256 | if locations_to_avoid != 'None': 257 | for loc in locations_to_avoid: 258 | one_word_location_list.remove(loc) 259 | 260 | prompt_list_dict = {} 261 | for animal_word in one_word_animal_list: 262 | for location_word in one_word_location_list: 263 | animal = animal_dict[animal_word] 264 | location = location_dict[location_word] 265 | prompt_list_dict[f'{animal_word}-{location_word}'] = [] 266 | for fur in fur_list: 267 | for pose in pose_list: 268 | for tod in tod_list: 269 | prompt = prompt_template.format(fur=fur, animal=animal, pose=pose, location=location, tod=tod) 270 | prompt_list_dict[f'{animal_word}-{location_word}'].append(prompt) 271 | 272 | # %% 273 | """ 274 | Generate a mini dataset with samples from each prompt 275 | """ 276 | for iteration in tqdm(range(num_iters)): 277 | print('\n\n\n\n\n\nIteration:', iteration, '\n\n\n\n\n\n') 278 | cleaning_total = 0 279 | for animal_loc in tqdm(prompt_list_dict.keys()): 280 | print('\n\n\nAnimal-Location:', animal_loc, '\n\n\n') 281 | prompt_count = 0 282 | for prompt in prompt_list_dict[animal_loc]: 283 | animal_str = animal_loc.split('-')[0] 284 | location_str = animal_loc.split('-')[1] 285 | save_label = os.path.join(global_save_label.format(iteration), f"{location_str}", f"{animal_str}") 286 | os.makedirs(save_label, exist_ok=True) 287 | cleaning_total += generate_batch( 288 | prompt=prompt, 289 | save_label=save_label, 290 | keywords = ['dog'], 291 | batch_size=batch_size, 292 | minibatch_size=minibatch_size, 293 | num_inference_steps=100, 294 | additional_label = f"prompt_{prompt_count}" 295 | ) 296 | prompt_count += 1 -------------------------------------------------------------------------------- /spawrious/tf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tarfile 3 | import urllib 4 | import urllib.request 5 | from typing import Any, Tuple 6 | from tqdm import tqdm 7 | 8 | import tensorflow as tf 9 | from PIL import Image 10 | from tensorflow.keras.preprocessing.image import ImageDataGenerator 11 | from tensorflow.keras.preprocessing import image_dataset_from_directory 12 | 13 | def _extract_dataset_from_tar( 14 | tar_file_name: str, data_dir: str, remove_tar_after_extracting: bool = True 15 | ) -> None: 16 | tar_file_dst = os.path.join(data_dir, tar_file_name) 17 | print("Extracting dataset...") 18 | tar = tarfile.open(tar_file_dst, "r:gz") 19 | tar.extractall(os.path.dirname(tar_file_dst)) 20 | tar.close() 21 | print("Dataset extracted. Delete tar file.") 22 | if remove_tar_after_extracting: 23 | os.remove(tar_file_dst) 24 | 25 | 26 | def _download_dataset_if_not_available( 27 | dataset_name: str, data_dir: str, remove_tar_after_extracting: bool = True 28 | ) -> None: 29 | """ 30 | datasets.txt file, which is present in the data_dir, is used to check if the dataset is already extracted. If the dataset is already extracted, then the tar file is not downloaded again. 31 | """ 32 | data_dir = data_dir.split("/spawrious224/")[0] # in case people pass in the wrong root_dir 33 | os.makedirs(data_dir, exist_ok=True) 34 | dataset_name = dataset_name.lower() 35 | if dataset_name.split("_")[0] == "m2m": 36 | dataset_name = "entire_dataset" 37 | url_dict = { 38 | "entire_dataset": "https://www.dropbox.com/s/hofkueo8qvaqlp3/spawrious224__entire_dataset.tar.gz?dl=1", 39 | "o2o_easy": "https://www.dropbox.com/s/kwhiv60ihxe3owy/spawrious224__o2o_easy.tar.gz?dl=1", 40 | "o2o_medium": "https://www.dropbox.com/s/x03gkhdwar5kht4/spawrious224__o2o_medium.tar.gz?dl=1", 41 | "o2o_hard": "https://www.dropbox.com/s/p1ry121m2gjj158/spawrious224__o2o_hard.tar.gz?dl=1", 42 | # "m2m": "https://www.dropbox.com/s/5usem63nfub266y/spawrious__m2m.tar.gz?dl=1", 43 | } 44 | tar_file_name = f"spawrious224__{dataset_name}.tar.gz" 45 | tar_file_dst = os.path.join(data_dir, tar_file_name) 46 | url = url_dict[dataset_name] 47 | 48 | # check if the dataset is already extracted 49 | if _check_images_availability(data_dir, dataset_name): 50 | print("Dataset already downloaded and extracted.") 51 | return 52 | # check if the tar file is already downloaded 53 | else: 54 | if os.path.exists(tar_file_dst): 55 | print("Dataset already downloaded. Extracting...") 56 | _extract_dataset_from_tar( 57 | tar_file_name, data_dir, remove_tar_after_extracting 58 | ) 59 | return 60 | # download the tar file and extract from it 61 | else: 62 | print('Dataset not found. Downloading...') 63 | response = urllib.request.urlopen(url) 64 | total_size = int(response.headers.get("Content-Length", 0)) 65 | block_size = 1024 66 | # Track progress of download 67 | progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True) 68 | with open(tar_file_dst, "wb") as f: 69 | while True: 70 | buffer = response.read(block_size) 71 | if not buffer: 72 | break 73 | f.write(buffer) 74 | progress_bar.update(len(buffer)) 75 | progress_bar.close() 76 | print("Dataset downloaded. Extracting...") 77 | _extract_dataset_from_tar( 78 | tar_file_name, data_dir, remove_tar_after_extracting 79 | ) 80 | return 81 | 82 | class ConcatDataset(tf.data.Dataset): 83 | def __init__(self, dataset_list): 84 | self.dataset_list = dataset_list 85 | self.total_len = sum([len(ds) for ds in dataset_list]) 86 | 87 | def __len__(self): 88 | return self.total_len 89 | 90 | def __getitem__(self, index): 91 | for ds in self.dataset_list: 92 | if index < len(ds): 93 | return ds[index] 94 | index -= len(ds) 95 | raise IndexError("Index out of range") 96 | 97 | class CustomImageFolder(tf.keras.utils.Sequence): 98 | """ 99 | A class that takes one folder at a time and loads a set number of images in a folder and assigns them a specific class 100 | """ 101 | 102 | def __init__(self, folder_path, class_index, location_index, limit=None, preprocess_func=None): 103 | self.folder_path = folder_path 104 | self.class_index = class_index 105 | self.location_index = location_index 106 | self.image_paths = [ 107 | os.path.join(folder_path, img) 108 | for img in os.listdir(folder_path) 109 | if img.endswith((".png", ".jpg", ".jpeg")) 110 | ] 111 | if limit: 112 | self.image_paths = self.image_paths[:limit] 113 | self.preprocess_func = preprocess_func 114 | 115 | def __len__(self): 116 | return len(self.image_paths) 117 | 118 | def __getitem__(self, index: int) -> Tuple[Any, Any, Any]: 119 | img_path = self.image_paths[index] 120 | img = Image.open(img_path).convert("RGB") 121 | img = tf.keras.preprocessing.image.img_to_array(img) 122 | 123 | if self.preprocess_func: 124 | img = self.preprocess_func(img) 125 | 126 | class_label = tf.convert_to_tensor(self.class_index, dtype=tf.int64) 127 | location_label = tf.convert_to_tensor(self.location_index, dtype=tf.int64) 128 | return img, class_label, location_label 129 | 130 | 131 | class MultipleDomainDataset: 132 | N_STEPS = 5001 # Default, subclasses may override 133 | CHECKPOINT_FREQ = 100 # Default, subclasses may override 134 | N_WORKERS = 8 # Default, subclasses may override 135 | ENVIRONMENTS = None # Subclasses should override 136 | INPUT_SHAPE = None # Subclasses should override 137 | 138 | def __getitem__(self, index): 139 | return self.datasets[index] 140 | 141 | def __len__(self): 142 | return len(self.datasets) 143 | 144 | 145 | def build_combination(benchmark_type, group, test, filler=None): 146 | total = 3168 147 | combinations = {} 148 | if "m2m" in benchmark_type: 149 | counts = [total, total] 150 | combinations["train_combinations"] = { 151 | ("bulldog",): [(group[0], counts[0]), (group[1], counts[1])], 152 | ("dachshund",): [(group[1], counts[0]), (group[0], counts[1])], 153 | ("labrador",): [(group[2], counts[0]), (group[3], counts[1])], 154 | ("corgi",): [(group[3], counts[0]), (group[2], counts[1])], 155 | } 156 | combinations["test_combinations"] = { 157 | ("bulldog",): [test[0], test[1]], 158 | ("dachshund",): [test[1], test[0]], 159 | ("labrador",): [test[2], test[3]], 160 | ("corgi",): [test[3], test[2]], 161 | } 162 | else: 163 | counts = [int(0.97 * total), int(0.87 * total)] 164 | combinations["train_combinations"] = { 165 | ("bulldog",): [(group[0], counts[0]), (group[0], counts[1])], 166 | ("dachshund",): [(group[1], counts[0]), (group[1], counts[1])], 167 | ("labrador",): [(group[2], counts[0]), (group[2], counts[1])], 168 | ("corgi",): [(group[3], counts[0]), (group[3], counts[1])], 169 | ("bulldog", "dachshund", "labrador", "corgi"): [ 170 | (filler, total - counts[0]), 171 | (filler, total - counts[1]), 172 | ], 173 | } 174 | combinations["test_combinations"] = { 175 | ("bulldog",): [test[0], test[0]], 176 | ("dachshund",): [test[1], test[1]], 177 | ("labrador",): [test[2], test[2]], 178 | ("corgi",): [test[3], test[3]], 179 | } 180 | return combinations 181 | 182 | 183 | def get_combinations(benchmark_type: str) -> Tuple[dict, dict]: 184 | combinations = { 185 | "o2o_easy": ( 186 | ["desert", "jungle", "dirt", "snow"], 187 | ["dirt", "snow", "desert", "jungle"], 188 | "beach", 189 | ), 190 | "o2o_medium": ( 191 | ["mountain", "beach", "dirt", "jungle"], 192 | ["jungle", "dirt", "beach", "snow"], 193 | "desert", 194 | ), 195 | "o2o_hard": ( 196 | ["jungle", "mountain", "snow", "desert"], 197 | ["mountain", "snow", "desert", "jungle"], 198 | "beach", 199 | ), 200 | "m2m_hard": ( 201 | ["dirt", "jungle", "snow", "beach"], 202 | ["snow", "beach", "dirt", "jungle"], 203 | None, 204 | ), 205 | "m2m_easy": ( 206 | ["desert", "mountain", "dirt", "jungle"], 207 | ["dirt", "jungle", "mountain", "desert"], 208 | None, 209 | ), 210 | "m2m_medium": ( 211 | ["beach", "snow", "mountain", "desert"], 212 | ["desert", "mountain", "beach", "snow"], 213 | None, 214 | ), 215 | } 216 | if benchmark_type not in combinations: 217 | raise ValueError("Invalid benchmark type") 218 | group, test, filler = combinations[benchmark_type] 219 | return build_combination(benchmark_type, group, test, filler) 220 | 221 | 222 | class SpawriousBenchmark(MultipleDomainDataset): 223 | ENVIRONMENTS = ["Test", "SC_group_1", "SC_group_2"] 224 | input_shape = (3, 224, 224) 225 | num_classes = 4 226 | class_list = ["bulldog", "corgi", "dachshund", "labrador"] 227 | locations_list = ["desert", "jungle", "dirt", "mountain", "snow", "beach"] 228 | 229 | def __init__(self, benchmark, root_dir, augment=True): 230 | combinations = get_combinations(benchmark.lower()) 231 | self.type1 = benchmark.lower().startswith("o2o") 232 | train_datasets, test_datasets = self._prepare_data_lists( 233 | combinations["train_combinations"], 234 | combinations["test_combinations"], 235 | root_dir, 236 | augment, 237 | ) 238 | self.datasets = [ConcatDataset(test_datasets)] + train_datasets 239 | 240 | # Prepares the train and test data lists by applying the necessary transformations. 241 | def _prepare_data_lists( 242 | self, train_combinations, test_combinations, root_dir, augment 243 | ): 244 | preprocess_input = tf.keras.applications.resnet.preprocess_input 245 | 246 | if augment: 247 | train_transforms = ImageDataGenerator( 248 | preprocessing_function=preprocess_input, 249 | horizontal_flip=True, 250 | width_shift_range=0.1, 251 | height_shift_range=0.1, 252 | zoom_range=0.2, 253 | ) 254 | else: 255 | train_transforms = ImageDataGenerator( 256 | preprocessing_function=preprocess_input, 257 | ) 258 | 259 | test_transforms = ImageDataGenerator( 260 | preprocessing_function=preprocess_input, 261 | ) 262 | 263 | train_data_list = self._create_data_list( 264 | train_combinations, root_dir, train_transforms 265 | ) 266 | test_data_list = self._create_data_list( 267 | test_combinations, root_dir, test_transforms 268 | ) 269 | 270 | return train_data_list, test_data_list 271 | 272 | # Creates a list of datasets based on the given combinations and transformations. 273 | def _create_data_list(self, combinations, root_dir, transforms): 274 | data_list = [] 275 | if isinstance(combinations, dict): 276 | 277 | # Build class groups for a given set of combinations, root directory, and transformations. 278 | for_each_class_group = [] 279 | cg_index = 0 280 | for classes, comb_list in combinations.items(): 281 | for_each_class_group.append([]) 282 | for ind, location_limit in enumerate(comb_list): 283 | if isinstance(location_limit, tuple): 284 | location, limit = location_limit 285 | else: 286 | location, limit = location_limit, None 287 | cg_data_list = [] 288 | for cls in classes: 289 | path = os.path.join( 290 | root_dir, f"{0 if not self.type1 else ind}/{location}/{cls}" 291 | ) 292 | data = CustomImageFolder( 293 | folder_path=path, 294 | class_index=self.class_list.index(cls), 295 | location_index=self.locations_list.index(location), 296 | limit=limit, 297 | transform=transforms, 298 | ) 299 | cg_data_list.append(data) 300 | 301 | for_each_class_group[cg_index].append(ConcatDataset(cg_data_list)) 302 | cg_index += 1 303 | 304 | for group in range(len(for_each_class_group[0])): 305 | data_list.append( 306 | ConcatDataset( 307 | [ 308 | for_each_class_group[k][group] 309 | for k in range(len(for_each_class_group)) 310 | ] 311 | ) 312 | ) 313 | else: 314 | for location in combinations: 315 | path = os.path.join(root_dir, f"{0}/{location}/") 316 | data = image_dataset_from_directory( 317 | directory=path, 318 | labels="inferred", 319 | label_mode="categorical", 320 | class_names=None, 321 | color_mode="rgb", 322 | batch_size=32, 323 | image_size=(224, 224), 324 | shuffle=True, 325 | seed=None, 326 | validation_split=None, 327 | subset=None, 328 | interpolation="bilinear", 329 | follow_links=False, 330 | ) 331 | data_list.append(data) 332 | 333 | return data_list 334 | 335 | def _check_images_availability(root_dir: str, dataset_type: str) -> bool: 336 | # Get the combinations for the given dataset type 337 | root_dir = root_dir.split("/spawrious224/")[0] # in case people pass in the wrong root_dir 338 | combinations = get_combinations(dataset_type.lower()) 339 | 340 | # Extract the train and test combinations 341 | train_combinations = combinations["train_combinations"] 342 | test_combinations = combinations["test_combinations"] 343 | 344 | # Check if the relevant images for each combination are present in the root directory 345 | for combination in [train_combinations, test_combinations]: 346 | for classes, comb_list in combination.items(): 347 | for ind, location_limit in enumerate(comb_list): 348 | if isinstance(location_limit, tuple): 349 | location, limit = location_limit 350 | else: 351 | location, limit = location_limit, None 352 | 353 | for cls in classes: 354 | path = os.path.join( 355 | root_dir, 356 | "spawrious224", 357 | f"{0 if not dataset_type.lower().startswith('o2o') else ind}/{location}/{cls}", 358 | ) 359 | 360 | # If the path does not exist or there are no relevant images, return False 361 | if not os.path.exists(path) or not any( 362 | img.endswith((".png", ".jpg", ".jpeg")) for img in os.listdir(path) 363 | ): 364 | return False 365 | 366 | # If all the required images are present, return True 367 | return True 368 | def get_tensorflow_dataset(dataset_name: str, root_dir: str): 369 | """ 370 | Returns the dataset as a tensorflow dataset, and downloads it if it is not already available. 371 | """ 372 | root_dir = root_dir.split("/spawrious224/")[0] # in case people pass in the wrong root_dir 373 | assert dataset_name.lower() in { 374 | "o2o_easy", 375 | "o2o_medium", 376 | "o2o_hard", 377 | "m2m_easy", 378 | "m2m_medium", 379 | "m2m_hard", 380 | "m2m", 381 | "entire_dataset", 382 | }, f"Invalid dataset type: {dataset_name}" 383 | _download_dataset_if_not_available(dataset_name, root_dir) 384 | return SpawriousBenchmark(dataset_name, root_dir, augment=True) 385 | 386 | if __name__ == '__main__': 387 | # get_spawrious_dataset('./test_dir','m2m_easy') 388 | root_dir = "/home/aengusl/Desktop/Projects/OOD_workshop/spawrious/data/" 389 | dataset_type = "m2m_easy" 390 | result = _check_images_availability(root_dir, dataset_type) 391 | print(result) 392 | -------------------------------------------------------------------------------- /spawrious/torch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tarfile 3 | import urllib 4 | import urllib.request 5 | from typing import Any, Tuple 6 | 7 | import torch 8 | from PIL import Image 9 | from torch.utils.data import ConcatDataset, Dataset 10 | from torchvision import transforms 11 | from torchvision.datasets import ImageFolder 12 | from tqdm import tqdm 13 | import timm 14 | from PIL import ImageFile 15 | 16 | ImageFile.LOAD_TRUNCATED_IMAGES = True 17 | 18 | # MODEL_NAME = "vit_so400m_patch14_siglip_384" 19 | # MODEL_NAME = 'swin_base_patch4_window7_224.ms_in22k_ft_in1k' 20 | # MODEL_NAME = 'deit3_base_patch16_224.fb_in22k_ft_in1k' 21 | # MODEL_NAME = 'beit_base_patch16_224.in22k_ft_in22k_in1k' 22 | # MODEL_NAME = 'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k' 23 | # MODEL_NAME = 'levit_128s.fb_dist_in1k' 24 | 25 | MODEL_NAME = None 26 | 27 | def set_model_name(name): 28 | global MODEL_NAME 29 | MODEL_NAME = name 30 | 31 | 32 | def _extract_dataset_from_tar( 33 | tar_file_name: str, data_dir: str, remove_tar_after_extracting: bool = True 34 | ) -> None: 35 | tar_file_dst = os.path.join(data_dir, tar_file_name) 36 | print("Extracting dataset...") 37 | tar = tarfile.open(tar_file_dst, "r:gz") 38 | tar.extractall(os.path.dirname(tar_file_dst)) 39 | tar.close() 40 | print("Dataset extracted. Delete tar file.") 41 | if remove_tar_after_extracting: 42 | os.remove(tar_file_dst) 43 | 44 | 45 | def _download_dataset_if_not_available( 46 | dataset_name: str, data_dir: str, remove_tar_after_extracting: bool = True 47 | ) -> None: 48 | """ 49 | datasets.txt file, which is present in the data_dir, is used to check if the dataset is already extracted. If the dataset is already extracted, then the tar file is not downloaded again. 50 | """ 51 | data_dir = data_dir.split("/spawrious224/")[ 52 | 0 53 | ] # in case people pass in the wrong root_dir 54 | os.makedirs(data_dir, exist_ok=True) 55 | dataset_name = dataset_name.lower() 56 | if dataset_name.split("_")[0] == "m2m": 57 | dataset_name = "entire_dataset" 58 | url_dict = { 59 | "entire_dataset": "https://www.dropbox.com/s/hofkueo8qvaqlp3/spawrious224__entire_dataset.tar.gz?dl=1", 60 | "o2o_easy": "https://www.dropbox.com/s/kwhiv60ihxe3owy/spawrious224__o2o_easy.tar.gz?dl=1", 61 | "o2o_medium": "https://www.dropbox.com/s/x03gkhdwar5kht4/spawrious224__o2o_medium.tar.gz?dl=1", 62 | "o2o_hard": "https://www.dropbox.com/s/p1ry121m2gjj158/spawrious224__o2o_hard.tar.gz?dl=1", 63 | # "m2m": "https://www.dropbox.com/s/5usem63nfub266y/spawrious__m2m.tar.gz?dl=1", 64 | } 65 | tar_file_name = f"spawrious224__{dataset_name}.tar.gz" 66 | tar_file_dst = os.path.join(data_dir, tar_file_name) 67 | url = url_dict[dataset_name] 68 | 69 | # check if the dataset is already extracted 70 | if _check_images_availability(data_dir, dataset_name): 71 | print("Dataset already downloaded and extracted.") 72 | return 73 | # check if the tar file is already downloaded 74 | else: 75 | if os.path.exists(tar_file_dst): 76 | print("Dataset already downloaded. Extracting...") 77 | _extract_dataset_from_tar( 78 | tar_file_name, data_dir, remove_tar_after_extracting 79 | ) 80 | return 81 | # download the tar file and extract from it 82 | else: 83 | print("Dataset not found. Downloading...") 84 | response = urllib.request.urlopen(url) 85 | total_size = int(response.headers.get("Content-Length", 0)) 86 | block_size = 1024 87 | # Track progress of download 88 | progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True) 89 | with open(tar_file_dst, "wb") as f: 90 | while True: 91 | buffer = response.read(block_size) 92 | if not buffer: 93 | break 94 | f.write(buffer) 95 | progress_bar.update(len(buffer)) 96 | progress_bar.close() 97 | print("Dataset downloaded. Extracting...") 98 | _extract_dataset_from_tar( 99 | tar_file_name, data_dir, remove_tar_after_extracting 100 | ) 101 | return 102 | 103 | 104 | class CustomImageFolder(Dataset): 105 | """ 106 | A class that takes one folder at a time and loads a set number of images in a folder and assigns them a specific class 107 | """ 108 | 109 | def __init__( 110 | self, folder_path, class_index, location_index, limit=None, transform=None 111 | ): 112 | self.folder_path = folder_path 113 | self.class_index = class_index 114 | self.location_index = location_index 115 | self.image_paths = [ 116 | os.path.join(folder_path, img) 117 | for img in os.listdir(folder_path) 118 | if img.endswith((".png", ".jpg", ".jpeg")) 119 | ] 120 | if limit: 121 | self.image_paths = self.image_paths[:limit] 122 | self.transform = transform 123 | 124 | def __len__(self): 125 | return len(self.image_paths) 126 | 127 | def __getitem__(self, index: int) -> Tuple[Any, Any, Any]: 128 | img_path = self.image_paths[index] 129 | img = Image.open(img_path).convert("RGB") 130 | 131 | if self.transform: 132 | img = self.transform(img) 133 | 134 | class_label = torch.tensor(self.class_index, dtype=torch.long) 135 | location_label = torch.tensor(self.location_index, dtype=torch.long) 136 | return img, class_label, location_label 137 | 138 | 139 | class MultipleDomainDataset: 140 | N_STEPS = 5001 # Default, subclasses may override 141 | CHECKPOINT_FREQ = 100 # Default, subclasses may override 142 | N_WORKERS = 8 # Default, subclasses may override 143 | ENVIRONMENTS = None # Subclasses should override 144 | INPUT_SHAPE = None # Subclasses should override 145 | 146 | def __getitem__(self, index): 147 | return self.datasets[index] 148 | 149 | def __len__(self): 150 | return len(self.datasets) 151 | 152 | 153 | def build_combination(benchmark_type, group, test, filler=None): 154 | total = 3168 155 | combinations = {} 156 | if "m2m" in benchmark_type: 157 | counts = [total, total] 158 | combinations["train_combinations"] = { 159 | ("bulldog",): [(group[0], counts[0]), (group[1], counts[1])], 160 | ("dachshund",): [(group[1], counts[0]), (group[0], counts[1])], 161 | ("labrador",): [(group[2], counts[0]), (group[3], counts[1])], 162 | ("corgi",): [(group[3], counts[0]), (group[2], counts[1])], 163 | } 164 | combinations["test_combinations"] = { 165 | ("bulldog",): [test[0], test[1]], 166 | ("dachshund",): [test[1], test[0]], 167 | ("labrador",): [test[2], test[3]], 168 | ("corgi",): [test[3], test[2]], 169 | } 170 | else: 171 | counts = [int(0.97 * total), int(0.87 * total)] 172 | combinations["train_combinations"] = { 173 | ("bulldog",): [(group[0], counts[0]), (group[0], counts[1])], 174 | ("dachshund",): [(group[1], counts[0]), (group[1], counts[1])], 175 | ("labrador",): [(group[2], counts[0]), (group[2], counts[1])], 176 | ("corgi",): [(group[3], counts[0]), (group[3], counts[1])], 177 | ("bulldog", "dachshund", "labrador", "corgi"): [ 178 | (filler, total - counts[0]), 179 | (filler, total - counts[1]), 180 | ], 181 | } 182 | combinations["test_combinations"] = { 183 | ("bulldog",): [test[0], test[0]], 184 | ("dachshund",): [test[1], test[1]], 185 | ("labrador",): [test[2], test[2]], 186 | ("corgi",): [test[3], test[3]], 187 | } 188 | return combinations 189 | 190 | 191 | def _get_combinations(benchmark_type: str) -> Tuple[dict, dict]: 192 | combinations = { 193 | "o2o_easy": ( 194 | ["desert", "jungle", "dirt", "snow"], 195 | ["dirt", "snow", "desert", "jungle"], 196 | "beach", 197 | ), 198 | "o2o_medium": ( 199 | ["mountain", "beach", "dirt", "jungle"], 200 | ["jungle", "dirt", "beach", "snow"], 201 | "desert", 202 | ), 203 | "o2o_hard": ( 204 | ["jungle", "mountain", "snow", "desert"], 205 | ["mountain", "snow", "desert", "jungle"], 206 | "beach", 207 | ), 208 | "m2m_hard": ( 209 | ["dirt", "jungle", "snow", "beach"], 210 | ["snow", "beach", "dirt", "jungle"], 211 | None, 212 | ), 213 | "m2m_easy": ( 214 | ["desert", "mountain", "dirt", "jungle"], 215 | ["dirt", "jungle", "mountain", "desert"], 216 | None, 217 | ), 218 | "m2m_medium": ( 219 | ["beach", "snow", "mountain", "desert"], 220 | ["desert", "mountain", "beach", "snow"], 221 | None, 222 | ), 223 | } 224 | if benchmark_type not in combinations: 225 | raise ValueError("Invalid benchmark type") 226 | group, test, filler = combinations[benchmark_type] 227 | return build_combination(benchmark_type, group, test, filler) 228 | 229 | 230 | class SpawriousBenchmark(MultipleDomainDataset): 231 | ENVIRONMENTS = ["Test", "SC_group_1", "SC_group_2"] 232 | input_shape = (3, 224, 224) 233 | num_classes = 4 234 | class_list = ["bulldog", "corgi", "dachshund", "labrador"] 235 | locations_list = ["desert", "jungle", "dirt", "mountain", "snow", "beach"] 236 | 237 | def __init__(self, benchmark, root_dir, augment=True): 238 | combinations = _get_combinations(benchmark.lower()) 239 | self.type1 = benchmark.lower().startswith("o2o") 240 | train_datasets, test_datasets = self._prepare_data_lists( 241 | combinations["train_combinations"], 242 | combinations["test_combinations"], 243 | root_dir, 244 | augment, 245 | ) 246 | self.datasets = [ConcatDataset(test_datasets)] + train_datasets 247 | 248 | def get_train_dataset(self): 249 | return torch.utils.data.ConcatDataset(self.datasets[1:]) 250 | 251 | def get_test_dataset(self): 252 | return self.datasets[0] 253 | 254 | # Prepares the train and test data lists by applying the necessary transformations. 255 | def _prepare_data_lists( 256 | self, train_combinations, test_combinations, root_dir, augment 257 | ): 258 | backbone = timm.create_model( 259 | # "vit_so400m_patch14_siglip_384", 260 | MODEL_NAME, 261 | pretrained=True, 262 | num_classes=0, 263 | ).eval() 264 | self.data_config = timm.data.resolve_model_data_config(backbone) 265 | test_transforms = timm.data.create_transform( 266 | **self.data_config, is_training=False 267 | ) 268 | 269 | # test_transforms = transforms.Compose( 270 | # [ 271 | # transforms.Resize((self.input_shape[1], self.input_shape[2])), 272 | # transforms.transforms.ToTensor(), 273 | # transforms.Normalize( 274 | # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 275 | # ), 276 | # ] 277 | # ) 278 | 279 | if augment: 280 | train_transforms = timm.data.create_transform( 281 | **self.data_config, is_training=True 282 | ) 283 | else: 284 | train_transforms = test_transforms 285 | 286 | train_data_list = self._create_data_list( 287 | train_combinations, root_dir, train_transforms 288 | ) 289 | test_data_list = self._create_data_list( 290 | test_combinations, root_dir, test_transforms 291 | ) 292 | 293 | return train_data_list, test_data_list 294 | 295 | # Creates a list of datasets based on the given combinations and transformations. 296 | def _create_data_list(self, combinations, root_dir, transforms): 297 | data_list = [] 298 | if isinstance(combinations, dict): 299 | # Build class groups for a given set of combinations, root directory, and transformations. 300 | for_each_class_group = [] 301 | cg_index = 0 302 | for classes, comb_list in combinations.items(): 303 | for_each_class_group.append([]) 304 | for ind, location_limit in enumerate(comb_list): 305 | if isinstance(location_limit, tuple): 306 | location, limit = location_limit 307 | else: 308 | location, limit = location_limit, None 309 | cg_data_list = [] 310 | for cls in classes: 311 | path = os.path.join( 312 | root_dir, 313 | "spawrious224", 314 | f"{0 if not self.type1 else ind}/{location}/{cls}", 315 | ) 316 | data = CustomImageFolder( 317 | folder_path=path, 318 | class_index=self.class_list.index(cls), 319 | location_index=self.locations_list.index(location), 320 | limit=limit, 321 | transform=transforms, 322 | ) 323 | cg_data_list.append(data) 324 | 325 | for_each_class_group[cg_index].append(ConcatDataset(cg_data_list)) 326 | cg_index += 1 327 | 328 | for group in range(len(for_each_class_group[0])): 329 | data_list.append( 330 | ConcatDataset( 331 | [ 332 | for_each_class_group[k][group] 333 | for k in range(len(for_each_class_group)) 334 | ] 335 | ) 336 | ) 337 | else: 338 | for location in combinations: 339 | path = os.path.join(root_dir, f"{0}/{location}/") 340 | data = ImageFolder(root=path, transform=transforms) 341 | data_list.append(data) 342 | 343 | return data_list 344 | 345 | 346 | def _check_images_availability(root_dir: str, dataset_type: str) -> bool: 347 | # Get the combinations for the given dataset type 348 | root_dir = root_dir.split("/spawrious224/")[ 349 | 0 350 | ] # in case people pass in the wrong root_dir 351 | if dataset_type == "entire_dataset": 352 | for dataset in ["0", "1", "domain_adaptation_ds"]: 353 | for location in ["snow", "jungle", "desert", "dirt", "mountain", "beach"]: 354 | for cls in ["bulldog", "corgi", "dachshund", "labrador"]: 355 | path = os.path.join( 356 | root_dir, "spawrious224", f"{dataset}/{location}/{cls}" 357 | ) 358 | if not os.path.exists(path) or not any( 359 | img.endswith((".png", ".jpg", ".jpeg")) 360 | for img in os.listdir(path) 361 | ): 362 | return False 363 | return True 364 | combinations = _get_combinations(dataset_type.lower()) 365 | 366 | # Extract the train and test combinations 367 | train_combinations = combinations["train_combinations"] 368 | test_combinations = combinations["test_combinations"] 369 | 370 | # Check if the relevant images for each combination are present in the root directory 371 | for combination in [train_combinations, test_combinations]: 372 | for classes, comb_list in combination.items(): 373 | for ind, location_limit in enumerate(comb_list): 374 | if isinstance(location_limit, tuple): 375 | location, limit = location_limit 376 | else: 377 | location, limit = location_limit, None 378 | 379 | for cls in classes: 380 | path = os.path.join( 381 | root_dir, 382 | "spawrious224", 383 | f"{0 if not dataset_type.lower().startswith('o2o') else ind}/{location}/{cls}", 384 | ) 385 | 386 | # If the path does not exist or there are no relevant images, return False 387 | if not os.path.exists(path) or not any( 388 | img.endswith((".png", ".jpg", ".jpeg")) 389 | for img in os.listdir(path) 390 | ): 391 | return False 392 | 393 | # If all the required images are present, return True 394 | return True 395 | 396 | 397 | def get_spawrious_dataset(root_dir: str, dataset_name: str = "entire_dataset"): 398 | """ 399 | Returns the dataset as a torch dataset, and downloads dataset if dataset is not already available. 400 | 401 | By default, the entire dataset is downloaded, which is necessary for m2m experiments, and domain adaptation experiments 402 | """ 403 | root_dir = root_dir.split("/spawrious224/")[ 404 | 0 405 | ] # in case people pass in the wrong root_dir 406 | assert dataset_name.lower() in { 407 | "o2o_easy", 408 | "o2o_medium", 409 | "o2o_hard", 410 | "m2m_easy", 411 | "m2m_medium", 412 | "m2m_hard", 413 | "m2m", 414 | "entire_dataset", 415 | }, f"Invalid dataset type: {dataset_name}" 416 | _download_dataset_if_not_available(dataset_name, root_dir) 417 | # TODO: get m2m to use entire dataset, not half of it 418 | return SpawriousBenchmark(dataset_name, root_dir, augment=True) 419 | --------------------------------------------------------------------------------