├── src ├── downloader.py └── dataloader.py └── README.md /src/downloader.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import zipfile 5 | from concurrent.futures import ThreadPoolExecutor 6 | 7 | from tqdm import tqdm 8 | 9 | 10 | class CLDatasets: 11 | """ 12 | A class for downloading datasets from Google Cloud Storage. 13 | """ 14 | 15 | def __init__(self, dataset: str, directory: str, unzip: bool = True): 16 | """ 17 | Initialize the CLDatasets object. 18 | 19 | Args: 20 | dataset (str): The name of the dataset to download. 21 | directory (str): The directory where the dataset will be saved. 22 | """ 23 | if dataset not in ['CGLM', 'CLOC', 'ImageNet2K']: 24 | print("Dataset not found!") 25 | return 26 | else: 27 | self.dataset = dataset 28 | self.directory = directory 29 | 30 | if not os.path.exists(self.directory): 31 | os.makedirs(self.directory) 32 | 33 | print("Dataset Selected:", dataset) 34 | self.download_dataset() 35 | 36 | if unzip: 37 | self.unzip_data_files(self.directory+f"/{self.dataset}/data") 38 | 39 | def download_dataset(self): 40 | """ 41 | Download the order files from Google Cloud Storage. 42 | """ 43 | print("Order files are being downloaded...") 44 | start_time = time.time() 45 | download_command = f"gsutil -m cp -r gs://cl-datasets/{self.dataset} {self.directory}/" 46 | os.system(download_command) 47 | elapsed_time = time.time() - start_time 48 | print("Elapsed time:", elapsed_time) 49 | 50 | def unzip_data_files(self, directory: str) -> None: 51 | """ 52 | Extracts the contents of zip files in a directory into nested folders. 53 | 54 | Args: 55 | directory: The path to the directory containing the zip files. 56 | 57 | Returns: 58 | None 59 | """ 60 | 61 | zip_files = [file for file in os.listdir( 62 | directory) if file.endswith('.zip')] 63 | 64 | def extract_single_zip(zip_file: str) -> None: 65 | 66 | zip_path = os.path.join(directory, zip_file) 67 | output_dir = os.path.join( 68 | directory, os.path.splitext(zip_file)[0]) 69 | 70 | os.makedirs(output_dir, exist_ok=True) 71 | 72 | with zipfile.ZipFile(zip_path, 'r') as zip_ref: 73 | zip_ref.extractall(output_dir) 74 | 75 | with ThreadPoolExecutor() as executor, tqdm(total=len(zip_files)) as pbar: 76 | futures_list = [] 77 | for zip_file in zip_files: 78 | future = executor.submit(extract_single_zip, zip_file) 79 | future.add_done_callback(lambda p: pbar.update(1)) 80 | futures_list.append(future) 81 | 82 | # Wait for all tasks to complete 83 | for future in futures_list: 84 | future.result() 85 | 86 | # Remove zip files 87 | 88 | remove_command = f"rm {self.directory}/{self.dataset}/data/*.zip" 89 | os.system(remove_command) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser( 94 | description='Download datasets from Google Cloud Storage.') 95 | parser.add_argument('--dataset', type=str, default='CGLM', 96 | help='The name of the dataset to download.') 97 | parser.add_argument('--directory', type=str, default='/data/cl_datasets/files/CGLM/', 98 | help='The directory where the dataset will be saved.') 99 | parser.add_argument('--unzip', action='store_true', 100 | help='Whether to unzip the downloaded files.') 101 | 102 | args = parser.parse_args() 103 | 104 | gcp_cl_datasets = CLDatasets( 105 | dataset=args.dataset, 106 | directory=args.directory, 107 | unzip=args.unzip) 108 | -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable, Optional, Tuple 3 | 4 | import h5py 5 | from PIL import Image 6 | 7 | 8 | class BaseDataClass: 9 | """Base class for a data class.""" 10 | 11 | def __init__(self, dataset: str, directory: str): 12 | """ 13 | Initialize the BaseDataClass. 14 | 15 | Args: 16 | dataset (str): Name of the dataset. 17 | directory (str): Path to the directory containing the data. 18 | 19 | Raises: 20 | FileNotFoundError: If 'order_files' or 'data' directories are not found. 21 | """ 22 | self.dataset = dataset 23 | self.directory = directory 24 | 25 | # Check that 'order_files' and 'data' directories exist in the directory 26 | if not os.path.exists(os.path.join(self.directory, 'order_files')): 27 | raise FileNotFoundError("order_files directory not found!") 28 | 29 | if not os.path.exists(os.path.join(self.directory, 'data')): 30 | raise FileNotFoundError("data directory not found!") 31 | 32 | print( 33 | f"Found 'order_files' and 'data' directories for {self.dataset}!") 34 | 35 | def __getitem__(self, index): 36 | """ 37 | Get an item from the data class. 38 | 39 | Args: 40 | index: Index of the item to retrieve. 41 | 42 | Raises: 43 | NotImplementedError: This method should be implemented by subclasses. 44 | """ 45 | raise NotImplementedError 46 | 47 | def __len__(self): 48 | """ 49 | Get the length of the data class. 50 | 51 | Raises: 52 | NotImplementedError: This method should be implemented by subclasses. 53 | """ 54 | raise NotImplementedError 55 | 56 | 57 | class H5Dataset(BaseDataClass): 58 | def __init__(self, dataset: str, directory: str, partition: str, transform: Optional[Callable] = None): 59 | """ 60 | Initialize the H5Dataset. 61 | 62 | Args: 63 | dataset (str): Dataset name. 64 | dir (str): Directory path. 65 | partition (str): train, test, pretrain (all datasets), pretest, preval, cls_inc, data_inc (additional for ImageNet2K) 66 | transform (callable, optional): Transform to apply to the samples. Defaults to None. 67 | 68 | Raises: 69 | FileNotFoundError: If any of the required files is not found. 70 | """ 71 | super().__init__(dataset=dataset, directory=directory) 72 | self.directory = directory 73 | self.image_paths = h5py.File( 74 | f"{directory}/order_files/{partition}_image_paths.hdf5", "r")["store_list"] 75 | self.labels = h5py.File( 76 | f"{directory}/order_files/{partition}_labels.hdf5", "r")["store_list"] 77 | self.transform = transform 78 | 79 | assert len(self.image_paths) == len(self.labels) 80 | 81 | def __getitem__(self, index: int) -> Tuple[Image.Image, int]: 82 | """ 83 | Get an item from the H5Dataset. 84 | 85 | Args: 86 | index (int): Index of the item to retrieve. 87 | 88 | Returns: 89 | tuple: A tuple containing the sample and label. 90 | """ 91 | img_path = self.directory + '/data/' + \ 92 | self.image_paths[index].decode("utf-8").strip() 93 | label = self.labels[index] 94 | sample = pil_loader(img_path) 95 | 96 | if self.transform is not None: 97 | sample = self.transform(sample) 98 | 99 | return sample, label 100 | 101 | def __len__(self) -> int: 102 | """ 103 | Get the length of the H5Dataset. 104 | 105 | Returns: 106 | int: Length of the dataset. 107 | """ 108 | return len(self.image_paths) 109 | 110 | 111 | def pil_loader(path: str) -> Image.Image: 112 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 113 | with open(path, "rb") as f: 114 | img = Image.open(f) 115 | return img.convert("RGB") 116 | 117 | 118 | if __name__ == "__main__": 119 | BaseDataClass(dataset='ImageNet2K', 120 | directory='/data/cl_datasets/files/ImageNet2K/') 121 | 122 | dataset = H5Dataset( 123 | dataset='ImageNet2K', directory="/data/cl_datasets/files/ImageNet2K/", partition='data_inc') 124 | print(len(dataset)) 125 | dataset[1][0].show() 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Continual Learning Datasets 📚 2 | 3 | Welcome to the Continual Learning Datasets repository! Here, we aim to make large-scale continual learning datasets easily accessible for everyone. Our repository provides a convenient way to download three large scale diverse datasets: CLOC, CGLM, and ImageNet2K. Feel free to explore, experiment, and contribute to our repo! 4 | 5 | 6 |
7 |
8 |