├── LICENSE
├── README.md
├── aircraft.py
├── cars.py
├── cub2011.py
├── dogs.py
├── inat2017.py
├── nabirds.py
└── tiny_imagenet.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 lvyilin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch FGVC Dataset
2 |
3 | This repo contains some unofficial PyTorch dataset APIs(mainly for Fine-Grained Visual Categorization task),
4 | which support automatically download (except large-scale datasets), extract the archives, and prepare the data.
5 |
6 | ## Supported Datasets
7 | - [x] CUB-200-2011
8 | - [x] Stanford Dogs
9 | - [x] Stanford Cars
10 | - [x] FGVC Aircraft
11 | - [x] NABirds
12 | - [x] Tiny ImageNet
13 | - [x] iNaturalist 2017
14 | - [ ] Oxford 102 Flowers
15 | - [ ] Oxford-IIIT Pets
16 | - [ ] Food-101
17 |
18 | ## Usage
19 | The code was tested on
20 | - pytorch==1.4.0
21 | - torchvision==0.4.1
22 |
23 | Use them the similar way you use `torchvision.datasets`.
24 | ```python
25 | train_dataset = Cub2011('./cub2011', train=True, download=False)
26 | test_dataset = Cub2011('./cub2011', train=False, download=False)
27 | ```
28 | ## Contributing
29 | Feel free to open an issue or PR.
30 |
31 | ## License
32 | MIT
--------------------------------------------------------------------------------
/aircraft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from torchvision.datasets import VisionDataset
4 | from torchvision.datasets.folder import default_loader
5 | from torchvision.datasets.utils import download_url
6 | from torchvision.datasets.utils import extract_archive
7 |
8 |
9 | class Aircraft(VisionDataset):
10 | """`FGVC-Aircraft `_ Dataset.
11 |
12 | Args:
13 | root (string): Root directory of the dataset.
14 | train (bool, optional): If True, creates dataset from training set, otherwise
15 | creates from test set.
16 | class_type (string, optional): choose from ('variant', 'family', 'manufacturer').
17 | transform (callable, optional): A function/transform that takes in an PIL image
18 | and returns a transformed version. E.g, ``transforms.RandomCrop``
19 | target_transform (callable, optional): A function/transform that takes in the
20 | target and transforms it.
21 | download (bool, optional): If true, downloads the dataset from the internet and
22 | puts it in root directory. If dataset is already downloaded, it is not
23 | downloaded again.
24 | """
25 | url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
26 | class_types = ('variant', 'family', 'manufacturer')
27 | splits = ('train', 'val', 'trainval', 'test')
28 | img_folder = os.path.join('fgvc-aircraft-2013b', 'data', 'images')
29 |
30 | def __init__(self, root, train=True, class_type='variant', transform=None,
31 | target_transform=None, download=False):
32 | super(Aircraft, self).__init__(root, transform=transform, target_transform=target_transform)
33 | split = 'trainval' if train else 'test'
34 | if split not in self.splits:
35 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
36 | split, ', '.join(self.splits),
37 | ))
38 | if class_type not in self.class_types:
39 | raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
40 | class_type, ', '.join(self.class_types),
41 | ))
42 |
43 | self.class_type = class_type
44 | self.split = split
45 | self.classes_file = os.path.join(self.root, 'fgvc-aircraft-2013b', 'data',
46 | 'images_%s_%s.txt' % (self.class_type, self.split))
47 |
48 | if download:
49 | self.download()
50 |
51 | (image_ids, targets, classes, class_to_idx) = self.find_classes()
52 | samples = self.make_dataset(image_ids, targets)
53 |
54 | self.loader = default_loader
55 |
56 | self.samples = samples
57 | self.classes = classes
58 | self.class_to_idx = class_to_idx
59 |
60 | def __getitem__(self, index):
61 | path, target = self.samples[index]
62 | sample = self.loader(path)
63 | if self.transform is not None:
64 | sample = self.transform(sample)
65 | if self.target_transform is not None:
66 | target = self.target_transform(target)
67 | return sample, target
68 |
69 | def __len__(self):
70 | return len(self.samples)
71 |
72 | def _check_exists(self):
73 | return os.path.exists(os.path.join(self.root, self.img_folder)) and \
74 | os.path.exists(self.classes_file)
75 |
76 | def download(self):
77 | if self._check_exists():
78 | return
79 |
80 | # prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
81 | print('Downloading %s...' % self.url)
82 | tar_name = self.url.rpartition('/')[-1]
83 | download_url(self.url, root=self.root, filename=tar_name)
84 | tar_path = os.path.join(self.root, tar_name)
85 | print('Extracting %s...' % tar_path)
86 | extract_archive(tar_path)
87 | print('Done!')
88 |
89 | def find_classes(self):
90 | # read classes file, separating out image IDs and class names
91 | image_ids = []
92 | targets = []
93 | with open(self.classes_file, 'r') as f:
94 | for line in f:
95 | split_line = line.split(' ')
96 | image_ids.append(split_line[0])
97 | targets.append(' '.join(split_line[1:]))
98 |
99 | # index class names
100 | classes = np.unique(targets)
101 | class_to_idx = {classes[i]: i for i in range(len(classes))}
102 | targets = [class_to_idx[c] for c in targets]
103 |
104 | return image_ids, targets, classes, class_to_idx
105 |
106 | def make_dataset(self, image_ids, targets):
107 | assert (len(image_ids) == len(targets))
108 | images = []
109 | for i in range(len(image_ids)):
110 | item = (os.path.join(self.root, self.img_folder,
111 | '%s.jpg' % image_ids[i]), targets[i])
112 | images.append(item)
113 | return images
114 |
115 |
116 | if __name__ == '__main__':
117 | train_dataset = Aircraft('./aircraft', train=True, download=False)
118 | test_dataset = Aircraft('./aircraft', train=False, download=False)
119 |
--------------------------------------------------------------------------------
/cars.py:
--------------------------------------------------------------------------------
1 | import os
2 | import scipy.io as sio
3 | from torchvision.datasets import VisionDataset
4 | from torchvision.datasets.folder import default_loader
5 | from torchvision.datasets.utils import download_url
6 | from torchvision.datasets.utils import extract_archive
7 |
8 |
9 | class Cars(VisionDataset):
10 | """`Stanford Cars `_ Dataset.
11 |
12 | Args:
13 | root (string): Root directory of the dataset.
14 | train (bool, optional): If True, creates dataset from training set, otherwise
15 | creates from test set.
16 | transform (callable, optional): A function/transform that takes in an PIL image
17 | and returns a transformed version. E.g, ``transforms.RandomCrop``
18 | target_transform (callable, optional): A function/transform that takes in the
19 | target and transforms it.
20 | download (bool, optional): If true, downloads the dataset from the internet and
21 | puts it in root directory. If dataset is already downloaded, it is not
22 | downloaded again.
23 | """
24 | file_list = {
25 | 'imgs': ('http://imagenet.stanford.edu/internal/car196/car_ims.tgz', 'car_ims.tgz'),
26 | 'annos': ('http://imagenet.stanford.edu/internal/car196/cars_annos.mat', 'cars_annos.mat')
27 | }
28 |
29 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
30 | super(Cars, self).__init__(root, transform=transform, target_transform=target_transform)
31 |
32 | self.loader = default_loader
33 | self.train = train
34 |
35 | if self._check_exists():
36 | print('Files already downloaded and verified.')
37 | elif download:
38 | self._download()
39 | else:
40 | raise RuntimeError(
41 | 'Dataset not found. You can use download=True to download it.')
42 |
43 | loaded_mat = sio.loadmat(os.path.join(self.root, self.file_list['annos'][1]))
44 | loaded_mat = loaded_mat['annotations'][0]
45 | self.samples = []
46 | for item in loaded_mat:
47 | if self.train != bool(item[-1][0]):
48 | path = str(item[0][0])
49 | label = int(item[-2][0]) - 1
50 | self.samples.append((path, label))
51 |
52 | def __getitem__(self, index):
53 | path, target = self.samples[index]
54 | path = os.path.join(self.root, path)
55 |
56 | image = self.loader(path)
57 | if self.transform is not None:
58 | image = self.transform(image)
59 | if self.target_transform is not None:
60 | target = self.target_transform(target)
61 | return image, target
62 |
63 | def __len__(self):
64 | return len(self.samples)
65 |
66 | def _check_exists(self):
67 | return (os.path.exists(os.path.join(self.root, self.file_list['imgs'][1]))
68 | and os.path.exists(os.path.join(self.root, self.file_list['annos'][1])))
69 |
70 | def _download(self):
71 | print('Downloading...')
72 | for url, filename in self.file_list.values():
73 | download_url(url, root=self.root, filename=filename)
74 | print('Extracting...')
75 | archive = os.path.join(self.root, self.file_list['imgs'][1])
76 | extract_archive(archive)
77 |
78 |
79 | if __name__ == '__main__':
80 | train_dataset = Cars('./cars', train=True, download=False)
81 | test_dataset = Cars('./cars', train=False, download=False)
82 |
--------------------------------------------------------------------------------
/cub2011.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pandas as pd
4 | from torchvision.datasets import VisionDataset
5 | from torchvision.datasets.folder import default_loader
6 | from torchvision.datasets.utils import download_file_from_google_drive
7 |
8 |
9 | class Cub2011(VisionDataset):
10 | """`CUB-200-2011 `_ Dataset.
11 |
12 | Args:
13 | root (string): Root directory of the dataset.
14 | train (bool, optional): If True, creates dataset from training set, otherwise
15 | creates from test set.
16 | transform (callable, optional): A function/transform that takes in an PIL image
17 | and returns a transformed version. E.g, ``transforms.RandomCrop``
18 | target_transform (callable, optional): A function/transform that takes in the
19 | target and transforms it.
20 | download (bool, optional): If true, downloads the dataset from the internet and
21 | puts it in root directory. If dataset is already downloaded, it is not
22 | downloaded again.
23 | """
24 | base_folder = 'CUB_200_2011/images'
25 | # url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
26 | file_id = '1hbzc_P1FuxMkcabkgn9ZKinBwW683j45'
27 | filename = 'CUB_200_2011.tgz'
28 | tgz_md5 = '97eceeb196236b17998738112f37df78'
29 |
30 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
31 | super(Cub2011, self).__init__(root, transform=transform, target_transform=target_transform)
32 |
33 | self.loader = default_loader
34 | self.train = train
35 | if download:
36 | self._download()
37 |
38 | if not self._check_integrity():
39 | raise RuntimeError('Dataset not found or corrupted. You can use download=True to download it')
40 |
41 | def _load_metadata(self):
42 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
43 | names=['img_id', 'filepath'])
44 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
45 | sep=' ', names=['img_id', 'target'])
46 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
47 | sep=' ', names=['img_id', 'is_training_img'])
48 |
49 | data = images.merge(image_class_labels, on='img_id')
50 | self.data = data.merge(train_test_split, on='img_id')
51 |
52 | class_names = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'classes.txt'),
53 | sep=' ', names=['class_name'], usecols=[1])
54 | self.class_names = class_names['class_name'].to_list()
55 | if self.train:
56 | self.data = self.data[self.data.is_training_img == 1]
57 | else:
58 | self.data = self.data[self.data.is_training_img == 0]
59 |
60 | def _check_integrity(self):
61 | try:
62 | self._load_metadata()
63 | except Exception:
64 | return False
65 |
66 | for index, row in self.data.iterrows():
67 | filepath = os.path.join(self.root, self.base_folder, row.filepath)
68 | if not os.path.isfile(filepath):
69 | print(filepath)
70 | return False
71 | return True
72 |
73 | def _download(self):
74 | import tarfile
75 |
76 | if self._check_integrity():
77 | print('Files already downloaded and verified')
78 | return
79 |
80 | download_file_from_google_drive(self.file_id, self.root, self.filename, self.tgz_md5)
81 |
82 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
83 | tar.extractall(path=self.root)
84 |
85 | def __len__(self):
86 | return len(self.data)
87 |
88 | def __getitem__(self, idx):
89 | sample = self.data.iloc[idx]
90 | path = os.path.join(self.root, self.base_folder, sample.filepath)
91 | target = sample.target - 1 # Targets start at 1 by default, so shift to 0
92 | img = self.loader(path)
93 |
94 | if self.transform is not None:
95 | img = self.transform(img)
96 | if self.target_transform is not None:
97 | target = self.target_transform(target)
98 | return img, target
99 |
100 |
101 | if __name__ == '__main__':
102 | train_dataset = Cub2011('./cub2011', train=True, download=False)
103 | test_dataset = Cub2011('./cub2011', train=False, download=False)
104 |
--------------------------------------------------------------------------------
/dogs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import scipy.io
3 | from os.path import join
4 | from torchvision.datasets import VisionDataset
5 | from torchvision.datasets.folder import default_loader
6 | from torchvision.datasets.utils import download_url, list_dir
7 |
8 |
9 | class Dogs(VisionDataset):
10 | """`Stanford Dogs `_ Dataset.
11 |
12 | Args:
13 | root (string): Root directory of the dataset.
14 | train (bool, optional): If True, creates dataset from training set, otherwise
15 | creates from test set.
16 | transform (callable, optional): A function/transform that takes in an PIL image
17 | and returns a transformed version. E.g, ``transforms.RandomCrop``
18 | target_transform (callable, optional): A function/transform that takes in the
19 | target and transforms it.
20 | download (bool, optional): If true, downloads the dataset from the internet and
21 | puts it in root directory. If dataset is already downloaded, it is not
22 | downloaded again.
23 | """
24 | download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs'
25 |
26 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
27 | super(Dogs, self).__init__(root, transform=transform, target_transform=target_transform)
28 |
29 | self.loader = default_loader
30 | self.train = train
31 |
32 | if download:
33 | self.download()
34 |
35 | split = self.load_split()
36 |
37 | self.images_folder = join(self.root, 'Images')
38 | self.annotations_folder = join(self.root, 'Annotation')
39 | self._breeds = list_dir(self.images_folder)
40 |
41 | self._breed_images = [(annotation + '.jpg', idx) for annotation, idx in split]
42 |
43 | self._flat_breed_images = self._breed_images
44 |
45 | def __len__(self):
46 | return len(self._flat_breed_images)
47 |
48 | def __getitem__(self, index):
49 | image_name, target = self._flat_breed_images[index]
50 | image_path = join(self.images_folder, image_name)
51 | image = self.loader(image_path)
52 |
53 | if self.transform is not None:
54 | image = self.transform(image)
55 | if self.target_transform is not None:
56 | target = self.target_transform(target)
57 | return image, target
58 |
59 | def download(self):
60 | import tarfile
61 |
62 | if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')):
63 | if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120:
64 | print('Files already downloaded and verified')
65 | return
66 |
67 | for filename in ['images', 'annotation', 'lists']:
68 | tar_filename = filename + '.tar'
69 | url = self.download_url_prefix + '/' + tar_filename
70 | download_url(url, self.root, tar_filename, None)
71 | print('Extracting downloaded file: ' + join(self.root, tar_filename))
72 | with tarfile.open(join(self.root, tar_filename), 'r') as tar_file:
73 | tar_file.extractall(self.root)
74 | os.remove(join(self.root, tar_filename))
75 |
76 | def load_split(self):
77 | if self.train:
78 | split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list']
79 | labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels']
80 | else:
81 | split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list']
82 | labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels']
83 |
84 | split = [item[0][0] for item in split]
85 | labels = [item[0] - 1 for item in labels]
86 | return list(zip(split, labels))
87 |
88 | def stats(self):
89 | counts = {}
90 | for index in range(len(self._flat_breed_images)):
91 | image_name, target_class = self._flat_breed_images[index]
92 | if target_class not in counts.keys():
93 | counts[target_class] = 1
94 | else:
95 | counts[target_class] += 1
96 |
97 | print("%d samples spanning %d classes (avg %f per class)" % (len(self._flat_breed_images), len(counts.keys()),
98 | float(len(self._flat_breed_images)) / float(
99 | len(counts.keys()))))
100 |
101 | return counts
102 |
103 |
104 | if __name__ == '__main__':
105 | train_dataset = Dogs('./dogs', train=True, download=False)
106 | test_dataset = Dogs('./dogs', train=False, download=False)
107 |
--------------------------------------------------------------------------------
/inat2017.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from torchvision.datasets import VisionDataset
4 | from torchvision.datasets.folder import default_loader
5 | from torchvision.datasets.utils import check_integrity, extract_archive
6 | from torchvision.datasets.utils import download_url, verify_str_arg
7 |
8 |
9 | class INat2017(VisionDataset):
10 | """`iNaturalist 2017 `_ Dataset.
11 |
12 | Args:
13 | root (string): Root directory of the dataset.
14 | split (string, optional): The dataset split, supports ``train``, or ``val``.
15 | transform (callable, optional): A function/transform that takes in an PIL image
16 | and returns a transformed version. E.g, ``transforms.RandomCrop``
17 | target_transform (callable, optional): A function/transform that takes in the
18 | target and transforms it.
19 | download (bool, optional): If true, downloads the dataset from the internet and
20 | puts it in root directory. If dataset is already downloaded, it is not
21 | downloaded again.
22 | """
23 | base_folder = 'train_val_images/'
24 | file_list = {
25 | 'imgs': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val_images.tar.gz',
26 | 'train_val_images.tar.gz',
27 | '7c784ea5e424efaec655bd392f87301f'),
28 | 'annos': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val2017.zip',
29 | 'train_val2017.zip',
30 | '444c835f6459867ad69fcb36478786e7')
31 | }
32 |
33 | def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
34 | super(INat2017, self).__init__(root, transform=transform, target_transform=target_transform)
35 | self.loader = default_loader
36 | self.split = verify_str_arg(split, "split", ("train", "val",))
37 |
38 | if self._check_exists():
39 | print('Files already downloaded and verified.')
40 | elif download:
41 | if not (os.path.exists(os.path.join(self.root, self.file_list['imgs'][1]))
42 | and os.path.exists(os.path.join(self.root, self.file_list['annos'][1]))):
43 | print('Downloading...')
44 | self._download()
45 | print('Extracting...')
46 | extract_archive(os.path.join(self.root, self.file_list['imgs'][1]))
47 | extract_archive(os.path.join(self.root, self.file_list['annos'][1]))
48 | else:
49 | raise RuntimeError(
50 | 'Dataset not found. You can use download=True to download it.')
51 | anno_filename = split + '2017.json'
52 | with open(os.path.join(self.root, anno_filename), 'r') as fp:
53 | all_annos = json.load(fp)
54 |
55 | self.annos = all_annos['annotations']
56 | self.images = all_annos['images']
57 |
58 | def __getitem__(self, index):
59 | path = os.path.join(self.root, self.images[index]['file_name'])
60 | target = self.annos[index]['category_id']
61 |
62 | image = self.loader(path)
63 | if self.transform is not None:
64 | image = self.transform(image)
65 | if self.target_transform is not None:
66 | target = self.target_transform(target)
67 |
68 | return image, target
69 |
70 | def __len__(self):
71 | return len(self.images)
72 |
73 | def _check_exists(self):
74 | return os.path.exists(os.path.join(self.root, self.base_folder))
75 |
76 | def _download(self):
77 | for url, filename, md5 in self.file_list.values():
78 | download_url(url, root=self.root, filename=filename)
79 | if not check_integrity(os.path.join(self.root, filename), md5):
80 | raise RuntimeError("File not found or corrupted.")
81 |
82 |
83 | if __name__ == '__main__':
84 | train_dataset = INat2017('./inat2017', split='train', download=False)
85 | test_dataset = INat2017('./inat2017', split='val', download=False)
86 |
--------------------------------------------------------------------------------
/nabirds.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import warnings
4 | from torchvision.datasets import VisionDataset
5 | from torchvision.datasets.folder import default_loader
6 | from torchvision.datasets.utils import check_integrity, extract_archive
7 |
8 |
9 | class NABirds(VisionDataset):
10 | """`NABirds `_ Dataset.
11 |
12 | Args:
13 | root (string): Root directory of the dataset.
14 | train (bool, optional): If True, creates dataset from training set, otherwise
15 | creates from test set.
16 | transform (callable, optional): A function/transform that takes in an PIL image
17 | and returns a transformed version. E.g, ``transforms.RandomCrop``
18 | target_transform (callable, optional): A function/transform that takes in the
19 | target and transforms it.
20 | download (bool, optional): If true, downloads the dataset from the internet and
21 | puts it in root directory. If dataset is already downloaded, it is not
22 | downloaded again.
23 | """
24 | base_folder = 'nabirds/images'
25 | filename = 'nabirds.tar.gz'
26 | md5 = 'df21a9e4db349a14e2b08adfd45873bd'
27 |
28 | def __init__(self, root, train=True, transform=None, target_transform=None, download=None):
29 | super(NABirds, self).__init__(root, transform=transform, target_transform=target_transform)
30 | if download is True:
31 | msg = ("The dataset is no longer publicly accessible. You need to "
32 | "download the archives externally and place them in the root "
33 | "directory.")
34 | raise RuntimeError(msg)
35 | elif download is False:
36 | msg = ("The use of the download flag is deprecated, since the dataset "
37 | "is no longer publicly accessible.")
38 | warnings.warn(msg, RuntimeWarning)
39 |
40 | dataset_path = os.path.join(root, 'nabirds')
41 | if not os.path.isdir(dataset_path):
42 | if not check_integrity(os.path.join(root, self.filename), self.md5):
43 | raise RuntimeError('Dataset not found or corrupted.')
44 | extract_archive(os.path.join(root, self.filename))
45 | self.loader = default_loader
46 | self.train = train
47 |
48 | image_paths = pd.read_csv(os.path.join(dataset_path, 'images.txt'),
49 | sep=' ', names=['img_id', 'filepath'])
50 | image_class_labels = pd.read_csv(os.path.join(dataset_path, 'image_class_labels.txt'),
51 | sep=' ', names=['img_id', 'target'])
52 | # Since the raw labels are non-continuous, map them to new ones
53 | self.label_map = get_continuous_class_map(image_class_labels['target'])
54 | train_test_split = pd.read_csv(os.path.join(dataset_path, 'train_test_split.txt'),
55 | sep=' ', names=['img_id', 'is_training_img'])
56 | data = image_paths.merge(image_class_labels, on='img_id')
57 | self.data = data.merge(train_test_split, on='img_id')
58 | # Load in the train / test split
59 | if self.train:
60 | self.data = self.data[self.data.is_training_img == 1]
61 | else:
62 | self.data = self.data[self.data.is_training_img == 0]
63 |
64 | # Load in the class data
65 | self.class_names = load_class_names(dataset_path)
66 | self.class_hierarchy = load_hierarchy(dataset_path)
67 |
68 | def __len__(self):
69 | return len(self.data)
70 |
71 | def __getitem__(self, idx):
72 | sample = self.data.iloc[idx]
73 | path = os.path.join(self.root, self.base_folder, sample.filepath)
74 | target = self.label_map[sample.target]
75 | img = self.loader(path)
76 |
77 | if self.transform is not None:
78 | img = self.transform(img)
79 | if self.target_transform is not None:
80 | target = self.target_transform(target)
81 | return img, target
82 |
83 |
84 | def get_continuous_class_map(class_labels):
85 | label_set = set(class_labels)
86 | return {k: i for i, k in enumerate(label_set)}
87 |
88 |
89 | def load_class_names(dataset_path=''):
90 | names = {}
91 |
92 | with open(os.path.join(dataset_path, 'classes.txt')) as f:
93 | for line in f:
94 | pieces = line.strip().split()
95 | class_id = pieces[0]
96 | names[class_id] = ' '.join(pieces[1:])
97 |
98 | return names
99 |
100 |
101 | def load_hierarchy(dataset_path=''):
102 | parents = {}
103 |
104 | with open(os.path.join(dataset_path, 'hierarchy.txt')) as f:
105 | for line in f:
106 | pieces = line.strip().split()
107 | child_id, parent_id = pieces
108 | parents[child_id] = parent_id
109 |
110 | return parents
111 |
112 |
113 | if __name__ == '__main__':
114 | train_dataset = NABirds('./nabirds', train=True, download=False)
115 | test_dataset = NABirds('./nabirds', train=False, download=False)
116 |
--------------------------------------------------------------------------------
/tiny_imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os
3 | import pandas as pd
4 | import warnings
5 | from torchvision.datasets import ImageFolder
6 | from torchvision.datasets import VisionDataset
7 | from torchvision.datasets.folder import default_loader
8 | from torchvision.datasets.folder import default_loader
9 | from torchvision.datasets.utils import extract_archive, check_integrity, download_url, verify_str_arg
10 |
11 |
12 | class TinyImageNet(VisionDataset):
13 | """`tiny-imageNet `_ Dataset.
14 |
15 | Args:
16 | root (string): Root directory of the dataset.
17 | split (string, optional): The dataset split, supports ``train``, or ``val``.
18 | transform (callable, optional): A function/transform that takes in an PIL image
19 | and returns a transformed version. E.g, ``transforms.RandomCrop``
20 | target_transform (callable, optional): A function/transform that takes in the
21 | target and transforms it.
22 | download (bool, optional): If true, downloads the dataset from the internet and
23 | puts it in root directory. If dataset is already downloaded, it is not
24 | downloaded again.
25 | """
26 | base_folder = 'tiny-imagenet-200/'
27 | url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
28 | filename = 'tiny-imagenet-200.zip'
29 | md5 = '90528d7ca1a48142e341f4ef8d21d0de'
30 |
31 | def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
32 | super(TinyImageNet, self).__init__(root, transform=transform, target_transform=target_transform)
33 |
34 | self.dataset_path = os.path.join(root, self.base_folder)
35 | self.loader = default_loader
36 | self.split = verify_str_arg(split, "split", ("train", "val",))
37 |
38 | if self._check_integrity():
39 | print('Files already downloaded and verified.')
40 | elif download:
41 | self._download()
42 | else:
43 | raise RuntimeError(
44 | 'Dataset not found. You can use download=True to download it.')
45 | if not os.path.isdir(self.dataset_path):
46 | print('Extracting...')
47 | extract_archive(os.path.join(root, self.filename))
48 |
49 | _, class_to_idx = find_classes(os.path.join(self.dataset_path, 'wnids.txt'))
50 |
51 | self.data = make_dataset(self.root, self.base_folder, self.split, class_to_idx)
52 |
53 | def _download(self):
54 | print('Downloading...')
55 | download_url(self.url, root=self.root, filename=self.filename)
56 | print('Extracting...')
57 | extract_archive(os.path.join(self.root, self.filename))
58 |
59 | def _check_integrity(self):
60 | return check_integrity(os.path.join(self.root, self.filename), self.md5)
61 |
62 | def __getitem__(self, index):
63 | img_path, target = self.data[index]
64 | image = self.loader(img_path)
65 |
66 | if self.transform is not None:
67 | image = self.transform(image)
68 | if self.target_transform is not None:
69 | target = self.target_transform(target)
70 |
71 | return image, target
72 |
73 | def __len__(self):
74 | return len(self.data)
75 |
76 |
77 | def find_classes(class_file):
78 | with open(class_file) as r:
79 | classes = list(map(lambda s: s.strip(), r.readlines()))
80 |
81 | classes.sort()
82 | class_to_idx = {classes[i]: i for i in range(len(classes))}
83 |
84 | return classes, class_to_idx
85 |
86 |
87 | def make_dataset(root, base_folder, dirname, class_to_idx):
88 | images = []
89 | dir_path = os.path.join(root, base_folder, dirname)
90 |
91 | if dirname == 'train':
92 | for fname in sorted(os.listdir(dir_path)):
93 | cls_fpath = os.path.join(dir_path, fname)
94 | if os.path.isdir(cls_fpath):
95 | cls_imgs_path = os.path.join(cls_fpath, 'images')
96 | for imgname in sorted(os.listdir(cls_imgs_path)):
97 | path = os.path.join(cls_imgs_path, imgname)
98 | item = (path, class_to_idx[fname])
99 | images.append(item)
100 | else:
101 | imgs_path = os.path.join(dir_path, 'images')
102 | imgs_annotations = os.path.join(dir_path, 'val_annotations.txt')
103 |
104 | with open(imgs_annotations) as r:
105 | data_info = map(lambda s: s.split('\t'), r.readlines())
106 |
107 | cls_map = {line_data[0]: line_data[1] for line_data in data_info}
108 |
109 | for imgname in sorted(os.listdir(imgs_path)):
110 | path = os.path.join(imgs_path, imgname)
111 | item = (path, class_to_idx[cls_map[imgname]])
112 | images.append(item)
113 |
114 | return images
115 |
116 |
117 | if __name__ == '__main__':
118 | train_dataset = TinyImageNet('./tiny-imagenet', split='train', download=False)
119 | test_dataset = TinyImageNet('./tiny-imagenet', split='val', download=False)
120 |
--------------------------------------------------------------------------------