├── requirements.txt ├── data ├── id-1.pt ├── id-2.pt ├── id-3.pt └── id-4.pt ├── my_classes.py └── pytorch_script.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | -------------------------------------------------------------------------------- /data/id-1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shervinea/pytorch-data-generator/HEAD/data/id-1.pt -------------------------------------------------------------------------------- /data/id-2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shervinea/pytorch-data-generator/HEAD/data/id-2.pt -------------------------------------------------------------------------------- /data/id-3.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shervinea/pytorch-data-generator/HEAD/data/id-3.pt -------------------------------------------------------------------------------- /data/id-4.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shervinea/pytorch-data-generator/HEAD/data/id-4.pt -------------------------------------------------------------------------------- /my_classes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | 4 | class Dataset(data.Dataset): 5 | 'Characterizes a dataset for PyTorch' 6 | def __init__(self, list_IDs, labels): 7 | 'Initialization' 8 | self.labels = labels 9 | self.list_IDs = list_IDs 10 | 11 | def __len__(self): 12 | 'Denotes the total number of samples' 13 | return len(self.list_IDs) 14 | 15 | def __getitem__(self, index): 16 | 'Generates one sample of data' 17 | # Select sample 18 | ID = self.list_IDs[index] 19 | 20 | # Load data and get label 21 | X = torch.load('data/' + ID + '.pt') 22 | y = self.labels[ID] 23 | 24 | return X, y 25 | -------------------------------------------------------------------------------- /pytorch_script.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | 4 | from my_classes import Dataset 5 | 6 | 7 | # CUDA for PyTorch 8 | use_cuda = torch.cuda.is_available() 9 | device = torch.device("cuda:0" if use_cuda else "cpu") 10 | cudnn.benchmark = True 11 | 12 | # Parameters 13 | params = {'batch_size': 1, 14 | 'shuffle': True, 15 | 'num_workers': 6} 16 | max_epochs = 100 17 | 18 | # Datasets 19 | partition = {'train': ['id-1', 'id-2', 'id-3'], 'validation': ['id-4']} # IDs 20 | labels = {'id-1': 0, 'id-2': 1, 'id-3': 2, 'id-4': 1} # Labels 21 | 22 | # Generators 23 | training_set = Dataset(partition['train'], labels) 24 | training_generator = data.DataLoader(training_set, **params) 25 | 26 | validation_set = Dataset(partition['validation'], labels) 27 | validation_generator = data.DataLoader(validation_set, **params) 28 | 29 | # Loop over epochs 30 | for epoch in range(max_epochs): 31 | # Training 32 | for local_batch, local_labels in training_generator: 33 | # Transfer to GPU 34 | local_batch, local_labels = local_batch.to(device), local_labels.to(device) 35 | 36 | # Model computations 37 | [...] 38 | 39 | # Validation 40 | with torch.set_grad_enabled(False): 41 | for local_batch, local_labels in validation_generator: 42 | # Transfer to GPU 43 | local_batch, local_labels = local_batch.to(device), local_labels.to(device) 44 | 45 | # Model computations 46 | [...] 47 | --------------------------------------------------------------------------------