├── LICENSE ├── README.md ├── aircraft.py ├── cars.py ├── cub2011.py ├── dogs.py ├── inat2017.py ├── nabirds.py └── tiny_imagenet.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 lvyilin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch FGVC Dataset 2 | 3 | This repo contains some unofficial PyTorch dataset APIs(mainly for Fine-Grained Visual Categorization task), 4 | which support automatically download (except large-scale datasets), extract the archives, and prepare the data. 5 | 6 | ## Supported Datasets 7 | - [x] CUB-200-2011 8 | - [x] Stanford Dogs 9 | - [x] Stanford Cars 10 | - [x] FGVC Aircraft 11 | - [x] NABirds 12 | - [x] Tiny ImageNet 13 | - [x] iNaturalist 2017 14 | - [ ] Oxford 102 Flowers 15 | - [ ] Oxford-IIIT Pets 16 | - [ ] Food-101 17 | 18 | ## Usage 19 | The code was tested on 20 | - pytorch==1.4.0 21 | - torchvision==0.4.1 22 | 23 | Use them the similar way you use `torchvision.datasets`. 24 | ```python 25 | train_dataset = Cub2011('./cub2011', train=True, download=False) 26 | test_dataset = Cub2011('./cub2011', train=False, download=False) 27 | ``` 28 | ## Contributing 29 | Feel free to open an issue or PR. 30 | 31 | ## License 32 | MIT -------------------------------------------------------------------------------- /aircraft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from torchvision.datasets import VisionDataset 4 | from torchvision.datasets.folder import default_loader 5 | from torchvision.datasets.utils import download_url 6 | from torchvision.datasets.utils import extract_archive 7 | 8 | 9 | class Aircraft(VisionDataset): 10 | """`FGVC-Aircraft `_ Dataset. 11 | 12 | Args: 13 | root (string): Root directory of the dataset. 14 | train (bool, optional): If True, creates dataset from training set, otherwise 15 | creates from test set. 16 | class_type (string, optional): choose from ('variant', 'family', 'manufacturer'). 17 | transform (callable, optional): A function/transform that takes in an PIL image 18 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 19 | target_transform (callable, optional): A function/transform that takes in the 20 | target and transforms it. 21 | download (bool, optional): If true, downloads the dataset from the internet and 22 | puts it in root directory. If dataset is already downloaded, it is not 23 | downloaded again. 24 | """ 25 | url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' 26 | class_types = ('variant', 'family', 'manufacturer') 27 | splits = ('train', 'val', 'trainval', 'test') 28 | img_folder = os.path.join('fgvc-aircraft-2013b', 'data', 'images') 29 | 30 | def __init__(self, root, train=True, class_type='variant', transform=None, 31 | target_transform=None, download=False): 32 | super(Aircraft, self).__init__(root, transform=transform, target_transform=target_transform) 33 | split = 'trainval' if train else 'test' 34 | if split not in self.splits: 35 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format( 36 | split, ', '.join(self.splits), 37 | )) 38 | if class_type not in self.class_types: 39 | raise ValueError('Class type "{}" not found. Valid class types are: {}'.format( 40 | class_type, ', '.join(self.class_types), 41 | )) 42 | 43 | self.class_type = class_type 44 | self.split = split 45 | self.classes_file = os.path.join(self.root, 'fgvc-aircraft-2013b', 'data', 46 | 'images_%s_%s.txt' % (self.class_type, self.split)) 47 | 48 | if download: 49 | self.download() 50 | 51 | (image_ids, targets, classes, class_to_idx) = self.find_classes() 52 | samples = self.make_dataset(image_ids, targets) 53 | 54 | self.loader = default_loader 55 | 56 | self.samples = samples 57 | self.classes = classes 58 | self.class_to_idx = class_to_idx 59 | 60 | def __getitem__(self, index): 61 | path, target = self.samples[index] 62 | sample = self.loader(path) 63 | if self.transform is not None: 64 | sample = self.transform(sample) 65 | if self.target_transform is not None: 66 | target = self.target_transform(target) 67 | return sample, target 68 | 69 | def __len__(self): 70 | return len(self.samples) 71 | 72 | def _check_exists(self): 73 | return os.path.exists(os.path.join(self.root, self.img_folder)) and \ 74 | os.path.exists(self.classes_file) 75 | 76 | def download(self): 77 | if self._check_exists(): 78 | return 79 | 80 | # prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz 81 | print('Downloading %s...' % self.url) 82 | tar_name = self.url.rpartition('/')[-1] 83 | download_url(self.url, root=self.root, filename=tar_name) 84 | tar_path = os.path.join(self.root, tar_name) 85 | print('Extracting %s...' % tar_path) 86 | extract_archive(tar_path) 87 | print('Done!') 88 | 89 | def find_classes(self): 90 | # read classes file, separating out image IDs and class names 91 | image_ids = [] 92 | targets = [] 93 | with open(self.classes_file, 'r') as f: 94 | for line in f: 95 | split_line = line.split(' ') 96 | image_ids.append(split_line[0]) 97 | targets.append(' '.join(split_line[1:])) 98 | 99 | # index class names 100 | classes = np.unique(targets) 101 | class_to_idx = {classes[i]: i for i in range(len(classes))} 102 | targets = [class_to_idx[c] for c in targets] 103 | 104 | return image_ids, targets, classes, class_to_idx 105 | 106 | def make_dataset(self, image_ids, targets): 107 | assert (len(image_ids) == len(targets)) 108 | images = [] 109 | for i in range(len(image_ids)): 110 | item = (os.path.join(self.root, self.img_folder, 111 | '%s.jpg' % image_ids[i]), targets[i]) 112 | images.append(item) 113 | return images 114 | 115 | 116 | if __name__ == '__main__': 117 | train_dataset = Aircraft('./aircraft', train=True, download=False) 118 | test_dataset = Aircraft('./aircraft', train=False, download=False) 119 | -------------------------------------------------------------------------------- /cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io as sio 3 | from torchvision.datasets import VisionDataset 4 | from torchvision.datasets.folder import default_loader 5 | from torchvision.datasets.utils import download_url 6 | from torchvision.datasets.utils import extract_archive 7 | 8 | 9 | class Cars(VisionDataset): 10 | """`Stanford Cars `_ Dataset. 11 | 12 | Args: 13 | root (string): Root directory of the dataset. 14 | train (bool, optional): If True, creates dataset from training set, otherwise 15 | creates from test set. 16 | transform (callable, optional): A function/transform that takes in an PIL image 17 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 18 | target_transform (callable, optional): A function/transform that takes in the 19 | target and transforms it. 20 | download (bool, optional): If true, downloads the dataset from the internet and 21 | puts it in root directory. If dataset is already downloaded, it is not 22 | downloaded again. 23 | """ 24 | file_list = { 25 | 'imgs': ('http://imagenet.stanford.edu/internal/car196/car_ims.tgz', 'car_ims.tgz'), 26 | 'annos': ('http://imagenet.stanford.edu/internal/car196/cars_annos.mat', 'cars_annos.mat') 27 | } 28 | 29 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False): 30 | super(Cars, self).__init__(root, transform=transform, target_transform=target_transform) 31 | 32 | self.loader = default_loader 33 | self.train = train 34 | 35 | if self._check_exists(): 36 | print('Files already downloaded and verified.') 37 | elif download: 38 | self._download() 39 | else: 40 | raise RuntimeError( 41 | 'Dataset not found. You can use download=True to download it.') 42 | 43 | loaded_mat = sio.loadmat(os.path.join(self.root, self.file_list['annos'][1])) 44 | loaded_mat = loaded_mat['annotations'][0] 45 | self.samples = [] 46 | for item in loaded_mat: 47 | if self.train != bool(item[-1][0]): 48 | path = str(item[0][0]) 49 | label = int(item[-2][0]) - 1 50 | self.samples.append((path, label)) 51 | 52 | def __getitem__(self, index): 53 | path, target = self.samples[index] 54 | path = os.path.join(self.root, path) 55 | 56 | image = self.loader(path) 57 | if self.transform is not None: 58 | image = self.transform(image) 59 | if self.target_transform is not None: 60 | target = self.target_transform(target) 61 | return image, target 62 | 63 | def __len__(self): 64 | return len(self.samples) 65 | 66 | def _check_exists(self): 67 | return (os.path.exists(os.path.join(self.root, self.file_list['imgs'][1])) 68 | and os.path.exists(os.path.join(self.root, self.file_list['annos'][1]))) 69 | 70 | def _download(self): 71 | print('Downloading...') 72 | for url, filename in self.file_list.values(): 73 | download_url(url, root=self.root, filename=filename) 74 | print('Extracting...') 75 | archive = os.path.join(self.root, self.file_list['imgs'][1]) 76 | extract_archive(archive) 77 | 78 | 79 | if __name__ == '__main__': 80 | train_dataset = Cars('./cars', train=True, download=False) 81 | test_dataset = Cars('./cars', train=False, download=False) 82 | -------------------------------------------------------------------------------- /cub2011.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | from torchvision.datasets import VisionDataset 5 | from torchvision.datasets.folder import default_loader 6 | from torchvision.datasets.utils import download_file_from_google_drive 7 | 8 | 9 | class Cub2011(VisionDataset): 10 | """`CUB-200-2011 `_ Dataset. 11 | 12 | Args: 13 | root (string): Root directory of the dataset. 14 | train (bool, optional): If True, creates dataset from training set, otherwise 15 | creates from test set. 16 | transform (callable, optional): A function/transform that takes in an PIL image 17 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 18 | target_transform (callable, optional): A function/transform that takes in the 19 | target and transforms it. 20 | download (bool, optional): If true, downloads the dataset from the internet and 21 | puts it in root directory. If dataset is already downloaded, it is not 22 | downloaded again. 23 | """ 24 | base_folder = 'CUB_200_2011/images' 25 | # url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' 26 | file_id = '1hbzc_P1FuxMkcabkgn9ZKinBwW683j45' 27 | filename = 'CUB_200_2011.tgz' 28 | tgz_md5 = '97eceeb196236b17998738112f37df78' 29 | 30 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False): 31 | super(Cub2011, self).__init__(root, transform=transform, target_transform=target_transform) 32 | 33 | self.loader = default_loader 34 | self.train = train 35 | if download: 36 | self._download() 37 | 38 | if not self._check_integrity(): 39 | raise RuntimeError('Dataset not found or corrupted. You can use download=True to download it') 40 | 41 | def _load_metadata(self): 42 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ', 43 | names=['img_id', 'filepath']) 44 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'), 45 | sep=' ', names=['img_id', 'target']) 46 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'), 47 | sep=' ', names=['img_id', 'is_training_img']) 48 | 49 | data = images.merge(image_class_labels, on='img_id') 50 | self.data = data.merge(train_test_split, on='img_id') 51 | 52 | class_names = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'classes.txt'), 53 | sep=' ', names=['class_name'], usecols=[1]) 54 | self.class_names = class_names['class_name'].to_list() 55 | if self.train: 56 | self.data = self.data[self.data.is_training_img == 1] 57 | else: 58 | self.data = self.data[self.data.is_training_img == 0] 59 | 60 | def _check_integrity(self): 61 | try: 62 | self._load_metadata() 63 | except Exception: 64 | return False 65 | 66 | for index, row in self.data.iterrows(): 67 | filepath = os.path.join(self.root, self.base_folder, row.filepath) 68 | if not os.path.isfile(filepath): 69 | print(filepath) 70 | return False 71 | return True 72 | 73 | def _download(self): 74 | import tarfile 75 | 76 | if self._check_integrity(): 77 | print('Files already downloaded and verified') 78 | return 79 | 80 | download_file_from_google_drive(self.file_id, self.root, self.filename, self.tgz_md5) 81 | 82 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: 83 | tar.extractall(path=self.root) 84 | 85 | def __len__(self): 86 | return len(self.data) 87 | 88 | def __getitem__(self, idx): 89 | sample = self.data.iloc[idx] 90 | path = os.path.join(self.root, self.base_folder, sample.filepath) 91 | target = sample.target - 1 # Targets start at 1 by default, so shift to 0 92 | img = self.loader(path) 93 | 94 | if self.transform is not None: 95 | img = self.transform(img) 96 | if self.target_transform is not None: 97 | target = self.target_transform(target) 98 | return img, target 99 | 100 | 101 | if __name__ == '__main__': 102 | train_dataset = Cub2011('./cub2011', train=True, download=False) 103 | test_dataset = Cub2011('./cub2011', train=False, download=False) 104 | -------------------------------------------------------------------------------- /dogs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io 3 | from os.path import join 4 | from torchvision.datasets import VisionDataset 5 | from torchvision.datasets.folder import default_loader 6 | from torchvision.datasets.utils import download_url, list_dir 7 | 8 | 9 | class Dogs(VisionDataset): 10 | """`Stanford Dogs `_ Dataset. 11 | 12 | Args: 13 | root (string): Root directory of the dataset. 14 | train (bool, optional): If True, creates dataset from training set, otherwise 15 | creates from test set. 16 | transform (callable, optional): A function/transform that takes in an PIL image 17 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 18 | target_transform (callable, optional): A function/transform that takes in the 19 | target and transforms it. 20 | download (bool, optional): If true, downloads the dataset from the internet and 21 | puts it in root directory. If dataset is already downloaded, it is not 22 | downloaded again. 23 | """ 24 | download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs' 25 | 26 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False): 27 | super(Dogs, self).__init__(root, transform=transform, target_transform=target_transform) 28 | 29 | self.loader = default_loader 30 | self.train = train 31 | 32 | if download: 33 | self.download() 34 | 35 | split = self.load_split() 36 | 37 | self.images_folder = join(self.root, 'Images') 38 | self.annotations_folder = join(self.root, 'Annotation') 39 | self._breeds = list_dir(self.images_folder) 40 | 41 | self._breed_images = [(annotation + '.jpg', idx) for annotation, idx in split] 42 | 43 | self._flat_breed_images = self._breed_images 44 | 45 | def __len__(self): 46 | return len(self._flat_breed_images) 47 | 48 | def __getitem__(self, index): 49 | image_name, target = self._flat_breed_images[index] 50 | image_path = join(self.images_folder, image_name) 51 | image = self.loader(image_path) 52 | 53 | if self.transform is not None: 54 | image = self.transform(image) 55 | if self.target_transform is not None: 56 | target = self.target_transform(target) 57 | return image, target 58 | 59 | def download(self): 60 | import tarfile 61 | 62 | if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')): 63 | if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120: 64 | print('Files already downloaded and verified') 65 | return 66 | 67 | for filename in ['images', 'annotation', 'lists']: 68 | tar_filename = filename + '.tar' 69 | url = self.download_url_prefix + '/' + tar_filename 70 | download_url(url, self.root, tar_filename, None) 71 | print('Extracting downloaded file: ' + join(self.root, tar_filename)) 72 | with tarfile.open(join(self.root, tar_filename), 'r') as tar_file: 73 | tar_file.extractall(self.root) 74 | os.remove(join(self.root, tar_filename)) 75 | 76 | def load_split(self): 77 | if self.train: 78 | split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list'] 79 | labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels'] 80 | else: 81 | split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list'] 82 | labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels'] 83 | 84 | split = [item[0][0] for item in split] 85 | labels = [item[0] - 1 for item in labels] 86 | return list(zip(split, labels)) 87 | 88 | def stats(self): 89 | counts = {} 90 | for index in range(len(self._flat_breed_images)): 91 | image_name, target_class = self._flat_breed_images[index] 92 | if target_class not in counts.keys(): 93 | counts[target_class] = 1 94 | else: 95 | counts[target_class] += 1 96 | 97 | print("%d samples spanning %d classes (avg %f per class)" % (len(self._flat_breed_images), len(counts.keys()), 98 | float(len(self._flat_breed_images)) / float( 99 | len(counts.keys())))) 100 | 101 | return counts 102 | 103 | 104 | if __name__ == '__main__': 105 | train_dataset = Dogs('./dogs', train=True, download=False) 106 | test_dataset = Dogs('./dogs', train=False, download=False) 107 | -------------------------------------------------------------------------------- /inat2017.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torchvision.datasets import VisionDataset 4 | from torchvision.datasets.folder import default_loader 5 | from torchvision.datasets.utils import check_integrity, extract_archive 6 | from torchvision.datasets.utils import download_url, verify_str_arg 7 | 8 | 9 | class INat2017(VisionDataset): 10 | """`iNaturalist 2017 `_ Dataset. 11 | 12 | Args: 13 | root (string): Root directory of the dataset. 14 | split (string, optional): The dataset split, supports ``train``, or ``val``. 15 | transform (callable, optional): A function/transform that takes in an PIL image 16 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 17 | target_transform (callable, optional): A function/transform that takes in the 18 | target and transforms it. 19 | download (bool, optional): If true, downloads the dataset from the internet and 20 | puts it in root directory. If dataset is already downloaded, it is not 21 | downloaded again. 22 | """ 23 | base_folder = 'train_val_images/' 24 | file_list = { 25 | 'imgs': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val_images.tar.gz', 26 | 'train_val_images.tar.gz', 27 | '7c784ea5e424efaec655bd392f87301f'), 28 | 'annos': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val2017.zip', 29 | 'train_val2017.zip', 30 | '444c835f6459867ad69fcb36478786e7') 31 | } 32 | 33 | def __init__(self, root, split='train', transform=None, target_transform=None, download=False): 34 | super(INat2017, self).__init__(root, transform=transform, target_transform=target_transform) 35 | self.loader = default_loader 36 | self.split = verify_str_arg(split, "split", ("train", "val",)) 37 | 38 | if self._check_exists(): 39 | print('Files already downloaded and verified.') 40 | elif download: 41 | if not (os.path.exists(os.path.join(self.root, self.file_list['imgs'][1])) 42 | and os.path.exists(os.path.join(self.root, self.file_list['annos'][1]))): 43 | print('Downloading...') 44 | self._download() 45 | print('Extracting...') 46 | extract_archive(os.path.join(self.root, self.file_list['imgs'][1])) 47 | extract_archive(os.path.join(self.root, self.file_list['annos'][1])) 48 | else: 49 | raise RuntimeError( 50 | 'Dataset not found. You can use download=True to download it.') 51 | anno_filename = split + '2017.json' 52 | with open(os.path.join(self.root, anno_filename), 'r') as fp: 53 | all_annos = json.load(fp) 54 | 55 | self.annos = all_annos['annotations'] 56 | self.images = all_annos['images'] 57 | 58 | def __getitem__(self, index): 59 | path = os.path.join(self.root, self.images[index]['file_name']) 60 | target = self.annos[index]['category_id'] 61 | 62 | image = self.loader(path) 63 | if self.transform is not None: 64 | image = self.transform(image) 65 | if self.target_transform is not None: 66 | target = self.target_transform(target) 67 | 68 | return image, target 69 | 70 | def __len__(self): 71 | return len(self.images) 72 | 73 | def _check_exists(self): 74 | return os.path.exists(os.path.join(self.root, self.base_folder)) 75 | 76 | def _download(self): 77 | for url, filename, md5 in self.file_list.values(): 78 | download_url(url, root=self.root, filename=filename) 79 | if not check_integrity(os.path.join(self.root, filename), md5): 80 | raise RuntimeError("File not found or corrupted.") 81 | 82 | 83 | if __name__ == '__main__': 84 | train_dataset = INat2017('./inat2017', split='train', download=False) 85 | test_dataset = INat2017('./inat2017', split='val', download=False) 86 | -------------------------------------------------------------------------------- /nabirds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import warnings 4 | from torchvision.datasets import VisionDataset 5 | from torchvision.datasets.folder import default_loader 6 | from torchvision.datasets.utils import check_integrity, extract_archive 7 | 8 | 9 | class NABirds(VisionDataset): 10 | """`NABirds `_ Dataset. 11 | 12 | Args: 13 | root (string): Root directory of the dataset. 14 | train (bool, optional): If True, creates dataset from training set, otherwise 15 | creates from test set. 16 | transform (callable, optional): A function/transform that takes in an PIL image 17 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 18 | target_transform (callable, optional): A function/transform that takes in the 19 | target and transforms it. 20 | download (bool, optional): If true, downloads the dataset from the internet and 21 | puts it in root directory. If dataset is already downloaded, it is not 22 | downloaded again. 23 | """ 24 | base_folder = 'nabirds/images' 25 | filename = 'nabirds.tar.gz' 26 | md5 = 'df21a9e4db349a14e2b08adfd45873bd' 27 | 28 | def __init__(self, root, train=True, transform=None, target_transform=None, download=None): 29 | super(NABirds, self).__init__(root, transform=transform, target_transform=target_transform) 30 | if download is True: 31 | msg = ("The dataset is no longer publicly accessible. You need to " 32 | "download the archives externally and place them in the root " 33 | "directory.") 34 | raise RuntimeError(msg) 35 | elif download is False: 36 | msg = ("The use of the download flag is deprecated, since the dataset " 37 | "is no longer publicly accessible.") 38 | warnings.warn(msg, RuntimeWarning) 39 | 40 | dataset_path = os.path.join(root, 'nabirds') 41 | if not os.path.isdir(dataset_path): 42 | if not check_integrity(os.path.join(root, self.filename), self.md5): 43 | raise RuntimeError('Dataset not found or corrupted.') 44 | extract_archive(os.path.join(root, self.filename)) 45 | self.loader = default_loader 46 | self.train = train 47 | 48 | image_paths = pd.read_csv(os.path.join(dataset_path, 'images.txt'), 49 | sep=' ', names=['img_id', 'filepath']) 50 | image_class_labels = pd.read_csv(os.path.join(dataset_path, 'image_class_labels.txt'), 51 | sep=' ', names=['img_id', 'target']) 52 | # Since the raw labels are non-continuous, map them to new ones 53 | self.label_map = get_continuous_class_map(image_class_labels['target']) 54 | train_test_split = pd.read_csv(os.path.join(dataset_path, 'train_test_split.txt'), 55 | sep=' ', names=['img_id', 'is_training_img']) 56 | data = image_paths.merge(image_class_labels, on='img_id') 57 | self.data = data.merge(train_test_split, on='img_id') 58 | # Load in the train / test split 59 | if self.train: 60 | self.data = self.data[self.data.is_training_img == 1] 61 | else: 62 | self.data = self.data[self.data.is_training_img == 0] 63 | 64 | # Load in the class data 65 | self.class_names = load_class_names(dataset_path) 66 | self.class_hierarchy = load_hierarchy(dataset_path) 67 | 68 | def __len__(self): 69 | return len(self.data) 70 | 71 | def __getitem__(self, idx): 72 | sample = self.data.iloc[idx] 73 | path = os.path.join(self.root, self.base_folder, sample.filepath) 74 | target = self.label_map[sample.target] 75 | img = self.loader(path) 76 | 77 | if self.transform is not None: 78 | img = self.transform(img) 79 | if self.target_transform is not None: 80 | target = self.target_transform(target) 81 | return img, target 82 | 83 | 84 | def get_continuous_class_map(class_labels): 85 | label_set = set(class_labels) 86 | return {k: i for i, k in enumerate(label_set)} 87 | 88 | 89 | def load_class_names(dataset_path=''): 90 | names = {} 91 | 92 | with open(os.path.join(dataset_path, 'classes.txt')) as f: 93 | for line in f: 94 | pieces = line.strip().split() 95 | class_id = pieces[0] 96 | names[class_id] = ' '.join(pieces[1:]) 97 | 98 | return names 99 | 100 | 101 | def load_hierarchy(dataset_path=''): 102 | parents = {} 103 | 104 | with open(os.path.join(dataset_path, 'hierarchy.txt')) as f: 105 | for line in f: 106 | pieces = line.strip().split() 107 | child_id, parent_id = pieces 108 | parents[child_id] = parent_id 109 | 110 | return parents 111 | 112 | 113 | if __name__ == '__main__': 114 | train_dataset = NABirds('./nabirds', train=True, download=False) 115 | test_dataset = NABirds('./nabirds', train=False, download=False) 116 | -------------------------------------------------------------------------------- /tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os 3 | import pandas as pd 4 | import warnings 5 | from torchvision.datasets import ImageFolder 6 | from torchvision.datasets import VisionDataset 7 | from torchvision.datasets.folder import default_loader 8 | from torchvision.datasets.folder import default_loader 9 | from torchvision.datasets.utils import extract_archive, check_integrity, download_url, verify_str_arg 10 | 11 | 12 | class TinyImageNet(VisionDataset): 13 | """`tiny-imageNet `_ Dataset. 14 | 15 | Args: 16 | root (string): Root directory of the dataset. 17 | split (string, optional): The dataset split, supports ``train``, or ``val``. 18 | transform (callable, optional): A function/transform that takes in an PIL image 19 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 20 | target_transform (callable, optional): A function/transform that takes in the 21 | target and transforms it. 22 | download (bool, optional): If true, downloads the dataset from the internet and 23 | puts it in root directory. If dataset is already downloaded, it is not 24 | downloaded again. 25 | """ 26 | base_folder = 'tiny-imagenet-200/' 27 | url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip' 28 | filename = 'tiny-imagenet-200.zip' 29 | md5 = '90528d7ca1a48142e341f4ef8d21d0de' 30 | 31 | def __init__(self, root, split='train', transform=None, target_transform=None, download=False): 32 | super(TinyImageNet, self).__init__(root, transform=transform, target_transform=target_transform) 33 | 34 | self.dataset_path = os.path.join(root, self.base_folder) 35 | self.loader = default_loader 36 | self.split = verify_str_arg(split, "split", ("train", "val",)) 37 | 38 | if self._check_integrity(): 39 | print('Files already downloaded and verified.') 40 | elif download: 41 | self._download() 42 | else: 43 | raise RuntimeError( 44 | 'Dataset not found. You can use download=True to download it.') 45 | if not os.path.isdir(self.dataset_path): 46 | print('Extracting...') 47 | extract_archive(os.path.join(root, self.filename)) 48 | 49 | _, class_to_idx = find_classes(os.path.join(self.dataset_path, 'wnids.txt')) 50 | 51 | self.data = make_dataset(self.root, self.base_folder, self.split, class_to_idx) 52 | 53 | def _download(self): 54 | print('Downloading...') 55 | download_url(self.url, root=self.root, filename=self.filename) 56 | print('Extracting...') 57 | extract_archive(os.path.join(self.root, self.filename)) 58 | 59 | def _check_integrity(self): 60 | return check_integrity(os.path.join(self.root, self.filename), self.md5) 61 | 62 | def __getitem__(self, index): 63 | img_path, target = self.data[index] 64 | image = self.loader(img_path) 65 | 66 | if self.transform is not None: 67 | image = self.transform(image) 68 | if self.target_transform is not None: 69 | target = self.target_transform(target) 70 | 71 | return image, target 72 | 73 | def __len__(self): 74 | return len(self.data) 75 | 76 | 77 | def find_classes(class_file): 78 | with open(class_file) as r: 79 | classes = list(map(lambda s: s.strip(), r.readlines())) 80 | 81 | classes.sort() 82 | class_to_idx = {classes[i]: i for i in range(len(classes))} 83 | 84 | return classes, class_to_idx 85 | 86 | 87 | def make_dataset(root, base_folder, dirname, class_to_idx): 88 | images = [] 89 | dir_path = os.path.join(root, base_folder, dirname) 90 | 91 | if dirname == 'train': 92 | for fname in sorted(os.listdir(dir_path)): 93 | cls_fpath = os.path.join(dir_path, fname) 94 | if os.path.isdir(cls_fpath): 95 | cls_imgs_path = os.path.join(cls_fpath, 'images') 96 | for imgname in sorted(os.listdir(cls_imgs_path)): 97 | path = os.path.join(cls_imgs_path, imgname) 98 | item = (path, class_to_idx[fname]) 99 | images.append(item) 100 | else: 101 | imgs_path = os.path.join(dir_path, 'images') 102 | imgs_annotations = os.path.join(dir_path, 'val_annotations.txt') 103 | 104 | with open(imgs_annotations) as r: 105 | data_info = map(lambda s: s.split('\t'), r.readlines()) 106 | 107 | cls_map = {line_data[0]: line_data[1] for line_data in data_info} 108 | 109 | for imgname in sorted(os.listdir(imgs_path)): 110 | path = os.path.join(imgs_path, imgname) 111 | item = (path, class_to_idx[cls_map[imgname]]) 112 | images.append(item) 113 | 114 | return images 115 | 116 | 117 | if __name__ == '__main__': 118 | train_dataset = TinyImageNet('./tiny-imagenet', split='train', download=False) 119 | test_dataset = TinyImageNet('./tiny-imagenet', split='val', download=False) 120 | --------------------------------------------------------------------------------