├── imagenetv2_pytorch ├── __init__.py └── ImageNetV2_dataset.py ├── setup.py └── readme.md /imagenetv2_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .ImageNetV2_dataset import ImageNetV2Dataset, ImageNetValDataset 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup(name='imagenetv2_pytorch', 6 | version='0.1', 7 | description='imagenetv2 datasets for PyTorch', 8 | author='Benjamin Recht, Rebecca Roelofs, Ludwig Schmidt, Vaishaal Shankar', 9 | author_email='vaishaal@gmail.com', 10 | packages=find_packages() 11 | ) 12 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Pytorch Dataloader for ImageNet V2 2 | 3 | 4 | First Install 5 | 6 | ``` 7 | pip install git+https://github.com/modestyachts/ImageNetV2_pytorch 8 | ``` 9 | 10 | Usage: 11 | 12 | ``` 13 | from imagenetv2_pytorch import ImageNetV2Dataset 14 | from torch.utils.data import DataLoader 15 | 16 | dataset = ImageNetV2Dataset("matched-frequency") # supports matched-frequency, threshold-0.7, top-images variants 17 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=32) # use whatever batch size you wish 18 | # feed into pytorch code 19 | ``` 20 | -------------------------------------------------------------------------------- /imagenetv2_pytorch/ImageNetV2_dataset.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import tarfile 3 | import requests 4 | import shutil 5 | 6 | from PIL import Image 7 | from tqdm import tqdm 8 | from torch.utils.data import Dataset, DataLoader 9 | from torchvision.datasets import ImageFolder 10 | 11 | URLS = {"matched-frequency" : "https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/imagenetv2-matched-frequency.tar.gz", 12 | "threshold-0.7" : "https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/imagenetv2-threshold0.7.tar.gz", 13 | "top-images": "https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/imagenetv2-top-images.tar.gz", 14 | "val": "https://imagenet2val.s3.amazonaws.com/imagenet_validation.tar.gz"} 15 | 16 | FNAMES = {"matched-frequency" : "imagenetv2-matched-frequency-format-val", 17 | "threshold-0.7" : "imagenetv2-threshold0.7-format-val", 18 | "top-images": "imagenetv2-top-images-format-val", 19 | "val": "imagenet_validation"} 20 | 21 | 22 | V2_DATASET_SIZE = 10000 23 | VAL_DATASET_SIZE = 50000 24 | 25 | class ImageNetValDataset(Dataset): 26 | def __init__(self, transform=None, location="."): 27 | self.dataset_root = pathlib.Path(f"{location}/imagenet_validation/") 28 | self.tar_root = pathlib.Path(f"{location}/imagenet_validation.tar.gz") 29 | self.fnames = list(self.dataset_root.glob("**/*.JPEG")) 30 | self.transform = transform 31 | if not self.dataset_root.exists() or len(self.fnames) != VAL_DATASET_SIZE: 32 | if not self.tar_root.exists(): 33 | print(f"Dataset imagenet-val not found on disk, downloading....") 34 | response = requests.get(URLS["val"], stream=True) 35 | total_size_in_bytes= int(response.headers.get('content-length', 0)) 36 | block_size = 1024 #1 Kibibyte 37 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) 38 | with open(self.tar_root, 'wb') as f: 39 | for data in response.iter_content(block_size): 40 | progress_bar.update(len(data)) 41 | f.write(data) 42 | progress_bar.close() 43 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: 44 | assert False, f"Downloading from {URLS[variant]} failed" 45 | print("Extracting....") 46 | tarfile.open(self.tar_root).extractall(f"{location}") 47 | shutil.move(f"{location}/{FNAMES['val']}", self.dataset_root) 48 | 49 | self.dataset = ImageFolder(self.dataset_root) 50 | 51 | def __len__(self): 52 | return len(self.dataset) 53 | 54 | def __getitem__(self, i): 55 | img, label = self.dataset[i] 56 | if self.transform is not None: 57 | img = self.transform(img) 58 | return img, label 59 | 60 | class ImageNetV2Dataset(Dataset): 61 | def __init__(self, variant="matched-frequency", transform=None, location="."): 62 | self.dataset_root = pathlib.Path(f"{location}/ImageNetV2-{variant}/") 63 | self.tar_root = pathlib.Path(f"{location}/ImageNetV2-{variant}.tar.gz") 64 | self.fnames = list(self.dataset_root.glob("**/*.jpeg")) 65 | self.transform = transform 66 | assert variant in URLS, f"unknown V2 Variant: {variant}" 67 | if not self.dataset_root.exists() or len(self.fnames) != V2_DATASET_SIZE: 68 | if not self.tar_root.exists(): 69 | print(f"Dataset {variant} not found on disk, downloading....") 70 | response = requests.get(URLS[variant], stream=True) 71 | total_size_in_bytes= int(response.headers.get('content-length', 0)) 72 | block_size = 1024 #1 Kibibyte 73 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) 74 | with open(self.tar_root, 'wb') as f: 75 | for data in response.iter_content(block_size): 76 | progress_bar.update(len(data)) 77 | f.write(data) 78 | progress_bar.close() 79 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: 80 | assert False, f"Downloading from {URLS[variant]} failed" 81 | print("Extracting....") 82 | tarfile.open(self.tar_root).extractall(f"{location}") 83 | shutil.move(f"{location}/{FNAMES[variant]}", self.dataset_root) 84 | self.fnames = list(self.dataset_root.glob("**/*.jpeg")) 85 | 86 | 87 | def __len__(self): 88 | return len(self.fnames) 89 | 90 | def __getitem__(self, i): 91 | img, label = Image.open(self.fnames[i]), int(self.fnames[i].parent.name) 92 | if self.transform is not None: 93 | img = self.transform(img) 94 | return img, label 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | --------------------------------------------------------------------------------