├── .gitignore ├── jigsaw.py ├── readme.md ├── rotation.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ -------------------------------------------------------------------------------- /jigsaw.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torchvision 8 | from torchvision.datasets import DatasetFolder 9 | from torchvision.models.resnet import resnet50 10 | 11 | from utils import (AverageMeter, Logger, Memory, ModelCheckpoint, 12 | NoiseContrastiveEstimator, Progbar, pil_loader) 13 | 14 | device = torch.device('cuda:0') 15 | data_dir = '/media/dysk/datasets/isic_challenge_2017/train' 16 | negative_nb = 1000 # number of negative examples in NCE 17 | lr = 0.001 18 | checkpoint_dir = 'jigsaw_models' 19 | log_filename = 'pretraining_log_jigsaw' 20 | 21 | 22 | class JigsawLoader(DatasetFolder): 23 | def __init__(self, root_dir): 24 | super(JigsawLoader, self).__init__(root_dir, pil_loader, extensions=('jpg')) 25 | self.root_dir = root_dir 26 | self.color_transform = torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4) 27 | self.flips = [torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.RandomVerticalFlip()] 28 | self.normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 29 | 30 | def __getitem__(self, index): 31 | """ 32 | Args: 33 | index (int): Index 34 | Returns: 35 | tuple: (sample, target) where target is class_index of the target class. 36 | """ 37 | path, _ = self.samples[index] 38 | original = self.loader(path) 39 | image = torchvision.transforms.Resize((224, 224))(original) 40 | sample = torchvision.transforms.RandomCrop((255, 255))(original) 41 | 42 | crop_areas = [(i*85, j*85, (i+1)*85, (j+1)*85) for i in range(3) for j in range(3)] 43 | samples = [sample.crop(crop_area) for crop_area in crop_areas] 44 | samples = [torchvision.transforms.RandomCrop((64, 64))(patch) for patch in samples] 45 | # augmentation collor jitter 46 | image = self.color_transform(image) 47 | samples = [self.color_transform(patch) for patch in samples] 48 | # augmentation - flips 49 | image = self.flips[0](image) 50 | image = self.flips[1](image) 51 | # to tensor 52 | image = torchvision.transforms.functional.to_tensor(image) 53 | samples = [torchvision.transforms.functional.to_tensor(patch) for patch in samples] 54 | # normalize 55 | image = self.normalize(image) 56 | samples = [self.normalize(patch) for patch in samples] 57 | random.shuffle(samples) 58 | 59 | return {'original': image, 'patches': samples, 'index': index} 60 | 61 | 62 | dataset = JigsawLoader(data_dir) 63 | train_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=32, num_workers=32) 64 | 65 | 66 | class Network(nn.Module): 67 | def __init__(self): 68 | super(Network, self).__init__() 69 | self.network = resnet50() 70 | self.network = torch.nn.Sequential(*list(self.network.children())[:-1]) 71 | self.projection_original_features = nn.Linear(2048, 128) 72 | self.connect_patches_feature = nn.Linear(1152, 128) 73 | 74 | def forward_once(self, x): 75 | return self.network(x) 76 | 77 | def return_reduced_image_features(self, original): 78 | original_features = self.forward_once(original) 79 | original_features = original_features.view(-1, 2048) 80 | original_features = self.projection_original_features(original_features) 81 | return original_features 82 | 83 | def return_reduced_image_patches_features(self, original, patches): 84 | original_features = self.return_reduced_image_features(original) 85 | 86 | patches_features = [] 87 | for i, patch in enumerate(patches): 88 | patch_features = self.return_reduced_image_features(patch) 89 | patches_features.append(patch_features) 90 | 91 | patches_features = torch.cat(patches_features, axis=1) 92 | 93 | patches_features = self.connect_patches_feature(patches_features) 94 | return original_features, patches_features 95 | 96 | def forward(self, images=None, patches=None, mode=0): 97 | ''' 98 | mode 0: get 128 feature for image, 99 | mode 1: get 128 feature for image and patches 100 | ''' 101 | if mode == 0: 102 | return self.return_reduced_image_features(images) 103 | if mode == 1: 104 | return self.return_reduced_image_patches_features(images, patches) 105 | 106 | 107 | net = Network().to(device) 108 | optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9) 109 | 110 | memory = Memory(size=len(dataset), weight=0.5, device=device) 111 | memory.initialize(net, train_loader) 112 | 113 | 114 | checkpoint = ModelCheckpoint(mode='min', directory=checkpoint_dir) 115 | noise_contrastive_estimator = NoiseContrastiveEstimator(device) 116 | logger = Logger(log_filename) 117 | 118 | loss_weight = 0.5 119 | 120 | for epoch in range(1000): 121 | print('\nEpoch: {}'.format(epoch)) 122 | memory.update_weighted_count() 123 | train_loss = AverageMeter('train_loss') 124 | bar = Progbar(len(train_loader), stateful_metrics=['train_loss', 'valid_loss']) 125 | 126 | for step, batch in enumerate(train_loader): 127 | 128 | # prepare batch 129 | images = batch['original'].to(device) 130 | patches = [element.to(device) for element in batch['patches']] 131 | index = batch['index'] 132 | representations = memory.return_representations(index).to(device).detach() 133 | # zero grad 134 | optimizer.zero_grad() 135 | 136 | #forward, loss, backward, step 137 | output = net(images=images, patches=patches, mode=1) 138 | 139 | loss_1 = noise_contrastive_estimator(representations, output[1], index, memory, negative_nb=negative_nb) 140 | loss_2 = noise_contrastive_estimator(representations, output[0], index, memory, negative_nb=negative_nb) 141 | loss = loss_weight * loss_1 + (1 - loss_weight) * loss_2 142 | 143 | loss.backward() 144 | optimizer.step() 145 | 146 | # update representation memory 147 | memory.update(index, output[0].detach().cpu().numpy()) 148 | 149 | # update metric and bar 150 | train_loss.update(loss.item(), images.shape[0]) 151 | bar.update(step, values=[('train_loss', train_loss.return_avg())]) 152 | logger.update(epoch, train_loss.return_avg()) 153 | 154 | # save model if improved 155 | checkpoint.save_model(net, train_loss.return_avg(), epoch) 156 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Self-Supervised Learning of Pretext-Invariant Representations - implementation 2 | 3 | My implementation of the paper "Self-Supervised Learning of Pretext-Invariant Representations" by Ishan Misra and Laurens van der Maaten (https://arxiv.org/abs/1912.01991). 4 | 5 | The main objective of the algorithm is to learn representations that are invariant to the transformations. This is obtained by the loss function that minimizes the distance between representations of the original image and its transformation. At the same time, the distance between the original image and other images from the dataset is maximized. 6 | 7 | The implementation contains two types of pretext task: 8 | - Jigsaw transformation pretext task (jigsaw.py) 9 | - Rotation pretext task (rotation.py) 10 | 11 | I decided to test the code on the small dataset of skin lesions (2000 images). My initial experiments have shown promising results of using self-supervised pretraining on a small dataset (both self-supervised pretraining and downstream task training were performed the small dataset): 12 | - training of ResNet50 from scratch (AUC around 0.55) 13 | - training of ResNet50 initialized by weights obtained by self-supervised pretraining (AUC around 0.7) 14 | - training of ResNet50 initialized by weights of network pre-trained on Imagenet (AUC around 0.8) 15 | 16 | 17 | ## Dependencies 18 | - Pytorch 1.3.1 19 | - PIL 20 | - Numpy 21 | 22 | 23 | -------------------------------------------------------------------------------- /rotation.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torchvision 8 | from torchvision.datasets import DatasetFolder 9 | from torchvision.models.resnet import resnet50 10 | 11 | from utils import (AverageMeter, Logger, Memory, ModelCheckpoint, 12 | NoiseContrastiveEstimator, Progbar, pil_loader) 13 | 14 | device = torch.device('cuda:2') 15 | data_dir = '/media/dysk/datasets/isic_challenge_2017/train' 16 | negative_nb = 1000 # number of negative examples in NCE 17 | lr = 0.001 18 | checkpoint_dir = 'rotation_models' 19 | log_filename = 'pretraining_log_rotation' 20 | 21 | 22 | class RotationLoader(DatasetFolder): 23 | def __init__(self, root_dir): 24 | super(RotationLoader, self).__init__(root_dir, pil_loader, extensions=('jpg')) 25 | self.root_dir = root_dir 26 | self.color_transform = torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4) 27 | self.flips = [torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.RandomVerticalFlip()] 28 | self.normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 29 | 30 | def __getitem__(self, index): 31 | """ 32 | Args: 33 | index (int): Index 34 | Returns: 35 | tuple: (sample, target) where target is class_index of the target class. 36 | """ 37 | path, _ = self.samples[index] 38 | original = self.loader(path) 39 | image = torchvision.transforms.Resize((300, 300))(original) 40 | image = torchvision.transforms.RandomCrop((224, 224))(image) 41 | 42 | rotation = torchvision.transforms.Resize((224, 224))(image) 43 | # augmentation - collor jitter 44 | image = self.color_transform(image) 45 | rotation = self.color_transform(rotation) 46 | # augmentation - flips 47 | image = self.flips[0](image) 48 | image = self.flips[1](image) 49 | # augmentation - rotation 50 | angles = [90, 180, 270] 51 | angle = random.choice(angles) 52 | rotation = torchvision.transforms.functional.rotate(rotation, angle) 53 | 54 | # to tensor 55 | image = torchvision.transforms.functional.to_tensor(image) 56 | rotation = torchvision.transforms.functional.to_tensor(rotation) 57 | # normalize 58 | image = self.normalize(image) 59 | rotation = self.normalize(rotation) 60 | 61 | return {'original': image, 'rotation': rotation, 'index': index} 62 | 63 | 64 | dataset = RotationLoader(data_dir) 65 | train_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=32, num_workers=32) 66 | 67 | 68 | class Network(nn.Module): 69 | def __init__(self): 70 | super(Network, self).__init__() 71 | self.network = resnet50() 72 | self.network = torch.nn.Sequential(*list(self.network.children())[:-1]) 73 | self.projection_original_features = nn.Linear(2048, 128) 74 | 75 | def forward_once(self, x): 76 | return self.network(x) 77 | 78 | def return_reduced_image_features(self, original): 79 | features = self.forward_once(original) 80 | features = features.view(-1, 2048) 81 | features = self.projection_original_features(features) 82 | return features 83 | 84 | def forward(self, images=None, rotation=None, mode=0): 85 | ''' 86 | mode 0: get 128d feature for image, 87 | mode 1: get 128d feature for image and rotation 88 | 89 | ''' 90 | if mode == 0: 91 | return self.return_reduced_image_features(images) 92 | if mode == 1: 93 | image_features = self.return_reduced_image_features(images) 94 | rotation_features = self.return_reduced_image_features(rotation) 95 | return image_features, rotation_features 96 | 97 | 98 | net = Network().to(device) 99 | 100 | optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9) 101 | 102 | memory = Memory(size=len(dataset), weight=0.5, device=device) 103 | memory.initialize(net, train_loader) 104 | 105 | 106 | checkpoint = ModelCheckpoint(mode='min', directory=checkpoint_dir) 107 | noise_contrastive_estimator = NoiseContrastiveEstimator(device) 108 | logger = Logger(log_filename) 109 | 110 | loss_weight = 0.5 111 | 112 | for epoch in range(1000): 113 | print('\nEpoch: {}'.format(epoch)) 114 | memory.update_weighted_count() 115 | train_loss = AverageMeter('train_loss') 116 | bar = Progbar(len(train_loader), stateful_metrics=['train_loss', 'valid_loss']) 117 | 118 | for step, batch in enumerate(train_loader): 119 | 120 | # prepare batch 121 | images = batch['original'].to(device) 122 | rotation = batch['rotation'].to(device) 123 | index = batch['index'] 124 | representations = memory.return_representations(index).to(device).detach() 125 | # zero grad 126 | optimizer.zero_grad() 127 | 128 | #forward, loss, backward, step 129 | output = net(images=images, rotation=rotation, mode=1) 130 | 131 | loss_1 = noise_contrastive_estimator(representations, output[1], index, memory, negative_nb=negative_nb) 132 | loss_2 = noise_contrastive_estimator(representations, output[0], index, memory, negative_nb=negative_nb) 133 | loss = loss_weight * loss_1 + (1 - loss_weight) * loss_2 134 | 135 | loss.backward() 136 | optimizer.step() 137 | 138 | # update representation memory 139 | memory.update(index, output[0].detach().cpu().numpy()) 140 | 141 | # update metric and bar 142 | train_loss.update(loss.item(), images.shape[0]) 143 | bar.update(step, values=[('train_loss', train_loss.return_avg())]) 144 | logger.update(epoch, train_loss.return_avg()) 145 | 146 | # save model if improved 147 | checkpoint.save_model(net, train_loss.return_avg(), epoch) 148 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import datetime 3 | import os 4 | import random 5 | import shutil 6 | import sys 7 | import time 8 | 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | 13 | 14 | class AverageMeter(object): 15 | ''' 16 | Taken from: 17 | https://github.com/keras-team/keras 18 | ''' 19 | """Computes and stores the average and current value""" 20 | 21 | def __init__(self, name, fmt=':f'): 22 | self.name = name 23 | self.fmt = fmt 24 | self.reset() 25 | 26 | def reset(self): 27 | self.val = 0 28 | self.avg = 0 29 | self.sum = 0 30 | self.count = 0 31 | 32 | def update(self, val, n=1): 33 | self.val = val 34 | self.sum += val * n 35 | self.count += n 36 | self.avg = self.sum / self.count 37 | 38 | def __str__(self): 39 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 40 | return fmtstr.format(**self.__dict__) 41 | 42 | def return_avg(self): 43 | return self.avg 44 | 45 | 46 | class Progbar(object): 47 | ''' 48 | Taken from: 49 | https://github.com/keras-team/keras 50 | ''' 51 | """Displays a progress bar. 52 | # Arguments 53 | target: Total number of steps expected, None if unknown. 54 | width: Progress bar width on screen. 55 | verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 56 | stateful_metrics: Iterable of string names of metrics that 57 | should *not* be averaged over time. Metrics in this list 58 | will be displayed as-is. All others will be averaged 59 | by the progbar before display. 60 | interval: Minimum visual progress update interval (in seconds). 61 | """ 62 | 63 | def __init__(self, target, width=30, verbose=1, interval=0.05, 64 | stateful_metrics=None): 65 | self.target = target 66 | self.width = width 67 | self.verbose = verbose 68 | self.interval = interval 69 | if stateful_metrics: 70 | self.stateful_metrics = set(stateful_metrics) 71 | else: 72 | self.stateful_metrics = set() 73 | 74 | self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and 75 | sys.stdout.isatty()) or 76 | 'ipykernel' in sys.modules) 77 | self._total_width = 0 78 | self._seen_so_far = 0 79 | self._values = collections.OrderedDict() 80 | self._start = time.time() 81 | self._last_update = 0 82 | 83 | def update(self, current, values=None): 84 | """Updates the progress bar. 85 | # Arguments 86 | current: Index of current step. 87 | values: List of tuples: 88 | `(name, value_for_last_step)`. 89 | If `name` is in `stateful_metrics`, 90 | `value_for_last_step` will be displayed as-is. 91 | Else, an average of the metric over time will be displayed. 92 | """ 93 | values = values or [] 94 | for k, v in values: 95 | if k not in self.stateful_metrics: 96 | if k not in self._values: 97 | self._values[k] = [v * (current - self._seen_so_far), 98 | current - self._seen_so_far] 99 | else: 100 | self._values[k][0] += v * (current - self._seen_so_far) 101 | self._values[k][1] += (current - self._seen_so_far) 102 | else: 103 | # Stateful metrics output a numeric value. This representation 104 | # means "take an average from a single value" but keeps the 105 | # numeric formatting. 106 | self._values[k] = [v, 1] 107 | self._seen_so_far = current 108 | 109 | now = time.time() 110 | info = ' - %.0fs' % (now - self._start) 111 | if self.verbose == 1: 112 | if (now - self._last_update < self.interval and 113 | self.target is not None and current < self.target): 114 | return 115 | 116 | prev_total_width = self._total_width 117 | if self._dynamic_display: 118 | sys.stdout.write('\b' * prev_total_width) 119 | sys.stdout.write('\r') 120 | else: 121 | sys.stdout.write('\n') 122 | 123 | if self.target is not None: 124 | numdigits = int(np.floor(np.log10(self.target))) + 1 125 | barstr = '%%%dd/%d [' % (numdigits, self.target) 126 | bar = barstr % current 127 | prog = float(current) / self.target 128 | prog_width = int(self.width * prog) 129 | if prog_width > 0: 130 | bar += ('=' * (prog_width - 1)) 131 | if current < self.target: 132 | bar += '>' 133 | else: 134 | bar += '=' 135 | bar += ('.' * (self.width - prog_width)) 136 | bar += ']' 137 | else: 138 | bar = '%7d/Unknown' % current 139 | 140 | self._total_width = len(bar) 141 | sys.stdout.write(bar) 142 | 143 | if current: 144 | time_per_unit = (now - self._start) / current 145 | else: 146 | time_per_unit = 0 147 | if self.target is not None and current < self.target: 148 | eta = time_per_unit * (self.target - current) 149 | if eta > 3600: 150 | eta_format = ('%d:%02d:%02d' % 151 | (eta // 3600, (eta % 3600) // 60, eta % 60)) 152 | elif eta > 60: 153 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 154 | else: 155 | eta_format = '%ds' % eta 156 | 157 | info = ' - ETA: %s' % eta_format 158 | else: 159 | if time_per_unit >= 1: 160 | info += ' %.0fs/step' % time_per_unit 161 | elif time_per_unit >= 1e-3: 162 | info += ' %.0fms/step' % (time_per_unit * 1e3) 163 | else: 164 | info += ' %.0fus/step' % (time_per_unit * 1e6) 165 | 166 | for k in self._values: 167 | info += ' - %s:' % k 168 | if isinstance(self._values[k], list): 169 | avg = np.mean( 170 | self._values[k][0] / max(1, self._values[k][1])) 171 | if abs(avg) > 1e-3: 172 | info += ' %.4f' % avg 173 | else: 174 | info += ' %.4e' % avg 175 | else: 176 | info += ' %s' % self._values[k] 177 | 178 | self._total_width += len(info) 179 | if prev_total_width > self._total_width: 180 | info += (' ' * (prev_total_width - self._total_width)) 181 | 182 | if self.target is not None and current >= self.target: 183 | info += '\n' 184 | 185 | sys.stdout.write(info) 186 | sys.stdout.flush() 187 | 188 | elif self.verbose == 2: 189 | if self.target is None or current >= self.target: 190 | for k in self._values: 191 | info += ' - %s:' % k 192 | avg = np.mean( 193 | self._values[k][0] / max(1, self._values[k][1])) 194 | if avg > 1e-3: 195 | info += ' %.4f' % avg 196 | else: 197 | info += ' %.4e' % avg 198 | info += '\n' 199 | 200 | sys.stdout.write(info) 201 | sys.stdout.flush() 202 | 203 | self._last_update = now 204 | 205 | def add(self, n, values=None): 206 | self.update(self._seen_so_far + n, values) 207 | 208 | 209 | class Memory(object): 210 | def __init__(self, device, size=2000, weight=0.5): 211 | self.memory = np.zeros((size, 128)) 212 | self.weighted_sum = np.zeros((size, 128)) 213 | self.weighted_count = 0 214 | self.weight = weight 215 | self.device = device 216 | 217 | def initialize(self, net, train_loader): 218 | self.update_weighted_count() 219 | print('Saving representations to memory') 220 | bar = Progbar(len(train_loader), stateful_metrics=[]) 221 | for step, batch in enumerate(train_loader): 222 | with torch.no_grad(): 223 | images = batch['original'].to(self.device) 224 | index = batch['index'] 225 | output = net(images=images, mode=0) 226 | self.weighted_sum[index, :] = output.cpu().numpy() 227 | self.memory[index, :] = self.weighted_sum[index, :] 228 | bar.update(step, values=[]) 229 | 230 | def update(self, index, values): 231 | self.weighted_sum[index, :] = values + (1 - self.weight) * self.weighted_sum[index, :] 232 | self.memory[index, :] = self.weighted_sum[index, :]/self.weighted_count 233 | pass 234 | 235 | def update_weighted_count(self): 236 | self.weighted_count = 1 + (1 - self.weight) * self.weighted_count 237 | 238 | def return_random(self, size, index): 239 | if isinstance(index, torch.Tensor): 240 | index = index.tolist() 241 | #allowed = [x for x in range(2000) if x not in index] 242 | allowed = [x for x in range(index[0])] + [x for x in range(index[0] + 1, 2000)] 243 | index = random.sample(allowed, size) 244 | return self.memory[index, :] 245 | 246 | def return_representations(self, index): 247 | if isinstance(index, torch.Tensor): 248 | index = index.tolist() 249 | return torch.Tensor(self.memory[index, :]) 250 | 251 | 252 | class ModelCheckpoint(): 253 | def __init__(self, mode, directory): 254 | self.directory = directory 255 | if mode == 'min': 256 | self.best = np.inf 257 | self.monitor_op = np.less 258 | elif mode == 'max': 259 | self.best = 0 260 | self.monitor_op = np.greater 261 | else: 262 | print('\nChose mode \'min\' or \'max\'') 263 | raise Exception('Mode should be either min or max') 264 | if os.path.isdir(self.directory): 265 | shutil.rmtree(self.directory) 266 | os.mkdir(self.directory) 267 | else: 268 | os.mkdir(self.directory) 269 | 270 | def save_model(self, model, current_value, epoch): 271 | if self.monitor_op(current_value, self.best): 272 | print('\nSave model, best value {:.3f}, epoch: {}'.format(current_value, epoch)) 273 | self.best = current_value 274 | torch.save(model.state_dict(), os.path.join(self.directory, 'epoch_{}'.format(epoch))) 275 | 276 | 277 | class NoiseContrastiveEstimator(): 278 | def __init__(self, device): 279 | self.device = device 280 | 281 | def __call__(self, original_features, path_features, index, memory, negative_nb=1000): 282 | loss = 0 283 | for i in range(original_features.shape[0]): 284 | 285 | temp = 0.07 286 | cos = torch.nn.CosineSimilarity() 287 | criterion = torch.nn.CrossEntropyLoss() 288 | 289 | negative = memory.return_random(size=negative_nb, index=[index[i]]) 290 | negative = torch.Tensor(negative).to(self.device).detach() 291 | 292 | image_to_modification_similarity = cos(original_features[None, i, :], path_features[None, i, :])/temp 293 | matrix_of_similarity = cos(path_features[None, i, :], negative) / temp 294 | 295 | similarities = torch.cat((image_to_modification_similarity, matrix_of_similarity)) 296 | loss += criterion(similarities[None, :], torch.tensor([0]).to(self.device)) 297 | return loss / original_features.shape[0] 298 | 299 | 300 | def pil_loader(path): 301 | with open(path, 'rb') as f: 302 | img = Image.open(f) 303 | return img.convert('RGB') 304 | 305 | 306 | class Logger: 307 | def __init__(self, file_name): 308 | self.file_name = file_name 309 | index = ['Epoch'] 310 | with open('{}.csv'.format(self.file_name), 'w') as file: 311 | file.write('Epoch,Loss,Time\n') 312 | 313 | def update(self, epoch, loss): 314 | now = datetime.datetime.now() 315 | with open('{}.csv'.format(self.file_name), 'a') as file: 316 | file.write('{},{:.4f},{}\n'.format(epoch, loss, now)) 317 | --------------------------------------------------------------------------------