├── README.md └── cub2011.py /README.md: -------------------------------------------------------------------------------- 1 | PyTorch dataset for CUB-200-2011 (http://www.vision.caltech.edu/visipedia/CUB-200-2011.html). 2 | -------------------------------------------------------------------------------- /cub2011.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from torchvision.datasets.folder import default_loader 4 | from torchvision.datasets.utils import download_url 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class Cub2011(Dataset): 9 | base_folder = 'CUB_200_2011/images' 10 | url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' 11 | filename = 'CUB_200_2011.tgz' 12 | tgz_md5 = '97eceeb196236b17998738112f37df78' 13 | 14 | def __init__(self, root, train=True, transform=None, loader=default_loader, download=True): 15 | self.root = os.path.expanduser(root) 16 | self.transform = transform 17 | self.loader = default_loader 18 | self.train = train 19 | 20 | if download: 21 | self._download() 22 | 23 | if not self._check_integrity(): 24 | raise RuntimeError('Dataset not found or corrupted.' + 25 | ' You can use download=True to download it') 26 | 27 | def _load_metadata(self): 28 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ', 29 | names=['img_id', 'filepath']) 30 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'), 31 | sep=' ', names=['img_id', 'target']) 32 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'), 33 | sep=' ', names=['img_id', 'is_training_img']) 34 | 35 | data = images.merge(image_class_labels, on='img_id') 36 | self.data = data.merge(train_test_split, on='img_id') 37 | 38 | if self.train: 39 | self.data = self.data[self.data.is_training_img == 1] 40 | else: 41 | self.data = self.data[self.data.is_training_img == 0] 42 | 43 | def _check_integrity(self): 44 | try: 45 | self._load_metadata() 46 | except Exception: 47 | return False 48 | 49 | for index, row in self.data.iterrows(): 50 | filepath = os.path.join(self.root, self.base_folder, row.filepath) 51 | if not os.path.isfile(filepath): 52 | print(filepath) 53 | return False 54 | return True 55 | 56 | def _download(self): 57 | import tarfile 58 | 59 | if self._check_integrity(): 60 | print('Files already downloaded and verified') 61 | return 62 | 63 | download_url(self.url, self.root, self.filename, self.tgz_md5) 64 | 65 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: 66 | tar.extractall(path=self.root) 67 | 68 | def __len__(self): 69 | return len(self.data) 70 | 71 | def __getitem__(self, idx): 72 | sample = self.data.iloc[idx] 73 | path = os.path.join(self.root, self.base_folder, sample.filepath) 74 | target = sample.target - 1 # Targets start at 1 by default, so shift to 0 75 | img = self.loader(path) 76 | 77 | if self.transform is not None: 78 | img = self.transform(img) 79 | 80 | return img, target 81 | --------------------------------------------------------------------------------