├── SimCLR ├── models │ ├── __init__.py │ ├── projector.py │ ├── resnet.py │ └── imagenet_resnet.py ├── warmup_scheduler │ ├── __init__.py │ ├── run.py │ └── scheduler.py ├── loss.py ├── argument.py ├── model_loader.py ├── train_contrastive.py ├── data_loader.py └── utils.py ├── requirements.txt ├── OpenCoS ├── splits_img │ ├── bird_cls_idx.npy │ ├── bird_cls_test.npy │ ├── dog_cls_idx.npy │ ├── dog_cls_test.npy │ ├── dog_cls_train.npy │ ├── food_cls_idx.npy │ ├── food_cls_test.npy │ ├── bird_cls_train.npy │ ├── food_cls_train.npy │ ├── insect_cls_idx.npy │ ├── insect_cls_test.npy │ ├── primate_cls_idx.npy │ ├── produce_cls_idx.npy │ ├── reptile_cls_idx.npy │ ├── scenery_cls_idx.npy │ ├── dog_cls_25pc_train.npy │ ├── insect_cls_train.npy │ ├── primate_cls_test.npy │ ├── primate_cls_train.npy │ ├── produce_cls_test.npy │ ├── produce_cls_train.npy │ ├── reptile_cls_test.npy │ ├── reptile_cls_train.npy │ ├── scenery_cls_test.npy │ ├── scenery_cls_train.npy │ ├── bird_cls_25pc_train.npy │ ├── food_cls_25pc_train.npy │ ├── aquatic_animal_cls_idx.npy │ ├── aquatic_animal_cls_test.npy │ ├── insect_cls_25pc_train.npy │ ├── primate_cls_25pc_train.npy │ ├── produce_cls_25pc_train.npy │ ├── reptile_cls_25pc_train.npy │ ├── scenery_cls_25pc_train.npy │ ├── aquatic_animal_cls_train.npy │ └── aquatic_animal_cls_25pc_train.npy ├── splits │ ├── cifar100_4pc_label_idx.npy │ ├── cifar10_25pc_label_idx.npy │ ├── cifar10_4pc_label_idx.npy │ ├── svhn_unlabel_train_idx.npy │ ├── tiny_unlabel_train_idx.npy │ ├── cifar100_100pc_label_idx.npy │ ├── cifar100_25pc_label_idx.npy │ ├── cifar10_400pc_label_idx.npy │ ├── cifar10_animal_test_idx.npy │ ├── cifar100_unlabel_train_idx.npy │ ├── cifar10_unlabel_train_idx.npy │ ├── cifar10_animal_25pc_label_idx.npy │ ├── cifar10_animal_4pc_label_idx.npy │ └── cifar10_notanimal_unlabel_idx.npy ├── models │ ├── __init__.py │ ├── wide_resnet.py │ ├── wide_resnet_auxbn.py │ ├── cifar_resnet.py │ ├── cifar_resnet_auxbn.py │ ├── resnet_auxbn.py │ ├── resnet.py │ ├── aux_batchnorm.py │ └── meta_resnet.py ├── imbalanced.py ├── utils.py ├── randaugment.py ├── datasets.py ├── train.py └── train_fixmatch.py └── README.md /SimCLR/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | -------------------------------------------------------------------------------- /SimCLR/warmup_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from warmup_scheduler.scheduler import GradualWarmupScheduler 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.4.0 2 | torchvision==0.5.0 3 | torchlars==0.1.2 4 | diffdist==0.1 5 | tensorboardX==2.0 6 | -------------------------------------------------------------------------------- /OpenCoS/splits_img/bird_cls_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/bird_cls_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/bird_cls_test.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/bird_cls_test.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/dog_cls_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/dog_cls_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/dog_cls_test.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/dog_cls_test.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/dog_cls_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/dog_cls_train.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/food_cls_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/food_cls_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/food_cls_test.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/food_cls_test.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/bird_cls_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/bird_cls_train.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/food_cls_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/food_cls_train.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/insect_cls_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/insect_cls_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/insect_cls_test.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/insect_cls_test.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/primate_cls_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/primate_cls_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/produce_cls_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/produce_cls_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/reptile_cls_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/reptile_cls_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/scenery_cls_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/scenery_cls_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits/cifar100_4pc_label_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits/cifar100_4pc_label_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits/cifar10_25pc_label_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits/cifar10_25pc_label_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits/cifar10_4pc_label_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits/cifar10_4pc_label_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits/svhn_unlabel_train_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits/svhn_unlabel_train_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits/tiny_unlabel_train_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits/tiny_unlabel_train_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/dog_cls_25pc_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/dog_cls_25pc_train.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/insect_cls_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/insect_cls_train.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/primate_cls_test.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/primate_cls_test.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/primate_cls_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/primate_cls_train.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/produce_cls_test.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/produce_cls_test.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/produce_cls_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/produce_cls_train.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/reptile_cls_test.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/reptile_cls_test.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/reptile_cls_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/reptile_cls_train.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/scenery_cls_test.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/scenery_cls_test.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/scenery_cls_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/scenery_cls_train.npy -------------------------------------------------------------------------------- /OpenCoS/splits/cifar100_100pc_label_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits/cifar100_100pc_label_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits/cifar100_25pc_label_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits/cifar100_25pc_label_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits/cifar10_400pc_label_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits/cifar10_400pc_label_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits/cifar10_animal_test_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits/cifar10_animal_test_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/bird_cls_25pc_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/bird_cls_25pc_train.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/food_cls_25pc_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/food_cls_25pc_train.npy -------------------------------------------------------------------------------- /OpenCoS/splits/cifar100_unlabel_train_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits/cifar100_unlabel_train_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits/cifar10_unlabel_train_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits/cifar10_unlabel_train_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/aquatic_animal_cls_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/aquatic_animal_cls_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/aquatic_animal_cls_test.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/aquatic_animal_cls_test.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/insect_cls_25pc_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/insect_cls_25pc_train.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/primate_cls_25pc_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/primate_cls_25pc_train.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/produce_cls_25pc_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/produce_cls_25pc_train.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/reptile_cls_25pc_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/reptile_cls_25pc_train.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/scenery_cls_25pc_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/scenery_cls_25pc_train.npy -------------------------------------------------------------------------------- /OpenCoS/splits/cifar10_animal_25pc_label_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits/cifar10_animal_25pc_label_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits/cifar10_animal_4pc_label_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits/cifar10_animal_4pc_label_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits/cifar10_notanimal_unlabel_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits/cifar10_notanimal_unlabel_idx.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/aquatic_animal_cls_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/aquatic_animal_cls_train.npy -------------------------------------------------------------------------------- /OpenCoS/splits_img/aquatic_animal_cls_25pc_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/OpenCoS/HEAD/OpenCoS/splits_img/aquatic_animal_cls_25pc_train.npy -------------------------------------------------------------------------------- /SimCLR/warmup_scheduler/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from warmup_scheduler import GradualWarmupScheduler 4 | 5 | 6 | if __name__ == '__main__': 7 | v = torch.zeros(10) 8 | optim = torch.optim.SGD([v], lr=0.01) 9 | scheduler = GradualWarmupScheduler(optim, multiplier=8, total_epoch=10) 10 | 11 | for epoch in range(1, 20): 12 | scheduler.step(epoch) 13 | 14 | print(epoch, optim.param_groups[0]['lr']) 15 | -------------------------------------------------------------------------------- /SimCLR/models/projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Projector(nn.Module): 6 | def __init__(self, expansion=4): 7 | super(Projector, self).__init__() 8 | 9 | if expansion == 0: 10 | self.linear_1 = nn.Linear(128, 128) 11 | self.linear_2 = nn.Linear(128, 128) 12 | else: 13 | self.linear_1 = nn.Linear(512*expansion, 2048) 14 | self.linear_2 = nn.Linear(2048, 2048) 15 | 16 | def forward(self, x, internal_output_list=False): 17 | 18 | #output_list = [] 19 | 20 | output = self.linear_1(x) 21 | output = F.relu(output) 22 | #output_list.append(output) 23 | 24 | output = self.linear_2(output) 25 | 26 | #output_list.append(output) 27 | 28 | 29 | return output 30 | 31 | -------------------------------------------------------------------------------- /OpenCoS/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .cifar_resnet import * 3 | from .resnet_auxbn import * 4 | from .cifar_resnet_auxbn import * 5 | from .wide_resnet import * 6 | from .wide_resnet_auxbn import * 7 | from .meta_resnet import * 8 | 9 | def load_model(name, num_classes=10, pretrained=False, divide=False, **kwargs): 10 | model_dict = globals() 11 | if 'wide' in name or 'Wide' in name: 12 | if divide: 13 | model = model_dict[name](28, 2, num_classes=num_classes, divide=divide, **kwargs) 14 | else: 15 | model = model_dict[name](28, 2, num_classes=num_classes, **kwargs) 16 | else: 17 | if divide: 18 | model = model_dict[name](pretrained=pretrained, num_classes=num_classes, divide=divide, **kwargs) 19 | else: 20 | model = model_dict[name](pretrained=pretrained, num_classes=num_classes, **kwargs) 21 | return model 22 | -------------------------------------------------------------------------------- /SimCLR/loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import torch.nn.functional as F 10 | import torch.distributed as dist 11 | 12 | import numpy as np 13 | import diffdist.functional as distops 14 | 15 | 16 | def pairwise_similarity(outputs,temperature=0.5): 17 | ''' 18 | Compute pairwise similarity and return the matrix 19 | input: aggregated outputs & temperature for scaling 20 | return: pairwise cosine similarity 21 | ''' 22 | outputs_1, outputs_2 = outputs.chunk(2) 23 | gather_t_1 = [torch.empty_like(outputs_1) for _ in range(dist.get_world_size())] 24 | gather_t_2 = [torch.empty_like(outputs_2) for _ in range(dist.get_world_size())] 25 | gather_t_1 = distops.all_gather(gather_t_1, outputs_1) 26 | gather_t_2 = distops.all_gather(gather_t_2, outputs_2) 27 | outputs_1 = torch.cat(gather_t_1) 28 | outputs_2 = torch.cat(gather_t_2) 29 | outputs = torch.cat([outputs_1, outputs_2]) 30 | 31 | B = outputs.shape[0] 32 | 33 | outputs_norm = outputs/(outputs.norm(dim=1).view(B,1) + 1e-8) 34 | similarity_matrix = (1./temperature) * torch.mm(outputs_norm,outputs_norm.transpose(0,1)) 35 | 36 | return similarity_matrix 37 | 38 | 39 | def NT_xent(similarity_matrix): 40 | ''' 41 | Compute NT_xent loss 42 | input: pairwise-similarity matrix 43 | return: NT xent loss 44 | ''' 45 | 46 | N2 = len(similarity_matrix) 47 | N = int(len(similarity_matrix) / 2) 48 | 49 | # Removing diagonal # 50 | similarity_matrix_exp = torch.exp(similarity_matrix) 51 | similarity_matrix_exp = similarity_matrix_exp * (1 - torch.eye(N2,N2)).cuda() 52 | 53 | NT_xent_loss = - torch.log(similarity_matrix_exp/(torch.sum(similarity_matrix_exp,dim=1).view(N2,1) + 1e-8) + 1e-8) 54 | NT_xent_loss_total = (1./float(N2)) * torch.sum(torch.diag(NT_xent_loss[0:N,N:]) + torch.diag(NT_xent_loss[N:,0:N])) 55 | 56 | return NT_xent_loss_total -------------------------------------------------------------------------------- /OpenCoS/imbalanced.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import torchvision 4 | 5 | 6 | class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): 7 | """Samples elements randomly from a given list of indices for imbalanced dataset 8 | Arguments: 9 | indices (list, optional): a list of indices 10 | num_samples (int, optional): number of samples to draw 11 | callback_get_label func: a callback-like function which takes two arguments - dataset and index 12 | """ 13 | 14 | def __init__(self, dataset, indices=None, num_samples=None, callback_get_label=None): 15 | 16 | # if indices is not provided, 17 | # all elements in the dataset will be considered 18 | self.indices = list(range(len(dataset))) \ 19 | if indices is None else indices 20 | 21 | # define custom callback 22 | self.callback_get_label = callback_get_label 23 | 24 | # if num_samples is not provided, 25 | # draw `len(indices)` samples in each iteration 26 | self.num_samples = len(self.indices) \ 27 | if num_samples is None else num_samples 28 | 29 | # distribution of classes in the dataset 30 | label_to_count = {} 31 | for idx in self.indices: 32 | label = self._get_label(dataset, idx) 33 | if label in label_to_count: 34 | label_to_count[label] += 1 35 | else: 36 | label_to_count[label] = 1 37 | 38 | # weight for each sample 39 | weights = [1.0 / label_to_count[self._get_label(dataset, idx)] 40 | for idx in self.indices] 41 | self.weights = torch.DoubleTensor(weights) 42 | 43 | def _get_label(self, dataset, idx): 44 | if self.callback_get_label: 45 | return self.callback_get_label(dataset, idx) 46 | elif isinstance(dataset, torchvision.datasets.MNIST): 47 | return dataset.train_labels[idx].item() 48 | elif isinstance(dataset, torchvision.datasets.ImageFolder): 49 | return dataset.imgs[idx][1] 50 | elif isinstance(dataset, torch.utils.data.Subset): 51 | return dataset.dataset.imgs[idx][1] 52 | else: 53 | raise NotImplementedError 54 | 55 | def __iter__(self): 56 | return (self.indices[i] for i in torch.multinomial( 57 | self.weights, self.num_samples, replacement=True)) 58 | 59 | def __len__(self): 60 | return self.num_samples 61 | -------------------------------------------------------------------------------- /SimCLR/argument.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parser(): 4 | 5 | parser = argparse.ArgumentParser(description='PyTorch Contrastive Learning of Visual Representation') 6 | parser.add_argument('--train_type', default='contrastive_learning', type=str, help='standard') 7 | parser.add_argument('--lr', default=1.5, type=float, help='learning rate, LearningRate = 0.3 × BatchSize/256 for ImageNet, 0.5,1.0,1.5 for CIFAR') 8 | parser.add_argument('--lr_multiplier', default=1.0, type=float, help='learning rate multiplier, 5,10,15 -> 0.5,1.0,1.5 for CIFAR') 9 | parser.add_argument('--dataset', default='cifar-10', type=str, help='cifar-10/cifar-100/lsun/imagenet-resize/svhn') 10 | parser.add_argument('--dataroot', default='/data', type=str, help='PATH TO dataset cifar-10, cifar-100, svhn') 11 | parser.add_argument('--tinyroot', default='/data/tinyimagenet/tiny-imagenet-200/train/', type=str, help='PATH TO tinyimagenet dataset') 12 | parser.add_argument('--resume', '-r', action='store_true', 13 | help='resume from checkpoint') 14 | parser.add_argument('--model', default="ResNet50", type=str, 15 | help='model type (default: ResNet50)') 16 | parser.add_argument('--name', default='', type=str, help='name of run') 17 | parser.add_argument('--seed', default=0, type=int, help='random seed') 18 | parser.add_argument('--batch-size', default=128, type=int, help='batch size / multi-gpu setting: batch per gpu') 19 | parser.add_argument('--epoch', default=1000, type=int, 20 | help='total epochs to run') 21 | parser.add_argument('--no-augment', dest='augment', action='store_false', 22 | help='use standard augmentation (default: True)') 23 | parser.add_argument('--decay', default=1e-6, type=float, help='weight decay') 24 | 25 | ##### arguments for data augmentation ##### 26 | parser.add_argument('--color_jitter_strength', default=0.5, type=float, help='0.5 for CIFAR, 1.0 for ImageNet') 27 | parser.add_argument('--temperature', default=0.5, type=float, help='temperature for pairwise-similarity') 28 | 29 | ##### arguments for linear evaluation ##### 30 | parser.add_argument('--multinomial_l2_regul', default=0.5, type=float, help='regularization for multinomial logistic regression') 31 | 32 | ##### arguments for distributted parallel #### 33 | parser.add_argument('--local_rank', type=int, default=0) 34 | parser.add_argument('--ngpu', type=int, default=4) 35 | 36 | parser.add_argument('--ooc_data', type=str, default=None) 37 | 38 | args = parser.parse_args() 39 | 40 | return args 41 | 42 | def print_args(args): 43 | for k, v in vars(args).items(): 44 | print('{:<16} : {}'.format(k, v)) 45 | 46 | -------------------------------------------------------------------------------- /SimCLR/warmup_scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """ Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 12 | total_epoch: target learning rate is reached at total_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 17 | self.multiplier = multiplier 18 | if self.multiplier < 1.: 19 | raise ValueError('multiplier should be greater thant or equal to 1.') 20 | self.total_epoch = total_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super(GradualWarmupScheduler, self).__init__(optimizer) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.total_epoch: 27 | if self.after_scheduler: 28 | if not self.finished: 29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 30 | self.finished = True 31 | return self.after_scheduler.get_lr() 32 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 33 | 34 | if self.multiplier == 1.0: 35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 36 | else: 37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 38 | 39 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 40 | if epoch is None: 41 | epoch = self.last_epoch + 1 42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 43 | if self.last_epoch <= self.total_epoch: 44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 46 | param_group['lr'] = lr 47 | else: 48 | if epoch is None: 49 | self.after_scheduler.step(metrics, None) 50 | else: 51 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 52 | 53 | def step(self, epoch=None, metrics=None): 54 | if type(self.after_scheduler) != ReduceLROnPlateau: 55 | if self.finished and self.after_scheduler: 56 | if epoch is None: 57 | self.after_scheduler.step(None) 58 | else: 59 | self.after_scheduler.step(epoch - self.total_epoch) 60 | else: 61 | return super(GradualWarmupScheduler, self).step(epoch) 62 | else: 63 | self.step_ReduceLROnPlateau(metrics, epoch) 64 | -------------------------------------------------------------------------------- /OpenCoS/utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os, logging 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | 16 | def set_logging_defaults(logdir, args): 17 | if os.path.isdir(logdir): 18 | res = input('"{}" exists. Overwrite [Y/n]? '.format(logdir)) 19 | if res != 'Y': 20 | raise Exception('"{}" exists.'.format(logdir)) 21 | else: 22 | os.makedirs(logdir) 23 | 24 | # set basic configuration for logging 25 | logging.basicConfig(format="[%(asctime)s] [%(name)s] %(message)s", 26 | level=logging.INFO, 27 | handlers=[logging.FileHandler(os.path.join(logdir, 'log.txt')), 28 | logging.StreamHandler(os.sys.stdout)]) 29 | 30 | # log cmdline argumetns 31 | logger = logging.getLogger('main') 32 | logger.info(' '.join(os.sys.argv)) 33 | logger.info(args) 34 | 35 | _, term_width = os.popen('stty size', 'r').read().split() 36 | term_width = int(term_width) 37 | 38 | TOTAL_BAR_LENGTH = 16. 39 | last_time = time.time() 40 | begin_time = last_time 41 | def progress_bar(current, total, msg=None): 42 | global last_time, begin_time 43 | if current == 0: 44 | begin_time = time.time() # Reset for new bar. 45 | 46 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 47 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 48 | 49 | sys.stdout.write(' [') 50 | for i in range(cur_len): 51 | sys.stdout.write('=') 52 | sys.stdout.write('>') 53 | for i in range(rest_len): 54 | sys.stdout.write('.') 55 | sys.stdout.write(']') 56 | 57 | cur_time = time.time() 58 | step_time = cur_time - last_time 59 | last_time = cur_time 60 | tot_time = cur_time - begin_time 61 | 62 | L = [] 63 | L.append(' Step: %s' % format_time(step_time)) 64 | L.append(' | Tot: %s' % format_time(tot_time)) 65 | if msg: 66 | L.append(' | ' + msg) 67 | 68 | msg = ''.join(L) 69 | sys.stdout.write(msg) 70 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 71 | sys.stdout.write(' ') 72 | 73 | # Go back to the center of the bar. 74 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)): 75 | sys.stdout.write('\b') 76 | sys.stdout.write(' %d/%d ' % (current+1, total)) 77 | 78 | if current < total-1: 79 | sys.stdout.write('\r') 80 | else: 81 | sys.stdout.write('\n') 82 | sys.stdout.flush() 83 | 84 | def format_time(seconds): 85 | days = int(seconds / 3600/24) 86 | seconds = seconds - days*3600*24 87 | hours = int(seconds / 3600) 88 | seconds = seconds - hours*3600 89 | minutes = int(seconds / 60) 90 | seconds = seconds - minutes*60 91 | secondsf = int(seconds) 92 | seconds = seconds - secondsf 93 | millis = int(seconds*1000) 94 | 95 | f = '' 96 | i = 1 97 | if days > 0: 98 | f += str(days) + 'D' 99 | i += 1 100 | if hours > 0 and i <= 2: 101 | f += str(hours) + 'h' 102 | i += 1 103 | if minutes > 0 and i <= 2: 104 | f += str(minutes) + 'm' 105 | i += 1 106 | if secondsf > 0 and i <= 2: 107 | f += str(secondsf) + 's' 108 | i += 1 109 | if millis > 0 and i <= 2: 110 | f += str(millis) + 'ms' 111 | i += 1 112 | if f == '': 113 | f = '0ms' 114 | return f 115 | -------------------------------------------------------------------------------- /OpenCoS/models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | '''Wide-ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | [2] Sergey Zagoruyko, Nikos Komodakis 7 | Wide Residual Networks. arXiv:1605.07146 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | __all__ = ['WideResNet', 'wide_resnet'] 15 | 16 | 17 | class WideBasicBlock(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, in_planes, planes, stride=1): 21 | super(WideBasicBlock, self).__init__() 22 | self.bn1 = nn.BatchNorm2d(in_planes) 23 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 26 | 27 | self.shortcut = nn.Sequential() 28 | if stride != 1 or in_planes != self.expansion*planes: 29 | self.shortcut = nn.Sequential( 30 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 31 | ) 32 | 33 | def forward(self, x): 34 | o1 = F.leaky_relu(self.bn1(x), 0.1) 35 | y = self.conv1(o1) 36 | o2 = F.leaky_relu(self.bn2(y), 0.1) 37 | z = self.conv2(o2) 38 | if len(self.shortcut)==0: 39 | return z + x 40 | else: 41 | return z + self.shortcut(o1) 42 | 43 | 44 | 45 | class WideResNet(nn.Module): 46 | """ WRN28-width with leaky relu (negative slope is 0.1)""" 47 | def __init__(self, block, depth, width, num_classes): 48 | super(WideResNet, self).__init__() 49 | self.in_planes = 16 50 | 51 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 52 | n = (depth - 4) // 6 53 | widths = [int(v * width) for v in (16, 32, 64)] 54 | 55 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 56 | self.layer1 = self._make_layer(block, widths[0], n, stride=1) 57 | self.layer2 = self._make_layer(block, widths[1], n, stride=2) 58 | self.layer3 = self._make_layer(block, widths[2], n, stride=2) 59 | self.bn1 = nn.BatchNorm2d(widths[2]) 60 | self.linear = nn.Linear(widths[2]*block.expansion, num_classes) 61 | self.linear_rot = nn.Linear(widths[2]*block.expansion, 4) 62 | 63 | for m in self.modules(): 64 | if isinstance(m, nn.Linear): 65 | nn.init.kaiming_normal_(m.weight) 66 | nn.init.constant_(m.bias, 0) 67 | elif isinstance(m, nn.Conv2d): 68 | nn.init.kaiming_normal_(m.weight) 69 | elif isinstance(m, nn.BatchNorm2d): 70 | nn.init.uniform_(m.weight) 71 | nn.init.constant_(m.bias, 0) 72 | nn.init.constant_(m.running_mean, 0) 73 | nn.init.constant_(m.running_var, 1) 74 | 75 | def _make_layer(self, block, planes, num_blocks, stride): 76 | strides = [stride] + [1]*(num_blocks-1) 77 | layers = [] 78 | for stride in strides: 79 | layers.append(block(self.in_planes, planes, stride)) 80 | self.in_planes = planes * block.expansion 81 | return nn.Sequential(*layers) 82 | 83 | def forward(self, x, feature=False, blocks=False, aux=False): 84 | f0 = self.conv1(x) 85 | f1 = self.layer1(f0) 86 | f2 = self.layer2(f1) 87 | f3 = self.layer3(f2) 88 | out = F.leaky_relu(self.bn1(f3), 0.1) 89 | out = F.avg_pool2d(out, 8) 90 | out4 = out.view(out.size(0), -1) 91 | out = self.linear(out4) 92 | if blocks: 93 | return out, [f1,f2,f3] 94 | elif feature: 95 | return out, out4 96 | else: 97 | return out 98 | 99 | def feature(self, x, aux=False): 100 | f0 = self.conv1(x) 101 | f1 = self.layer1(f0) 102 | f2 = self.layer2(f1) 103 | f3 = self.layer3(f2) 104 | out = F.leaky_relu(self.bn1(f3), 0.1) 105 | out = F.avg_pool2d(out, 8) 106 | out4 = out.view(out.size(0), -1) 107 | return out4 108 | 109 | def blocks(self, x, aux=False): 110 | f0 = self.conv1(x) 111 | f1 = self.layer1(f0) 112 | f2 = self.layer2(f1) 113 | f3 = self.layer3(f2) 114 | return [f1,f2,f3] 115 | 116 | def rot(self, x, aux=False): 117 | f0 = self.conv1(x) 118 | f1 = self.layer1(f0) 119 | f2 = self.layer2(f1) 120 | f3 = self.layer3(f2) 121 | out = F.leaky_relu(self.bn1(f3), 0.1) 122 | out = F.avg_pool2d(out, 8) 123 | out4 = out.view(out.size(0), -1) 124 | out_rot = self.linear_rot(out4) 125 | return out_rot 126 | 127 | def forward_rot(self, x, aux=False): 128 | f0 = self.conv1(x) 129 | f1 = self.layer1(f0) 130 | f2 = self.layer2(f1) 131 | f3 = self.layer3(f2) 132 | out = F.leaky_relu(self.bn1(f3), 0.1) 133 | out = F.avg_pool2d(out, 8) 134 | out4 = out.view(out.size(0), -1) 135 | out = self.linear(out4) 136 | out_rot = self.linear_rot(out4) 137 | return out, out_rot 138 | 139 | 140 | def wide_resnet(depth, width, num_classes=10): 141 | return WideResNet(WideBasicBlock, depth, width, num_classes) 142 | 143 | 144 | #if __name__ == "__main__": 145 | # net = wide_resnet(28, 2) 146 | # net.cuda() 147 | # x = torch.randn(2,3,32,32).cuda() 148 | # print (net) 149 | # print (net(x).size()) 150 | -------------------------------------------------------------------------------- /SimCLR/model_loader.py: -------------------------------------------------------------------------------- 1 | from models.resnet import PreResNet18,ResNet18,ResNet34,ResNet50 2 | from models.imagenet_resnet import resnet50 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def get_model(args): 8 | 9 | if args.dataset == 'cifar-10': 10 | num_classes=10 11 | elif args.dataset == 'cifar-100': 12 | num_classes=100 13 | else: 14 | raise NotImplementedError 15 | 16 | if 'contrastive' in args.train_type: 17 | contrastive_learning=True 18 | else: 19 | contrastive_learning=False 20 | 21 | if args.model == 'PreResNet18': 22 | model = PreResNet18(num_classes,contrastive_learning) 23 | elif args.model == 'ResNet18': 24 | model = ResNet18(num_classes,contrastive_learning) 25 | elif args.model == 'ResNet34': 26 | model = ResNet34(num_classes,contrastive_learning) 27 | elif args.model == 'ResNet50': 28 | model = ResNet50(num_classes,contrastive_learning) 29 | elif args.model == 'resnet50': 30 | model = resnet50() 31 | elif args.model == 'wide_resnet': 32 | model = wide_resnet(num_classes=num_classes,contranstive_learning=contrastive_learning) 33 | 34 | return model 35 | 36 | 37 | class WideBasicBlock(nn.Module): 38 | expansion = 1 39 | 40 | def __init__(self, in_planes, planes, stride=1): 41 | super(WideBasicBlock, self).__init__() 42 | self.bn1 = nn.BatchNorm2d(in_planes) 43 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 46 | 47 | self.shortcut = nn.Sequential() 48 | if stride != 1 or in_planes != self.expansion*planes: 49 | self.shortcut = nn.Sequential( 50 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 51 | ) 52 | 53 | def forward(self, x): 54 | o1 = F.leaky_relu(self.bn1(x), 0.1) 55 | y = self.conv1(o1) 56 | o2 = F.leaky_relu(self.bn2(y), 0.1) 57 | z = self.conv2(o2) 58 | if len(self.shortcut)==0: 59 | return z + x 60 | else: 61 | return z + self.shortcut(o1) 62 | 63 | 64 | 65 | class WideResNet(nn.Module): 66 | """ WRN28-width with leaky relu (negative slope is 0.1)""" 67 | def __init__(self, block, depth, width, num_classes, contranstive_learning=False): 68 | super(WideResNet, self).__init__() 69 | self.in_planes = 16 70 | 71 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 72 | n = (depth - 4) // 6 73 | widths = [int(v * width) for v in (16, 32, 64)] 74 | 75 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 76 | self.layer1 = self._make_layer(block, widths[0], n, stride=1) 77 | self.layer2 = self._make_layer(block, widths[1], n, stride=2) 78 | self.layer3 = self._make_layer(block, widths[2], n, stride=2) 79 | self.bn1 = nn.BatchNorm2d(widths[2]) 80 | 81 | self.contranstive_learning = contranstive_learning 82 | 83 | if not contranstive_learning: 84 | self.linear = nn.Linear(widths[2]*block.expansion, num_classes) 85 | # assert(False) 86 | 87 | for m in self.modules(): 88 | if isinstance(m, nn.Linear): 89 | nn.init.kaiming_normal_(m.weight) 90 | nn.init.constant_(m.bias, 0) 91 | elif isinstance(m, nn.Conv2d): 92 | nn.init.kaiming_normal_(m.weight) 93 | elif isinstance(m, nn.BatchNorm2d): 94 | nn.init.uniform_(m.weight) 95 | nn.init.constant_(m.bias, 0) 96 | nn.init.constant_(m.running_mean, 0) 97 | nn.init.constant_(m.running_var, 1) 98 | 99 | def _make_layer(self, block, planes, num_blocks, stride): 100 | strides = [stride] + [1]*(num_blocks-1) 101 | layers = [] 102 | for stride in strides: 103 | layers.append(block(self.in_planes, planes, stride)) 104 | self.in_planes = planes * block.expansion 105 | return nn.Sequential(*layers) 106 | 107 | def forward(self, x, feature=False, aux=False): 108 | f0 = self.conv1(x) 109 | f1 = self.layer1(f0) 110 | f2 = self.layer2(f1) 111 | f3 = self.layer3(f2) 112 | out = F.leaky_relu(self.bn1(f3), 0.1) 113 | out = F.avg_pool2d(out, 8) 114 | out = out.view(out.size(0), -1) 115 | if not self.contranstive_learning: 116 | out = self.linear(out) 117 | # assert(False) 118 | return out 119 | 120 | 121 | # def rot(self, x, aux=False): 122 | # f0 = self.conv1(x) 123 | # f1 = self.layer1(f0) 124 | # f2 = self.layer2(f1) 125 | # f3 = self.layer3(f2) 126 | # out = F.leaky_relu(self.bn1(f3), 0.1) 127 | # out = F.avg_pool2d(out, 8) 128 | # out4 = out.view(out.size(0), -1) 129 | # out_rot = self.linear_rot(out4) 130 | # return out_rot 131 | 132 | # def forward_rot(self, x, aux=False): 133 | # f0 = self.conv1(x) 134 | # f1 = self.layer1(f0) 135 | # f2 = self.layer2(f1) 136 | # f3 = self.layer3(f2) 137 | # out = F.leaky_relu(self.bn1(f3), 0.1) 138 | # out = F.avg_pool2d(out, 8) 139 | # out4 = out.view(out.size(0), -1) 140 | # out = self.linear(out4) 141 | # out_rot = self.linear_rot(out4) 142 | # return out, out_rot 143 | 144 | 145 | def wide_resnet(depth=28, width=2, num_classes=10,contranstive_learning=False): 146 | 147 | return WideResNet(WideBasicBlock, 28, 2, num_classes, contranstive_learning) 148 | -------------------------------------------------------------------------------- /OpenCoS/models/wide_resnet_auxbn.py: -------------------------------------------------------------------------------- 1 | '''Wide-ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | [2] Sergey Zagoruyko, Nikos Komodakis 7 | Wide Residual Networks. arXiv:1605.07146 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from .aux_batchnorm import BatchNorm2d 13 | from .resnet_auxbn import mySequential 14 | 15 | __all__ = ['WideResNet_AuxBN', 'wide_resnet_auxbn'] 16 | 17 | class WideBasicBlock(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, in_planes, planes, stride=1, divide=False): 21 | super(WideBasicBlock, self).__init__() 22 | self.bn1 = BatchNorm2d(in_planes, divide=divide) 23 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 24 | self.bn2 = BatchNorm2d(planes, divide=divide) 25 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 26 | 27 | self.shortcut = nn.Sequential() 28 | if stride != 1 or in_planes != self.expansion*planes: 29 | self.shortcut = nn.Sequential( 30 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 31 | ) 32 | 33 | def forward(self, x, aux=False): 34 | o1 = F.leaky_relu(self.bn1(x,aux), 0.1) 35 | y = self.conv1(o1) 36 | o2 = F.leaky_relu(self.bn2(y,aux), 0.1) 37 | z = self.conv2(o2) 38 | if len(self.shortcut)==0: 39 | return z + x 40 | else: 41 | return z + self.shortcut(o1) 42 | 43 | 44 | 45 | class WideResNet_AuxBN(nn.Module): 46 | """ WRN28-width with leaky relu (negative slope is 0.1)""" 47 | def __init__(self, block, depth, width, num_classes, divide=False): 48 | super(WideResNet_AuxBN, self).__init__() 49 | self.in_planes = 16 50 | 51 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 52 | n = (depth - 4) // 6 53 | widths = [int(v * width) for v in (16, 32, 64)] 54 | 55 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 56 | self.layer1 = self._make_layer(block, widths[0], n, stride=1, divide=divide) 57 | self.layer2 = self._make_layer(block, widths[1], n, stride=2, divide=divide) 58 | self.layer3 = self._make_layer(block, widths[2], n, stride=2, divide=divide) 59 | self.bn1 = BatchNorm2d(widths[2], divide=divide) 60 | self.linear = nn.Linear(widths[2]*block.expansion, num_classes) 61 | self.linear_rot = nn.Linear(widths[2]*block.expansion, 4) 62 | 63 | for m in self.modules(): 64 | if isinstance(m, nn.Linear): 65 | nn.init.kaiming_normal_(m.weight) 66 | nn.init.constant_(m.bias, 0) 67 | elif isinstance(m, nn.Conv2d): 68 | nn.init.kaiming_normal_(m.weight) 69 | elif isinstance(m, BatchNorm2d): 70 | nn.init.uniform_(m.bn.weight) 71 | nn.init.constant_(m.bn.bias, 0) 72 | nn.init.constant_(m.bn.running_mean, 0) 73 | nn.init.constant_(m.bn.running_var, 1) 74 | nn.init.constant_(m.bn_aux.running_mean, 0) 75 | nn.init.constant_(m.bn_aux.running_var, 1) 76 | 77 | def _make_layer(self, block, planes, num_blocks, stride, divide): 78 | strides = [stride] + [1]*(num_blocks-1) 79 | layers = [] 80 | for stride in strides: 81 | layers.append(block(self.in_planes, planes, stride, divide)) 82 | self.in_planes = planes * block.expansion 83 | return mySequential(*layers) 84 | 85 | def forward(self, x, feature=False, aux=False): 86 | f0 = self.conv1(x) 87 | f1 = self.layer1(f0,aux) 88 | f2 = self.layer2(f1,aux) 89 | f3 = self.layer3(f2,aux) 90 | out = F.leaky_relu(self.bn1(f3,aux), 0.1) 91 | out = F.avg_pool2d(out, 8) 92 | out4 = out.view(out.size(0), -1) 93 | out = self.linear(out4) 94 | 95 | if feature: 96 | return out, out4 97 | else: 98 | return out 99 | 100 | def feature(self, x, aux=False): 101 | f0 = self.conv1(x) 102 | f1 = self.layer1(f0,aux) 103 | f2 = self.layer2(f1,aux) 104 | f3 = self.layer3(f2,aux) 105 | out = F.leaky_relu(self.bn1(f3,aux), 0.1) 106 | out = F.avg_pool2d(out, 8) 107 | out4 = out.view(out.size(0), -1) 108 | return out4 109 | 110 | 111 | def rot(self, x, aux=False): 112 | f0 = self.conv1(x) 113 | f1 = self.layer1(f0,aux) 114 | f2 = self.layer2(f1,aux) 115 | f3 = self.layer3(f2,aux) 116 | out = F.leaky_relu(self.bn1(f3,aux), 0.1) 117 | out = F.avg_pool2d(out, 8) 118 | out4 = out.view(out.size(0), -1) 119 | out_rot = self.linear_rot(out4) 120 | 121 | return out_rot 122 | 123 | 124 | def forward_rot(self, x, aux=False): 125 | f0 = self.conv1(x) 126 | f1 = self.layer1(f0,aux) 127 | f2 = self.layer2(f1,aux) 128 | f3 = self.layer3(f2,aux) 129 | out = F.leaky_relu(self.bn1(f3,aux), 0.1) 130 | out = F.avg_pool2d(out, 8) 131 | out4 = out.view(out.size(0), -1) 132 | out = self.linear(out4) 133 | out_rot = self.linear_rot(out4) 134 | 135 | return out, out_rot 136 | 137 | 138 | def wide_resnet_auxbn(depth, width, num_classes=10, divide=False): 139 | return WideResNet_AuxBN(WideBasicBlock, depth, width, num_classes, divide=divide) 140 | 141 | 142 | #if __name__ == "__main__": 143 | # net = wide_resnet(28, 2) 144 | # net.cuda() 145 | # x = torch.randn(2,3,32,32).cuda() 146 | # print (net) 147 | # print (net(x).size()) 148 | -------------------------------------------------------------------------------- /OpenCoS/models/cifar_resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['CIFAR_ResNet', 'CIFAR_ResNet18', 'CIFAR_ResNet34', 'CIFAR_ResNet50'] 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, groups=groups, bias=False) 12 | 13 | class PreActBlock(nn.Module): 14 | '''Pre-activation version of the BasicBlock.''' 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(PreActBlock, self).__init__() 19 | self.bn1 = nn.BatchNorm2d(in_planes) 20 | self.conv1 = conv3x3(in_planes, planes, stride) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.conv2 = conv3x3(planes, planes) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 28 | ) 29 | 30 | def forward(self, x): 31 | out = F.relu(self.bn1(x)) 32 | shortcut = self.shortcut(out) 33 | out = self.conv1(out) 34 | out = self.conv2(F.relu(self.bn2(out))) 35 | out += shortcut 36 | return out 37 | 38 | class Bottleneck(nn.Module): 39 | expansion = 4 40 | 41 | def __init__(self, in_planes, planes, stride=1): 42 | super(Bottleneck, self).__init__() 43 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 44 | self.bn1 = nn.BatchNorm2d(planes) 45 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 46 | self.bn2 = nn.BatchNorm2d(planes) 47 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 48 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 49 | 50 | self.shortcut = nn.Sequential() 51 | if stride != 1 or in_planes != self.expansion*planes: 52 | self.shortcut = nn.Sequential( 53 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 54 | nn.BatchNorm2d(self.expansion*planes) 55 | ) 56 | 57 | def forward(self, x): 58 | out = F.relu(self.bn1(self.conv1(x))) 59 | out = F.relu(self.bn2(self.conv2(out))) 60 | out = self.bn3(self.conv3(out)) 61 | out += self.shortcut(x) 62 | out = F.relu(out) 63 | return out 64 | 65 | class CIFAR_ResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10, bias=True): 67 | super(CIFAR_ResNet, self).__init__() 68 | self.in_planes = 64 69 | self.conv1 = conv3x3(3,64) 70 | self.bn1 = nn.BatchNorm2d(64) 71 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 72 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 73 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 74 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 75 | self.linear = nn.Linear(512*block.expansion, num_classes, bias=bias) 76 | self.linear_rot = nn.Linear(512*block.expansion, 4, bias=bias) 77 | 78 | 79 | def _make_layer(self, block, planes, num_blocks, stride): 80 | strides = [stride] + [1]*(num_blocks-1) 81 | layers = [] 82 | for stride in strides: 83 | layers.append(block(self.in_planes, planes, stride)) 84 | self.in_planes = planes * block.expansion 85 | return nn.Sequential(*layers) 86 | 87 | def forward(self, x, feature=False, aux=False): 88 | out = x 89 | out = self.conv1(out) 90 | out = self.bn1(out) 91 | out = F.relu(out) 92 | out1 = self.layer1(out) 93 | out2 = self.layer2(out1) 94 | out3 = self.layer3(out2) 95 | out = self.layer4(out3) 96 | out = F.avg_pool2d(out, 4) 97 | out4 = out.view(out.size(0), -1) 98 | out = self.linear(out4) 99 | if feature: 100 | return out, out4 101 | else: 102 | return out 103 | 104 | def feature(self, x, aux=False): 105 | out = x 106 | out = self.conv1(out) 107 | out = self.bn1(out) 108 | out = F.relu(out) 109 | out1 = self.layer1(out) 110 | out2 = self.layer2(out1) 111 | out3 = self.layer3(out2) 112 | out = self.layer4(out3) 113 | out = F.avg_pool2d(out, 4) 114 | out4 = out.view(out.size(0), -1) 115 | return out4 116 | 117 | def rot(self, x, aux=False): 118 | out = x 119 | out = self.conv1(out) 120 | out = self.bn1(out) 121 | out = F.relu(out) 122 | out1 = self.layer1(out) 123 | out2 = self.layer2(out1) 124 | out3 = self.layer3(out2) 125 | out = self.layer4(out3) 126 | out = F.avg_pool2d(out, 4) 127 | out4 = out.view(out.size(0), -1) 128 | out_rot = self.linear_rot(out4) 129 | return out_rot 130 | 131 | 132 | def forward_rot(self, x, aux=False): 133 | out = x 134 | out = self.conv1(out) 135 | out = self.bn1(out) 136 | out = F.relu(out) 137 | out1 = self.layer1(out) 138 | out2 = self.layer2(out1) 139 | out3 = self.layer3(out2) 140 | out = self.layer4(out3) 141 | out = F.avg_pool2d(out, 4) 142 | out4 = out.view(out.size(0), -1) 143 | out = self.linear(out4) 144 | out_rot = self.linear_rot(out4) 145 | return out, out_rot 146 | 147 | 148 | def CIFAR_ResNet10(pretrained=False, **kwargs): 149 | return CIFAR_ResNet(PreActBlock, [1,1,1,1], **kwargs) 150 | 151 | def CIFAR_ResNet18(pretrained=False, **kwargs): 152 | return CIFAR_ResNet(PreActBlock, [2,2,2,2], **kwargs) 153 | 154 | def CIFAR_ResNet34(pretrained=False, **kwargs): 155 | return CIFAR_ResNet(PreActBlock, [3,4,6,3], **kwargs) 156 | 157 | def CIFAR_ResNet50(pretrained=False, **kwargs): 158 | return CIFAR_ResNet(Bottleneck, [3,4,6,3], **kwargs) 159 | -------------------------------------------------------------------------------- /SimCLR/train_contrastive.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import csv 7 | import os 8 | import json 9 | import copy 10 | 11 | import numpy as np 12 | import torch 13 | from torch.autograd import Variable 14 | import torch.backends.cudnn as cudnn 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.optim as optim 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | 21 | import data_loader 22 | import model_loader 23 | import models 24 | from models.projector import Projector 25 | 26 | from argument import parser, print_args 27 | from utils import progress_bar, checkpoint 28 | 29 | from loss import pairwise_similarity,NT_xent 30 | 31 | # Download packages from following git # 32 | # "pip install torchlars" or git from https://github.com/kakaobrain/torchlars, version 0.1.2 33 | # git from https://github.com/ildoonet/pytorch-gradual-warmup-lr # 34 | from torchlars import LARS 35 | from warmup_scheduler import GradualWarmupScheduler 36 | 37 | args = parser() 38 | if args.local_rank == 0: 39 | print_args(args) 40 | 41 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 42 | 43 | if args.seed != 0: 44 | torch.manual_seed(args.seed) 45 | 46 | world_size = args.ngpu 47 | torch.distributed.init_process_group( 48 | 'nccl', 49 | init_method='env://', 50 | world_size=world_size, 51 | rank=args.local_rank, 52 | ) 53 | 54 | # Data 55 | if args.local_rank == 0: 56 | print('==> Preparing data..') 57 | trainloader, traindst, testloader, testdst ,train_sampler = data_loader.get_dataset(args) 58 | if args.local_rank == 0: 59 | print('Number of training data: ', len(traindst)) 60 | 61 | # Model 62 | if args.local_rank == 0: 63 | print('==> Building model..') 64 | torch.cuda.set_device(args.local_rank) 65 | model = model_loader.get_model(args) 66 | if args.model == 'wide_resnet': 67 | projector = Projector(expansion=0) 68 | else: 69 | projector = Projector(expansion=4) 70 | 71 | # Log and saving checkpoint information # 72 | if not os.path.isdir('results') and args.local_rank % ngpus_per_node == 0: 73 | os.mkdir('results') 74 | args.name += (args.train_type + '_' +args.model + '_' + args.dataset) 75 | loginfo = 'results/log_' + args.name + '_' + str(args.seed) 76 | logname = (loginfo+ '.csv') 77 | 78 | if args.local_rank == 0: 79 | print ('Training info...') 80 | print (loginfo) 81 | 82 | # Model upload to GPU # 83 | model.cuda() 84 | projector.cuda() 85 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 86 | model = torch.nn.parallel.DistributedDataParallel( 87 | model, 88 | device_ids=[args.local_rank], 89 | output_device=args.local_rank, 90 | find_unused_parameters=True, 91 | ) 92 | projector = torch.nn.parallel.DistributedDataParallel( 93 | projector, 94 | device_ids=[args.local_rank], 95 | output_device=args.local_rank, 96 | find_unused_parameters=True, 97 | ) 98 | 99 | ngpus_per_node = torch.cuda.device_count() 100 | print(torch.cuda.device_count()) 101 | cudnn.benchmark = True 102 | print('Using CUDA..') 103 | 104 | # Aggregating model parameter & projection parameter # 105 | model_params = [] 106 | model_params += model.parameters() 107 | model_params += projector.parameters() 108 | 109 | # LARS optimizer from KAKAO-BRAIN github 110 | # "pip install torchlars" or git from https://github.com/kakaobrain/torchlars 111 | base_optimizer = optim.SGD(model_params, lr=args.lr, momentum=0.9, weight_decay=args.decay) 112 | optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001) 113 | 114 | # Cosine learning rate annealing (SGDR) & Learning rate warmup # 115 | # git from https://github.com/ildoonet/pytorch-gradual-warmup-lr # 116 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epoch) 117 | scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=args.lr_multiplier, total_epoch=10, after_scheduler=scheduler_cosine) 118 | 119 | def train(epoch): 120 | print('\nEpoch: %d' % epoch) 121 | 122 | scheduler_warmup.step() 123 | model.train() 124 | projector.train() 125 | train_sampler.set_epoch(epoch) 126 | 127 | train_loss = 0 128 | reg_loss = 0 129 | 130 | for batch_idx, ((inputs_1, inputs_2), targets) in enumerate(trainloader): 131 | inputs_1, inputs_2 = inputs_1.cuda() ,inputs_2.cuda() 132 | inputs = torch.cat((inputs_1,inputs_2)) 133 | 134 | outputs = projector(model(inputs)) 135 | 136 | similarity = pairwise_similarity(outputs,temperature=args.temperature) 137 | loss = NT_xent(similarity) 138 | 139 | train_loss += loss.data 140 | 141 | 142 | optimizer.zero_grad() 143 | loss.backward() 144 | optimizer.step() 145 | 146 | progress_bar(batch_idx, len(trainloader), 147 | 'Loss: %.3f | Reg: %.5f' 148 | % (train_loss/(batch_idx+1), reg_loss/(batch_idx+1))) 149 | 150 | return (train_loss/batch_idx, reg_loss/batch_idx) 151 | 152 | 153 | def test(epoch): 154 | model.eval() 155 | projector.eval() 156 | 157 | test_loss = 0 158 | 159 | # Save at the last epoch # 160 | if epoch == start_epoch + args.epoch - 1 and args.local_rank % ngpus_per_node == 0: 161 | checkpoint(model, test_loss, epoch, args, optimizer) 162 | checkpoint(projector, test_loss, epoch, args, optimizer, save_name_add='_projector') 163 | # Save at every 100 epoch # 164 | elif epoch > 1 and epoch %100 == 0 and args.local_rank % ngpus_per_node == 0: 165 | checkpoint(model, test_loss, epoch, args, optimizer, save_name_add='_epoch_'+str(epoch)) 166 | checkpoint(projector, test_loss, epoch, args, optimizer, save_name_add=('_projector_epoch_' + str(epoch))) 167 | 168 | return (test_loss) 169 | 170 | 171 | ##### Log file ##### 172 | if args.local_rank % ngpus_per_node == 0: 173 | if os.path.exists(logname): 174 | os.remove(logname) 175 | with open(logname, 'w') as logfile: 176 | logwriter = csv.writer(logfile, delimiter=',') 177 | logwriter.writerow(['epoch', 'train loss', 'reg loss']) 178 | 179 | 180 | ##### Training ##### 181 | for epoch in range(start_epoch, args.epoch): 182 | train_loss, reg_loss = train(epoch) 183 | _ = test(epoch) 184 | 185 | if args.local_rank % ngpus_per_node == 0: 186 | with open(logname, 'a') as logfile: 187 | logwriter = csv.writer(logfile, delimiter=',') 188 | logwriter.writerow([epoch, train_loss.item(), reg_loss]) 189 | 190 | 191 | -------------------------------------------------------------------------------- /OpenCoS/randaugment.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from 2 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 3 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py 4 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py 5 | # https://github.com/kekmodel/FixMatch-pytorch/blob/master/dataset/randaugment.py 6 | import logging 7 | import random 8 | 9 | import numpy as np 10 | import PIL 11 | import PIL.ImageOps 12 | import PIL.ImageEnhance 13 | import PIL.ImageDraw 14 | from PIL import Image 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | PARAMETER_MAX = 10 19 | 20 | 21 | def AutoContrast(img, **kwarg): 22 | return PIL.ImageOps.autocontrast(img) 23 | 24 | 25 | def Brightness(img, v, max_v, bias=0): 26 | v = _float_parameter(v, max_v) + bias 27 | return PIL.ImageEnhance.Brightness(img).enhance(v) 28 | 29 | 30 | def Color(img, v, max_v, bias=0): 31 | v = _float_parameter(v, max_v) + bias 32 | return PIL.ImageEnhance.Color(img).enhance(v) 33 | 34 | 35 | def Contrast(img, v, max_v, bias=0): 36 | v = _float_parameter(v, max_v) + bias 37 | return PIL.ImageEnhance.Contrast(img).enhance(v) 38 | 39 | 40 | def Cutout(img, v, max_v, bias=0): 41 | if v == 0: 42 | return img 43 | v = _float_parameter(v, max_v) + bias 44 | v = int(v * min(img.size)) 45 | return CutoutAbs(img, v) 46 | 47 | 48 | def CutoutAbs(img, v, **kwarg): 49 | w, h = img.size 50 | x0 = np.random.uniform(0, w) 51 | y0 = np.random.uniform(0, h) 52 | x0 = int(max(0, x0 - v / 2.)) 53 | y0 = int(max(0, y0 - v / 2.)) 54 | x1 = int(min(w, x0 + v)) 55 | y1 = int(min(h, y0 + v)) 56 | xy = (x0, y0, x1, y1) 57 | # gray 58 | color = (127, 127, 127) 59 | img = img.copy() 60 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 61 | return img 62 | 63 | 64 | def Equalize(img, **kwarg): 65 | return PIL.ImageOps.equalize(img) 66 | 67 | 68 | def Identity(img, **kwarg): 69 | return img 70 | 71 | 72 | def Invert(img, **kwarg): 73 | return PIL.ImageOps.invert(img) 74 | 75 | 76 | def Posterize(img, v, max_v, bias=0): 77 | v = _int_parameter(v, max_v) + bias 78 | return PIL.ImageOps.posterize(img, v) 79 | 80 | 81 | def Rotate(img, v, max_v, bias=0): 82 | v = _int_parameter(v, max_v) + bias 83 | if random.random() < 0.5: 84 | v = -v 85 | return img.rotate(v) 86 | 87 | 88 | def Sharpness(img, v, max_v, bias=0): 89 | v = _float_parameter(v, max_v) + bias 90 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 91 | 92 | 93 | def ShearX(img, v, max_v, bias=0): 94 | v = _float_parameter(v, max_v) + bias 95 | if random.random() < 0.5: 96 | v = -v 97 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 98 | 99 | 100 | def ShearY(img, v, max_v, bias=0): 101 | v = _float_parameter(v, max_v) + bias 102 | if random.random() < 0.5: 103 | v = -v 104 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 105 | 106 | 107 | def Solarize(img, v, max_v, bias=0): 108 | v = _int_parameter(v, max_v) + bias 109 | return PIL.ImageOps.solarize(img, 256 - v) 110 | 111 | 112 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 113 | v = _int_parameter(v, max_v) + bias 114 | if random.random() < 0.5: 115 | v = -v 116 | img_np = np.array(img).astype(np.int) 117 | img_np = img_np + v 118 | img_np = np.clip(img_np, 0, 255) 119 | img_np = img_np.astype(np.uint8) 120 | img = Image.fromarray(img_np) 121 | return PIL.ImageOps.solarize(img, threshold) 122 | 123 | 124 | def TranslateX(img, v, max_v, bias=0): 125 | v = _float_parameter(v, max_v) + bias 126 | if random.random() < 0.5: 127 | v = -v 128 | v = int(v * img.size[0]) 129 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 130 | 131 | 132 | def TranslateY(img, v, max_v, bias=0): 133 | v = _float_parameter(v, max_v) + bias 134 | if random.random() < 0.5: 135 | v = -v 136 | v = int(v * img.size[1]) 137 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 138 | 139 | 140 | def _float_parameter(v, max_v): 141 | return float(v) * max_v / PARAMETER_MAX 142 | 143 | 144 | def _int_parameter(v, max_v): 145 | return int(v * max_v / PARAMETER_MAX) 146 | 147 | 148 | def fixmatch_augment_pool(): 149 | # FixMatch paper 150 | augs = [(AutoContrast, None, None), 151 | (Brightness, 0.9, 0.05), 152 | (Color, 0.9, 0.05), 153 | (Contrast, 0.9, 0.05), 154 | (Equalize, None, None), 155 | (Identity, None, None), 156 | (Posterize, 4, 4), 157 | (Rotate, 30, 0), 158 | (Sharpness, 0.9, 0.05), 159 | (ShearX, 0.3, 0), 160 | (ShearY, 0.3, 0), 161 | (Solarize, 256, 0), 162 | (TranslateX, 0.3, 0), 163 | (TranslateY, 0.3, 0)] 164 | return augs 165 | 166 | 167 | def my_augment_pool(): 168 | # Test 169 | augs = [(AutoContrast, None, None), 170 | (Brightness, 1.8, 0.1), 171 | (Color, 1.8, 0.1), 172 | (Contrast, 1.8, 0.1), 173 | (Cutout, 0.2, 0), 174 | (Equalize, None, None), 175 | (Invert, None, None), 176 | (Posterize, 4, 4), 177 | (Rotate, 30, 0), 178 | (Sharpness, 1.8, 0.1), 179 | (ShearX, 0.3, 0), 180 | (ShearY, 0.3, 0), 181 | (Solarize, 256, 0), 182 | (SolarizeAdd, 110, 0), 183 | (TranslateX, 0.45, 0), 184 | (TranslateY, 0.45, 0)] 185 | return augs 186 | 187 | 188 | class RandAugmentPC(object): 189 | def __init__(self, n, m): 190 | assert n >= 1 191 | assert 1 <= m <= 10 192 | self.n = n 193 | self.m = m 194 | self.augment_pool = my_augment_pool() 195 | 196 | def __call__(self, img): 197 | ops = random.choices(self.augment_pool, k=self.n) 198 | for op, max_v, bias in ops: 199 | prob = np.random.uniform(0.2, 0.8) 200 | if random.random() + prob >= 1: 201 | img = op(img, v=self.m, max_v=max_v, bias=bias) 202 | img = CutoutAbs(img, 16) 203 | return img 204 | 205 | 206 | class RandAugmentMC(object): 207 | def __init__(self, n, m): 208 | assert n >= 1 209 | assert 1 <= m <= 10 210 | self.n = n 211 | self.m = m 212 | self.augment_pool = fixmatch_augment_pool() 213 | 214 | def __call__(self, img): 215 | ops = random.choices(self.augment_pool, k=self.n) 216 | for op, max_v, bias in ops: 217 | v = np.random.randint(1, self.m) 218 | if random.random() < 0.5: 219 | img = op(img, v=v, max_v=max_v, bias=bias) 220 | img = CutoutAbs(img, 16) 221 | return img 222 | -------------------------------------------------------------------------------- /SimCLR/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision import transforms 4 | import torchvision.datasets as datasets 5 | from torch.utils.data.distributed import DistributedSampler 6 | from collections import defaultdict 7 | import math 8 | import random 9 | import numpy as np 10 | 11 | # Setup Augmentations 12 | 13 | def get_dataset(args): 14 | 15 | ### color augmentation ### 16 | color_jitter = transforms.ColorJitter(0.8*args.color_jitter_strength, 0.8*args.color_jitter_strength, 0.8*args.color_jitter_strength, 0.2*args.color_jitter_strength) 17 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) 18 | rnd_gray = transforms.RandomGrayscale(p=0.2) 19 | 20 | if 'contrastive' in args.train_type: 21 | contrastive_learning = True 22 | else: 23 | contrastive_learning = False 24 | 25 | if 'linear_eval' in args.train_type or 'multinomial' in args.train_type: 26 | linear_eval = True 27 | else: 28 | linear_eval = False 29 | 30 | if contrastive_learning: 31 | transform_train = transforms.Compose([ 32 | rnd_color_jitter, 33 | rnd_gray, 34 | transforms.RandomHorizontalFlip(), 35 | transforms.RandomResizedCrop(32), 36 | transforms.ToTensor(), 37 | ]) 38 | 39 | transform_test = transform_train 40 | 41 | elif linear_eval: 42 | transform_train = transforms.Compose([ 43 | transforms.Resize(32), 44 | transforms.ToTensor(), 45 | ]) 46 | 47 | transform_test = transform_train 48 | 49 | else: 50 | transform_train = transforms.Compose([ 51 | transforms.RandomCrop(32, padding=4), 52 | transforms.RandomHorizontalFlip(), 53 | transforms.ToTensor(), 54 | ]) 55 | 56 | transform_test = transforms.Compose([ 57 | transforms.Resize(32), 58 | transforms.ToTensor(), 59 | ]) 60 | 61 | class TransformTwice: 62 | def __init__(self, transform): 63 | self.transform = transform 64 | 65 | def __call__(self, inp): 66 | out1 = self.transform(inp) 67 | out2 = self.transform(inp) 68 | return out1, out2 69 | 70 | if args.dataset == 'cifar-10': 71 | 72 | if contrastive_learning: 73 | train_dst = datasets.CIFAR10(root=args.dataroot, train=True, download=True,transform=TransformTwice(transform_train)) 74 | else: 75 | train_dst = datasets.CIFAR10(root=args.dataroot, train=True, download=True,transform=(transform_train)) 76 | val_dst = datasets.CIFAR10(root=args.dataroot, train=False, download=True,transform=transform_test) 77 | 78 | if args.ooc_data == 'svhn': 79 | ooc_dst = datasets.SVHN(root=args.dataroot, split='train', download=True, transform=TransformTwice(transform_train)) 80 | ooc_index = np.load(os.path.join('../OpenCoS/splits', 'svhn_unlabel_train_idx.npy')).astype(np.int64) 81 | elif args.ooc_data == 'tiny': 82 | ooc_dst = datasets.ImageFolder(root=args.tinyroot, transform=TransformTwice(transform_train)) 83 | ooc_index = np.load(os.path.join('../OpenCoS/splits', 'tiny_unlabel_train_idx.npy')).astype(np.int64) 84 | 85 | if args.ooc_data in ['svhn', 'tiny']: 86 | cifar_label = np.load(os.path.join('../OpenCoS/splits', 'cifar10_400pc_label_idx.npy')).astype(np.int64) 87 | cifar_unlabel = np.load(os.path.join('../OpenCoS/splits', 'cifar10_unlabel_train_idx.npy')).astype(np.int64) 88 | cifar_unlabel = cifar_unlabel[:50000 - len(cifar_label) - 40000] 89 | 90 | cifar_index = np.concatenate((cifar_label, cifar_unlabel)) 91 | train_dst = torch.utils.data.Subset(train_dst, cifar_index) 92 | 93 | ooc_index = ooc_index[:40000] 94 | ooc_dst = torch.utils.data.Subset(ooc_dst, ooc_index) 95 | 96 | train_dst = torch.utils.data.ConcatDataset([train_dst, ooc_dst]) 97 | 98 | if args.dataset == 'cifar-100': 99 | 100 | if contrastive_learning: 101 | train_dst = datasets.CIFAR100(root=args.dataroot, train=True, download=True,transform=TransformTwice(transform_train)) 102 | else: 103 | train_dst = datasets.CIFAR100(root=args.dataroot, train=True, download=True,transform=(transform_train)) 104 | val_dst = datasets.CIFAR100(root=args.dataroot, train=False, download=True,transform=transform_test) 105 | 106 | if args.ooc_data == 'svhn': 107 | ooc_dst = datasets.SVHN(root=args.dataroot, split='train', download=True, transform=TransformTwice(transform_train)) 108 | ooc_index = np.load(os.path.join('../OpenCoS/splits', 'svhn_unlabel_train_idx.npy')).astype(np.int64) 109 | elif args.ooc_data == 'tiny': 110 | ooc_dst = datasets.ImageFolder(root=args.tinyroot, transform=TransformTwice(transform_train)) 111 | ooc_index = np.load(os.path.join('../OpenCoS/splits', 'tiny_unlabel_train_idx.npy')).astype(np.int64) 112 | 113 | if args.ooc_data in ['svhn', 'tiny']: 114 | cifar_label = np.load(os.path.join('../OpenCoS/splits', 'cifar100_100pc_label_idx.npy')).astype(np.int64) 115 | cifar_unlabel = np.load(os.path.join('../OpenCoS/splits', 'cifar100_unlabel_train_idx.npy')).astype(np.int64) 116 | cifar_unlabel = cifar_unlabel[:50000 - len(cifar_label) - 40000] 117 | 118 | cifar_index = np.concatenate((cifar_label, cifar_unlabel)) 119 | train_dst = torch.utils.data.Subset(train_dst, cifar_index) 120 | 121 | ooc_index = ooc_index[:40000] 122 | ooc_dst = torch.utils.data.Subset(ooc_dst, ooc_index) 123 | 124 | train_dst = torch.utils.data.ConcatDataset([train_dst, ooc_dst]) 125 | 126 | if contrastive_learning: 127 | train_sampler = torch.utils.data.distributed.DistributedSampler( 128 | train_dst, 129 | num_replicas=args.ngpu, 130 | rank=args.local_rank, 131 | ) 132 | train_loader = torch.utils.data.DataLoader(train_dst,batch_size=args.batch_size,num_workers=4, 133 | pin_memory=True, 134 | shuffle=(train_sampler is None), 135 | sampler=train_sampler, 136 | ) 137 | 138 | val_loader = torch.utils.data.DataLoader(val_dst,batch_size=100,num_workers=4, 139 | pin_memory=True, 140 | shuffle=False, 141 | ) 142 | 143 | return train_loader, train_dst, val_loader, val_dst, train_sampler 144 | else: 145 | train_loader = torch.utils.data.DataLoader(train_dst, 146 | batch_size=args.batch_size, 147 | shuffle=True, num_workers=4) 148 | 149 | val_loader = torch.utils.data.DataLoader(val_dst, batch_size=100, 150 | shuffle=False, num_workers=4) 151 | 152 | return train_loader, train_dst, val_loader, val_dst 153 | -------------------------------------------------------------------------------- /SimCLR/utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | import torch.nn.functional as F 15 | 16 | import numpy as np 17 | import scipy.misc 18 | from itertools import chain 19 | 20 | 21 | def get_mean_and_std(dataset): 22 | '''Compute the mean and std value of dataset.''' 23 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 24 | mean = torch.zeros(3) 25 | std = torch.zeros(3) 26 | print('==> Computing mean and std..') 27 | for inputs, targets in dataloader: 28 | for i in range(3): 29 | mean[i] += inputs[:,i,:,:].mean() 30 | std[i] += inputs[:,i,:,:].std() 31 | mean.div_(len(dataset)) 32 | std.div_(len(dataset)) 33 | return mean, std 34 | 35 | def init_params(net): 36 | '''Init layer parameters.''' 37 | for m in net.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | init.kaiming_normal(m.weight, mode='fan_out') 40 | if m.bias: 41 | init.constant(m.bias, 0) 42 | elif isinstance(m, nn.BatchNorm2d): 43 | init.constant(m.weight, 1) 44 | init.constant(m.bias, 0) 45 | elif isinstance(m, nn.Linear): 46 | init.normal(m.weight, std=1e-3) 47 | if m.bias: 48 | init.constant(m.bias, 0) 49 | 50 | 51 | _, term_width = os.popen('stty size', 'r').read().split() 52 | term_width = int(term_width) 53 | 54 | TOTAL_BAR_LENGTH = 86. 55 | last_time = time.time() 56 | begin_time = last_time 57 | def progress_bar(current, total, msg=None): 58 | global last_time, begin_time 59 | if current == 0: 60 | begin_time = time.time() # Reset for new bar. 61 | 62 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 63 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 64 | 65 | sys.stdout.write(' [') 66 | for i in range(cur_len): 67 | sys.stdout.write('=') 68 | sys.stdout.write('>') 69 | for i in range(rest_len): 70 | sys.stdout.write('.') 71 | sys.stdout.write(']') 72 | 73 | cur_time = time.time() 74 | step_time = cur_time - last_time 75 | last_time = cur_time 76 | tot_time = cur_time - begin_time 77 | 78 | L = [] 79 | L.append(' Step: %s' % format_time(step_time)) 80 | L.append(' | Tot: %s' % format_time(tot_time)) 81 | if msg: 82 | L.append(' | ' + msg) 83 | 84 | msg = ''.join(L) 85 | sys.stdout.write(msg) 86 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 87 | sys.stdout.write(' ') 88 | 89 | # Go back to the center of the bar. 90 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)): 91 | sys.stdout.write('\b') 92 | sys.stdout.write(' %d/%d ' % (current+1, total)) 93 | 94 | if current < total-1: 95 | sys.stdout.write('\r') 96 | else: 97 | sys.stdout.write('\n') 98 | sys.stdout.flush() 99 | 100 | def format_time(seconds): 101 | days = int(seconds / 3600/24) 102 | seconds = seconds - days*3600*24 103 | hours = int(seconds / 3600) 104 | seconds = seconds - hours*3600 105 | minutes = int(seconds / 60) 106 | seconds = seconds - minutes*60 107 | secondsf = int(seconds) 108 | seconds = seconds - secondsf 109 | millis = int(seconds*1000) 110 | 111 | f = '' 112 | i = 1 113 | if days > 0: 114 | f += str(days) + 'D' 115 | i += 1 116 | if hours > 0 and i <= 2: 117 | f += str(hours) + 'h' 118 | i += 1 119 | if minutes > 0 and i <= 2: 120 | f += str(minutes) + 'm' 121 | i += 1 122 | if secondsf > 0 and i <= 2: 123 | f += str(secondsf) + 's' 124 | i += 1 125 | if millis > 0 and i <= 2: 126 | f += str(millis) + 'ms' 127 | i += 1 128 | if f == '': 129 | f = '0ms' 130 | return f 131 | 132 | def checkpoint(model, acc, epoch, args, optimizer, save_name_add=''): 133 | # Save checkpoint. 134 | print('Saving..') 135 | state = { 136 | 'epoch': epoch, 137 | 'acc': acc, 138 | 'model': model.state_dict(), 139 | 'optimizer_state' : optimizer.state_dict(), 140 | 'rng_state': torch.get_rng_state() 141 | } 142 | 143 | save_name = './checkpoint/ckpt.t7' + args.name + '_' + str(args.seed) 144 | save_name += save_name_add 145 | 146 | if not os.path.isdir('checkpoint'): 147 | os.mkdir('checkpoint') 148 | torch.save(state, save_name) 149 | 150 | def learning_rate_warmup(optimizer, epoch, args): 151 | """Learning rate warmup for first 10 epoch""" 152 | 153 | lr = args.lr 154 | lr /= 10 155 | lr *= (epoch+1) 156 | 157 | for param_group in optimizer.param_groups: 158 | param_group['lr'] = lr 159 | 160 | def adjust_learning_rate(optimizer, epoch, args): 161 | """decrease the learning rate at 100 and 150 epoch""" 162 | 163 | lr = args.lr 164 | if epoch >= args.epoch * 0.5 : 165 | lr /= 10 166 | if epoch >= args.epoch * 0.75: 167 | lr /= 10 168 | 169 | for param_group in optimizer.param_groups: 170 | param_group['lr'] = lr 171 | 172 | 173 | def one_hot(ids, n_class): 174 | # --------------------- 175 | # author:ke1th 176 | # source:CSDN 177 | # artical:https://blog.csdn.net/u012436149/article/details/77017832 178 | """ 179 | ids: (list, ndarray) shape:[batch_size] 180 | out_tensor:FloatTensor shape:[batch_size, depth] 181 | """ 182 | 183 | assert len(ids.shape) == 1, 'the ids should be 1-D' 184 | # ids = torch.LongTensor(ids).view(-1,1) 185 | 186 | out_tensor = torch.zeros(len(ids), n_class) 187 | 188 | out_tensor.scatter_(1, ids.cpu().unsqueeze(1), 1.) 189 | 190 | return out_tensor 191 | 192 | class LabelDict(): 193 | def __init__(self, dataset='cifar-10'): 194 | self.dataset = dataset 195 | if dataset == 'cifar-10': 196 | self.label_dict = {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 197 | 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 198 | 8: 'ship', 9: 'truck'} 199 | 200 | self.class_dict = {v: k for k, v in self.label_dict.items()} 201 | 202 | def label2class(self, label): 203 | assert label in self.label_dict, 'the label %d is not in %s' % (label, self.dataset) 204 | return self.label_dict[label] 205 | 206 | def class2label(self, _class): 207 | assert isinstance(_class, str) 208 | assert _class in self.class_dict, 'the class %s is not in %s' % (_class, self.dataset) 209 | return self.class_dict[_class] 210 | 211 | def get_highest_incorrect_predict(outputs,targets): 212 | _, sorted_prediction = torch.topk(outputs.data,k=2,dim=1) 213 | 214 | ### correct then second predict, incorrect then highest predict ### 215 | 216 | highest_incorrect_predict = ((sorted_prediction[:,0] == targets).type(torch.cuda.LongTensor) * sorted_prediction[:,1] + (sorted_prediction[:,0] != targets).type(torch.cuda.LongTensor) * sorted_prediction[:,0]).detach() 217 | 218 | return highest_incorrect_predict 219 | 220 | -------------------------------------------------------------------------------- /OpenCoS/models/cifar_resnet_auxbn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import torch.nn.functional as F 4 | from .aux_batchnorm import BatchNorm2d 5 | 6 | __all__ = ['CIFAR_ResNet_AuxBN', 'CIFAR_ResNet18_AuxBN', 'CIFAR_ResNet34_AuxBN', 'CIFAR_ResNet10_AuxBN', 'CIFAR_ResNet50_AuxBN'] 7 | 8 | class mySequential(nn.Sequential): 9 | def forward(self, *inputs): 10 | for module in self._modules.values(): 11 | if isinstance(module, nn.Conv2d): 12 | if type(inputs) == tuple: 13 | inputs = (module(inputs[0]),)+ inputs[1:] 14 | else: 15 | inputs = module(inputs) 16 | else: 17 | if type(inputs) == tuple: 18 | inputs = (module(*inputs),)+ inputs[1:] 19 | else: 20 | inputs = module(inputs) 21 | 22 | if type(inputs) == tuple: 23 | return inputs[0] 24 | else: 25 | return inputs 26 | 27 | 28 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 29 | """3x3 convolution with padding""" 30 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 31 | padding=1, groups=groups, bias=False) 32 | 33 | class PreActBlock(nn.Module): 34 | '''Pre-activation version of the BasicBlock.''' 35 | expansion = 1 36 | 37 | def __init__(self, in_planes, planes, stride=1, divide=False): 38 | super(PreActBlock, self).__init__() 39 | self.bn1 = BatchNorm2d(in_planes, divide=divide) 40 | self.conv1 = conv3x3(in_planes, planes, stride) 41 | self.bn2 = BatchNorm2d(planes, divide=divide) 42 | self.conv2 = conv3x3(planes, planes) 43 | 44 | self.shortcut = nn.Sequential() 45 | if stride != 1 or in_planes != self.expansion*planes: 46 | self.shortcut = nn.Sequential( 47 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 48 | ) 49 | 50 | def forward(self, x, aux=False): 51 | out = F.relu(self.bn1(x,aux)) 52 | shortcut = self.shortcut(out) 53 | out = self.conv1(out) 54 | out = self.conv2(F.relu(self.bn2(out,aux))) 55 | out += shortcut 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, in_planes, planes, stride=1, divide=False): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 65 | self.bn1 = BatchNorm2d(planes, divide=divide) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 67 | self.bn2 = BatchNorm2d(planes, divide=divide) 68 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 69 | self.bn3 = BatchNorm2d(self.expansion*planes, divide=divide) 70 | 71 | self.shortcut = mySequential() 72 | if stride != 1 or in_planes != self.expansion*planes: 73 | self.shortcut = mySequential( 74 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 75 | BatchNorm2d(self.expansion*planes, divide=divide) 76 | ) 77 | 78 | def forward(self, x, aux=False): 79 | out = F.relu(self.bn1(self.conv1(x),aux)) 80 | out = F.relu(self.bn2(self.conv2(out),aux)) 81 | out = self.bn3(self.conv3(out),aux) 82 | out += self.shortcut(x,aux) 83 | out = F.relu(out) 84 | return out 85 | 86 | 87 | 88 | 89 | class CIFAR_ResNet_AuxBN(nn.Module): 90 | def __init__(self, block, num_blocks, num_classes=10, bias=True, divide=False): 91 | super(CIFAR_ResNet_AuxBN, self).__init__() 92 | self.in_planes = 64 93 | self.conv1 = conv3x3(3,64) 94 | self.bn1 = BatchNorm2d(64, divide=divide) 95 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, divide=divide) 96 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, divide=divide) 97 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, divide=divide) 98 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, divide=divide) 99 | self.linear = nn.Linear(512*block.expansion, num_classes, bias=bias) 100 | self.linear_rot = nn.Linear(512*block.expansion, 4, bias=bias) 101 | 102 | 103 | def _make_layer(self, block, planes, num_blocks, stride, divide): 104 | strides = [stride] + [1]*(num_blocks-1) 105 | layers = [] 106 | for stride in strides: 107 | layers.append(block(self.in_planes, planes, stride, divide)) 108 | self.in_planes = planes * block.expansion 109 | return mySequential(*layers) 110 | 111 | def forward(self, x, feature=False, aux=False): 112 | out = x 113 | out = self.conv1(out) 114 | out = self.bn1(out, aux) 115 | out = F.relu(out) 116 | out1 = self.layer1(out, aux) 117 | out2 = self.layer2(out1, aux) 118 | out3 = self.layer3(out2, aux) 119 | out = self.layer4(out3, aux) 120 | out = F.avg_pool2d(out, 4) 121 | out4 = out.view(out.size(0), -1) 122 | out = self.linear(out4) 123 | if feature: 124 | return out, out4 125 | else: 126 | return out 127 | 128 | def feature(self, x, aux=False): 129 | out = x 130 | out = self.conv1(out) 131 | out = self.bn1(out, aux) 132 | out = F.relu(out) 133 | out1 = self.layer1(out, aux) 134 | out2 = self.layer2(out1, aux) 135 | out3 = self.layer3(out2, aux) 136 | out = self.layer4(out3, aux) 137 | out = F.avg_pool2d(out, 4) 138 | out4 = out.view(out.size(0), -1) 139 | return out4 140 | 141 | def rot(self, x, aux=False): 142 | out = x 143 | out = self.conv1(out) 144 | out = self.bn1(out, aux) 145 | out = F.relu(out) 146 | out1 = self.layer1(out, aux) 147 | out2 = self.layer2(out1, aux) 148 | out3 = self.layer3(out2, aux) 149 | out = self.layer4(out3, aux) 150 | out = F.avg_pool2d(out, 4) 151 | out4 = out.view(out.size(0), -1) 152 | out_rot = self.linear_rot(out4) 153 | return out_rot 154 | 155 | 156 | def forward_rot(self, x, aux=False): 157 | out = x 158 | out = self.conv1(out) 159 | out = self.bn1(out, aux) 160 | out = F.relu(out) 161 | out1 = self.layer1(out, aux) 162 | out2 = self.layer2(out1, aux) 163 | out3 = self.layer3(out2, aux) 164 | out = self.layer4(out3, aux) 165 | out = F.avg_pool2d(out, 4) 166 | out4 = out.view(out.size(0), -1) 167 | out = self.linear(out4) 168 | out_rot = self.linear_rot(out4) 169 | return out, out_rot 170 | 171 | 172 | def CIFAR_ResNet10_AuxBN(pretrained=False, **kwargs): 173 | return CIFAR_ResNet_AuxBN(PreActBlock, [1,1,1,1], **kwargs) 174 | 175 | def CIFAR_ResNet18_AuxBN(pretrained=False, **kwargs): 176 | return CIFAR_ResNet_AuxBN(PreActBlock, [2,2,2,2], **kwargs) 177 | 178 | def CIFAR_ResNet34_AuxBN(pretrained=False, **kwargs): 179 | return CIFAR_ResNet_AuxBN(PreActBlock, [3,4,6,3], **kwargs) 180 | 181 | def CIFAR_ResNet50_AuxBN(pretrained=False, **kwargs): 182 | return CIFAR_ResNet_AuxBN(Bottleneck, [3,4,6,3], **kwargs) 183 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenCoS: Contrastive Semi-supervised Learning for Handling Open-set Unlabeled Data 2 | The code is compatible with CUDA 10.1 and python 3.6. See requirements.txt for all prerequisites, and you can also install them using the following command. 3 | ``` 4 | pip install -r requirements.txt 5 | ``` 6 | 7 | ## Overview 8 | * Stage 1. Unsupervised pre-training (--nproc_per_node=8, --ngpu 8: number of gpus; --dataset: cifar-10, cifar-100; --ooc_data: None, svhn, tiny; --model: ResNet50) 9 | * Stage 2. OpenCoS + ReMixMatch (--sgpu: gpu id; --dataset: animal, cifar10, cifar100; --pc: 4, 25; --udata: svhn, tiny; --model: CIFAR_ResNet50_AuxBN; --model_path: pretrained simclr model) 10 | 11 | - --ood_samples: proportion of ooc (we use out-of-class 40,000 samples, in-class 10,000 samples) 12 | - --model_path: pre-trained model directory of Stage 1. (default: code/SimCLR/checkpoint folder) 13 | - --dataroot: CIFAR-10, CIFAR-100, SVHN datasets directory (default: /data folder) 14 | - --tinyroot: TinyImageNet dataset directory (default: /data/tinyimagenet/tiny-imagenet-200 folder) 15 | 16 | ## Running scripts (OpenCoS + ReMixMatch) 17 | ### CIFAR-Animals + CIFAR-Others benchmark 18 | ``` 19 | python3 -m torch.distributed.launch --nproc_per_node=8 train_contrastive.py --dataset cifar-10 --model ResNet50 --batch-size 128 --name c10_U0 --ngpu 8 --ooc_data None 20 | python3 train_opencos_remixmatch.py --sgpu 0 --dataset animal --ema --model CIFAR_ResNet50_AuxBN --name opencos_remixmatch_Uothers_4pc --udata cten --pc 4 --naug 1 --batch-size 64 -ft --num_iters 50000 --model_path ../SimCLR/checkpoint/ckpt.t7c10_U0contrastive_learning_ResNet50_cifar-10_0 --ood_samples 40000 --lr 0.03 --fix_optim --lmd_pre 0 --lmd_rot 0 --lmd_unif 0.5 --aux_divide --ths 2 --use_jitter --temp_s2 0.1 --top_ratio 0.1 21 | ``` 22 | 23 | ### CIFAR-10 + SVHN benchmark 24 | ``` 25 | python3 -m torch.distributed.launch --nproc_per_node=8 train_contrastive.py --dataset cifar-10 --model ResNet50 --batch-size 128 --name c10_Usvhn40000 --ngpu 8 --ooc_data svhn 26 | python3 train_opencos_remixmatch.py --sgpu 0 --dataset cifar10 --ema --model CIFAR_ResNet50_AuxBN --name opencos_remixmatch_Usvhn_4pc --udata svhn --pc 4 --naug 1 --batch-size 64 -ft --num_iters 50000 --model_path ../SimCLR/checkpoint/ckpt.t7c10_Usvhn40000contrastive_learning_ResNet50_cifar-10_0 --ood_samples 40000 --lr 0.03 --fix_optim --lmd_pre 0 --lmd_rot 0 --lmd_unif 0.5 --aux_divide --ths 2 --use_jitter --temp_s2 0.1 --top_ratio 0.1 27 | ``` 28 | 29 | ### CIFAR-10 + TinyImageNet benchmark 30 | ``` 31 | python3 -m torch.distributed.launch --nproc_per_node=8 train_contrastive.py --dataset cifar-10 --model ResNet50 --batch-size 128 --name c10_Utiny40000 --ngpu 8 --ooc_data tiny 32 | python3 train_opencos_remixmatch.py --sgpu 0 --dataset cifar10 --ema --model CIFAR_ResNet50_AuxBN --name opencos_remixmatch_Utiny_4pc --udata tiny --pc 4 --naug 1 --batch-size 64 -ft --num_iters 50000 --model_path ../SimCLR/checkpoint/ckpt.t7c10_Utiny40000contrastive_learning_ResNet50_cifar-10_0 --ood_samples 40000 --lr 0.03 --fix_optim --lmd_pre 0 --lmd_rot 0 --lmd_unif 0.5 --aux_divide --ths 2 --use_jitter --temp_s2 0.1 --top_ratio 0.1 33 | ``` 34 | 35 | ### CIFAR-100 + SVHN benchmark 36 | ``` 37 | python3 -m torch.distributed.launch --nproc_per_node=8 train_contrastive.py --dataset cifar-100 --model ResNet50 --batch-size 128 --name c100_Usvhn40000 --ngpu 8 --ooc_data svhn 38 | python3 train_opencos_remixmatch.py --sgpu 0 --dataset cifar100 --ema --model CIFAR_ResNet50_AuxBN --name opencos_remixmatch_Usvhn_4pc --udata svhn --pc 4 --naug 1 --batch-size 64 -ft --num_iters 50000 --model_path ../SimCLR/checkpoint/ckpt.t7c100_Usvhn40000contrastive_learning_ResNet50_cifar-100_0 --ood_samples 40000 --lr 0.03 --fix_optim --lmd_pre 0 --lmd_rot 0 --lmd_unif 0.5 --aux_divide --ths 2 --use_jitter --temp_s2 0.1 --top_ratio 0.1 39 | ``` 40 | 41 | ### CIFAR-100 + TinyImageNet benchmark 42 | ``` 43 | python3 -m torch.distributed.launch --nproc_per_node=8 train_contrastive.py --dataset cifar-100 --model ResNet50 --batch-size 128 --name c100_Utiny40000 --ngpu 8 --ooc_data tiny 44 | python3 train_opencos_remixmatch.py --sgpu 0 --dataset cifar100 --ema --model CIFAR_ResNet50_AuxBN --name opencos_remixmatch_Utiny_4pc --udata tiny --pc 4 --naug 1 --batch-size 64 -ft --num_iters 50000 --model_path ../SimCLR/checkpoint/ckpt.t7c100_Utiny40000contrastive_learning_ResNet50_cifar-100_0 --ood_samples 40000 --lr 0.03 --fix_optim --lmd_pre 0 --lmd_rot 0 --lmd_unif 0.5 --aux_divide --ths 2 --use_jitter --temp_s2 0.1 --top_ratio 0.1 45 | ``` 46 | 47 | ## Running scripts for baseline methods (CIFAR-100 + TinyImageNet benchmark) 48 | ### Pre-training (wide_resnet / CIFAR_ResNet50) 49 | ``` 50 | python3 -m torch.distributed.launch --nproc_per_node=8 train_contrastive.py --dataset cifar-100 --model wide_resnet --batch-size 128 --name c100_Utiny40000 --ngpu 8 --ooc_data tiny 51 | python3 -m torch.distributed.launch --nproc_per_node=8 train_contrastive.py --dataset cifar-100 --model ResNet50 --batch-size 128 --name c100_Utiny40000 --ngpu 8 --ooc_data tiny 52 | ``` 53 | 54 | ### SimCLR-le (wide_resnet / CIFAR_ResNet50) 55 | ``` 56 | python3 train.py --sgpu 0 --dataset cifar100 --multinomial --model wide_resnet --name multinomial_4pc --udata tiny --pc 4 -ft --batch-size 128 --ood_samples 40000 --model_path ../SimCLR/checkpoint/ckpt.t7c100_Utiny40000contrastive_learning_wide_resnet_cifar-100_0 57 | python3 train.py --sgpu 0 --dataset cifar100 --multinomial --model CIFAR_ResNet50 --name multinomial_4pc --udata tiny --pc 4 -ft --batch-size 128 --ood_samples 40000 --model_path ../SimCLR/checkpoint/ckpt.t7c100_Utiny40000contrastive_learning_ResNet50_cifar-100_0 58 | ``` 59 | 60 | ### SimCLR-ft (wide_resnet / CIFAR_ResNet50) 61 | ``` 62 | python3 train.py --sgpu 0 --dataset cifar100 --model wide_resnet --name finetune_4pc --udata tiny --pc 4 -ft --batch-size 128 --ood_samples 40000 --model_path ../SimCLR/checkpoint/ckpt.t7c100_Utiny40000contrastive_learning_wide_resnet_cifar-100_0 --lr 0.03 --fix_optim --num_iters 50000 63 | python3 train.py --sgpu 0 --dataset cifar100 --model CIFAR_ResNet50 --name finetune_4pc --udata tiny --pc 4 -ft --batch-size 128 --ood_samples 40000 --model_path ../SimCLR/checkpoint/ckpt.t7c100_Utiny40000contrastive_learning_ResNet50_cifar-100_0 --lr 0.03 --fix_optim --num_iters 50000 64 | ``` 65 | 66 | ### ReMixMatch-ft (wide_resnet / CIFAR_ResNet50) 67 | ``` 68 | python3 train_remixmatch.py --sgpu 0 --dataset cifar100 --ema --model wide_resnet --name remixmatch_4pc --udata tiny --pc 4 --naug 1 --batch-size 64 --num_iters 50000 --ood_samples 40000 --lr 0.03 --fix_optim --use_jitter --no_rampup --lmd_pre 0 --lmd_rot 0 -ft --model_path ../SimCLR/checkpoint/ckpt.t7c100_Utiny40000contrastive_learning_wide_resnet_cifar-100_0 69 | python3 train_remixmatch.py --sgpu 0 --dataset cifar100 --ema --model CIFAR_ResNet50 --name remixmatch_4pc --udata tiny --pc 4 --naug 1 --batch-size 64 --num_iters 50000 --ood_samples 40000 --lr 0.03 --fix_optim --use_jitter --no_rampup --lmd_pre 0 --lmd_rot 0 -ft --model_path ../SimCLR/checkpoint/ckpt.t7c100_Utiny40000contrastive_learning_ResNet50_cifar-100_0 70 | ``` 71 | 72 | ### FixMatch-ft (wide_resnet / CIFAR_ResNet50) 73 | ``` 74 | python3 train_fixmatch.py --sgpu 0 --dataset cifar100 --ema --model wide_resnet --name fixmatch_4pc --udata tiny --pc 4 --mu 1 --batch-size 64 --num_iters 50000 --ood_samples 40000 --lr 0.03 --fix_optim --use_jitter -ft --model_path ../SimCLR/checkpoint/ckpt.t7c100_Utiny40000contrastive_learning_wide_resnet_cifar-100_0 75 | python3 train_fixmatch.py --sgpu 0 --dataset cifar100 --ema --model CIFAR_ResNet50 --name fixmatch_4pc --udata tiny --pc 4 --mu 1 --batch-size 64 --num_iters 50000 --ood_samples 40000 --lr 0.03 --fix_optim --use_jitter -ft --model_path ../SimCLR/checkpoint/ckpt.t7c100_Utiny40000contrastive_learning_ResNet50_cifar-100_0 76 | ``` 77 | -------------------------------------------------------------------------------- /SimCLR/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | BasicBlock and Bottleneck module is from the original ResNet paper: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | 7 | PreActBlock and PreActBottleneck module is from the later paper: 8 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 9 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 10 | ''' 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from torch.autograd import Variable 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, in_planes, planes, stride=1): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(in_planes, planes, stride) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.conv2 = conv3x3(planes, planes) 30 | self.bn2 = nn.BatchNorm2d(planes) 31 | 32 | self.shortcut = nn.Sequential() 33 | if stride != 1 or in_planes != self.expansion*planes: 34 | self.shortcut = nn.Sequential( 35 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 36 | nn.BatchNorm2d(self.expansion*planes) 37 | ) 38 | 39 | def forward(self, x): 40 | out = F.relu(self.bn1(self.conv1(x))) 41 | out = self.bn2(self.conv2(out)) 42 | out += self.shortcut(x) 43 | out = F.relu(out) 44 | return out 45 | 46 | 47 | class PreActBlock(nn.Module): 48 | '''Pre-activation version of the BasicBlock.''' 49 | expansion = 1 50 | 51 | def __init__(self, in_planes, planes, stride=1): 52 | super(PreActBlock, self).__init__() 53 | self.bn1 = nn.BatchNorm2d(in_planes) 54 | self.conv1 = conv3x3(in_planes, planes, stride) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | self.conv2 = conv3x3(planes, planes) 57 | 58 | self.shortcut = nn.Sequential() 59 | if stride != 1 or in_planes != self.expansion*planes: 60 | self.shortcut = nn.Sequential( 61 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 62 | ) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(x)) 66 | shortcut = self.shortcut(out) 67 | out = self.conv1(out) 68 | out = self.conv2(F.relu(self.bn2(out))) 69 | out += shortcut 70 | return out 71 | 72 | 73 | class Bottleneck(nn.Module): 74 | expansion = 4 75 | 76 | def __init__(self, in_planes, planes, stride=1): 77 | super(Bottleneck, self).__init__() 78 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 79 | self.bn1 = nn.BatchNorm2d(planes) 80 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 81 | self.bn2 = nn.BatchNorm2d(planes) 82 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 83 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 84 | 85 | self.shortcut = nn.Sequential() 86 | if stride != 1 or in_planes != self.expansion*planes: 87 | self.shortcut = nn.Sequential( 88 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 89 | nn.BatchNorm2d(self.expansion*planes) 90 | ) 91 | 92 | def forward(self, x): 93 | out = F.relu(self.bn1(self.conv1(x))) 94 | out = F.relu(self.bn2(self.conv2(out))) 95 | out = self.bn3(self.conv3(out)) 96 | out += self.shortcut(x) 97 | out = F.relu(out) 98 | return out 99 | 100 | 101 | class PreActBottleneck(nn.Module): 102 | '''Pre-activation version of the original Bottleneck module.''' 103 | expansion = 4 104 | 105 | def __init__(self, in_planes, planes, stride=1): 106 | super(PreActBottleneck, self).__init__() 107 | self.bn1 = nn.BatchNorm2d(in_planes) 108 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 109 | self.bn2 = nn.BatchNorm2d(planes) 110 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 111 | self.bn3 = nn.BatchNorm2d(planes) 112 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 113 | 114 | self.shortcut = nn.Sequential() 115 | if stride != 1 or in_planes != self.expansion*planes: 116 | self.shortcut = nn.Sequential( 117 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 118 | ) 119 | 120 | def forward(self, x): 121 | out = F.relu(self.bn1(x)) 122 | shortcut = self.shortcut(out) 123 | out = self.conv1(out) 124 | out = self.conv2(F.relu(self.bn2(out))) 125 | out = self.conv3(F.relu(self.bn3(out))) 126 | out += shortcut 127 | return out 128 | 129 | 130 | class ResNet(nn.Module): 131 | def __init__(self, block, num_blocks, num_classes=10, contranstive_learning=False): 132 | super(ResNet, self).__init__() 133 | self.in_planes = 64 134 | 135 | self.conv1 = conv3x3(3,64) 136 | self.bn1 = nn.BatchNorm2d(64) 137 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 138 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 139 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 140 | 141 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 142 | 143 | self.contranstive_learning = contranstive_learning 144 | 145 | if not contranstive_learning: 146 | self.linear = nn.Linear(512*block.expansion, num_classes) 147 | 148 | def _make_layer(self, block, planes, num_blocks, stride): 149 | strides = [stride] + [1]*(num_blocks-1) 150 | layers = [] 151 | for stride in strides: 152 | layers.append(block(self.in_planes, planes, stride)) 153 | self.in_planes = planes * block.expansion 154 | return nn.Sequential(*layers) 155 | 156 | def forward(self, x, lin=0, lout=5, internal_outputs=False): 157 | out = x 158 | out_list = [] 159 | 160 | out = self.conv1(out) 161 | out = self.bn1(out) 162 | out = F.relu(out) 163 | out_list.append(out) 164 | 165 | out = self.layer1(out) 166 | out_list.append(out) 167 | 168 | out = self.layer2(out) 169 | out_list.append(out) 170 | 171 | out = self.layer3(out) 172 | out_list.append(out) 173 | 174 | out = self.layer4(out) 175 | out_list.append(out) 176 | 177 | out = F.avg_pool2d(out, 4) 178 | out = out.view(out.size(0), -1) 179 | 180 | if not self.contranstive_learning: 181 | out = self.linear(out) 182 | 183 | if internal_outputs: 184 | return out, out_list 185 | 186 | return out 187 | 188 | def PreResNet18(num_classes,contranstive_learning): 189 | return ResNet(PreActBlock, [2,2,2,2],num_classes=num_classes,contranstive_learning=contranstive_learning) 190 | 191 | def ResNet18(num_classes,contranstive_learning): 192 | return ResNet(BasicBlock, [2,2,2,2],num_classes=num_classes,contranstive_learning=contranstive_learning) 193 | 194 | def ResNet34(num_classes,contranstive_learning): 195 | return ResNet(BasicBlock, [3,4,6,3],num_classes=num_classes,contranstive_learning=contranstive_learning) 196 | 197 | def ResNet50(num_classes,contranstive_learning): 198 | return ResNet(Bottleneck, [3,4,6,3],num_classes=num_classes,contranstive_learning=contranstive_learning) 199 | 200 | def ResNet101(num_classes,contranstive_learning): 201 | return ResNet(Bottleneck, [3,4,23,3],num_classes=num_classes,contranstive_learning=contranstive_learning) 202 | 203 | def ResNet152(num_classes,contranstive_learning): 204 | return ResNet(Bottleneck, [3,8,36,3],num_classes=num_classes,contranstive_learning=contranstive_learning) 205 | 206 | def test(): 207 | net = ResNet18() 208 | y = net(Variable(torch.randn(1,3,32,32))) 209 | print(y.size()) 210 | 211 | -------------------------------------------------------------------------------- /SimCLR/models/imagenet_resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | BasicBlock and Bottleneck module is from the original ResNet paper: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | 7 | PreActBlock and PreActBottleneck module is from the later paper: 8 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 9 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 10 | ''' 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from torch.autograd import Variable 16 | 17 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 18 | """3x3 convolution with padding""" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=dilation, groups=groups, bias=False, dilation=dilation) 21 | 22 | 23 | def conv1x1(in_planes, out_planes, stride=1): 24 | """1x1 convolution""" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | __constants__ = ['downsample'] 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 32 | base_width=64, dilation=1, norm_layer=None): 33 | super(BasicBlock, self).__init__() 34 | if norm_layer is None: 35 | norm_layer = nn.BatchNorm2d 36 | if groups != 1 or base_width != 64: 37 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 38 | if dilation > 1: 39 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 40 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 41 | self.conv1 = conv3x3(inplanes, planes, stride) 42 | self.bn1 = norm_layer(planes) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.conv2 = conv3x3(planes, planes) 45 | self.bn2 = norm_layer(planes) 46 | self.downsample = downsample 47 | self.stride = stride 48 | 49 | def forward(self, x): 50 | identity = x 51 | 52 | out = self.conv1(x) 53 | out = self.bn1(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv2(out) 57 | out = self.bn2(out) 58 | 59 | if self.downsample is not None: 60 | identity = self.downsample(x) 61 | 62 | out += identity 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | class Bottleneck(nn.Module): 68 | expansion = 4 69 | __constants__ = ['downsample'] 70 | 71 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 72 | base_width=64, dilation=1, norm_layer=None): 73 | super(Bottleneck, self).__init__() 74 | if norm_layer is None: 75 | norm_layer = nn.BatchNorm2d 76 | width = int(planes * (base_width / 64.)) * groups 77 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 78 | self.conv1 = conv1x1(inplanes, width) 79 | self.bn1 = norm_layer(width) 80 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 81 | self.bn2 = norm_layer(width) 82 | self.conv3 = conv1x1(width, planes * self.expansion) 83 | self.bn3 = norm_layer(planes * self.expansion) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.downsample = downsample 86 | self.stride = stride 87 | 88 | def forward(self, x): 89 | identity = x 90 | 91 | out = self.conv1(x) 92 | out = self.bn1(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv2(out) 96 | out = self.bn2(out) 97 | out = self.relu(out) 98 | 99 | out = self.conv3(out) 100 | out = self.bn3(out) 101 | 102 | if self.downsample is not None: 103 | identity = self.downsample(x) 104 | 105 | out += identity 106 | out = self.relu(out) 107 | 108 | return out 109 | 110 | class ResNet(nn.Module): 111 | 112 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 113 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 114 | norm_layer=None): 115 | super(ResNet, self).__init__() 116 | if norm_layer is None: 117 | norm_layer = nn.BatchNorm2d 118 | self._norm_layer = norm_layer 119 | 120 | self.inplanes = 64 121 | self.dilation = 1 122 | if replace_stride_with_dilation is None: 123 | # each element in the tuple indicates if we should replace 124 | # the 2x2 stride with a dilated convolution instead 125 | replace_stride_with_dilation = [False, False, False] 126 | if len(replace_stride_with_dilation) != 3: 127 | raise ValueError("replace_stride_with_dilation should be None " 128 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 129 | self.groups = groups 130 | self.base_width = width_per_group 131 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 132 | bias=False) 133 | self.bn1 = norm_layer(self.inplanes) 134 | self.relu = nn.ReLU(inplace=True) 135 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 136 | self.layer1 = self._make_layer(block, 64, layers[0]) 137 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 138 | dilate=replace_stride_with_dilation[0]) 139 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 140 | dilate=replace_stride_with_dilation[1]) 141 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 142 | dilate=replace_stride_with_dilation[2]) 143 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 144 | self.fc = nn.Linear(512 * block.expansion, num_classes) 145 | 146 | for m in self.modules(): 147 | if isinstance(m, nn.Conv2d): 148 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 149 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 150 | nn.init.constant_(m.weight, 1) 151 | nn.init.constant_(m.bias, 0) 152 | 153 | # Zero-initialize the last BN in each residual branch, 154 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 155 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 156 | if zero_init_residual: 157 | for m in self.modules(): 158 | if isinstance(m, Bottleneck): 159 | nn.init.constant_(m.bn3.weight, 0) 160 | elif isinstance(m, BasicBlock): 161 | nn.init.constant_(m.bn2.weight, 0) 162 | 163 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 164 | norm_layer = self._norm_layer 165 | downsample = None 166 | previous_dilation = self.dilation 167 | if dilate: 168 | self.dilation *= stride 169 | stride = 1 170 | if stride != 1 or self.inplanes != planes * block.expansion: 171 | downsample = nn.Sequential( 172 | conv1x1(self.inplanes, planes * block.expansion, stride), 173 | norm_layer(planes * block.expansion), 174 | ) 175 | 176 | layers = [] 177 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 178 | self.base_width, previous_dilation, norm_layer)) 179 | self.inplanes = planes * block.expansion 180 | for _ in range(1, blocks): 181 | layers.append(block(self.inplanes, planes, groups=self.groups, 182 | base_width=self.base_width, dilation=self.dilation, 183 | norm_layer=norm_layer)) 184 | 185 | return nn.Sequential(*layers) 186 | 187 | def _forward(self, x, internal_outputs=False): 188 | output_list = [] 189 | 190 | x = self.conv1(x) 191 | x = self.bn1(x) 192 | x = self.relu(x) 193 | x = self.maxpool(x) 194 | 195 | x = self.layer1(x) 196 | output_list.append(x) 197 | 198 | x = self.layer2(x) 199 | output_list.append(x) 200 | 201 | x = self.layer3(x) 202 | output_list.append(x) 203 | 204 | x = self.layer4(x) 205 | output_list.append(x) 206 | 207 | x = self.avgpool(x) 208 | x = torch.flatten(x, 1) 209 | #x = self.fc(x) 210 | 211 | if internal_outputs: 212 | return x, output_list 213 | 214 | return x 215 | 216 | forward = _forward 217 | 218 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 219 | model = ResNet(block, layers, **kwargs) 220 | return model 221 | 222 | def resnet18(pretrained=False, progress=True, **kwargs): 223 | r"""ResNet-18 model from 224 | `"Deep Residual Learning for Image Recognition" `_ 225 | 226 | Args: 227 | pretrained (bool): If True, returns a model pre-trained on ImageNet 228 | progress (bool): If True, displays a progress bar of the download to stderr 229 | """ 230 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 231 | **kwargs) 232 | 233 | def resnet34(pretrained=False, progress=True, **kwargs): 234 | r"""ResNet-34 model from 235 | `"Deep Residual Learning for Image Recognition" `_ 236 | 237 | Args: 238 | pretrained (bool): If True, returns a model pre-trained on ImageNet 239 | progress (bool): If True, displays a progress bar of the download to stderr 240 | """ 241 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 242 | **kwargs) 243 | 244 | def resnet50(pretrained=False, progress=True, **kwargs): 245 | r"""ResNet-50 model from 246 | `"Deep Residual Learning for Image Recognition" `_ 247 | 248 | Args: 249 | pretrained (bool): If True, returns a model pre-trained on ImageNet 250 | progress (bool): If True, displays a progress bar of the download to stderr 251 | """ 252 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 253 | **kwargs) 254 | 255 | def resnet101(pretrained=False, progress=True, **kwargs): 256 | r"""ResNet-101 model from 257 | `"Deep Residual Learning for Image Recognition" `_ 258 | Args: 259 | pretrained (bool): If True, returns a model pre-trained on ImageNet 260 | progress (bool): If True, displays a progress bar of the download to stderr 261 | """ 262 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 263 | **kwargs) 264 | 265 | 266 | def resnet152(pretrained=False, progress=True, **kwargs): 267 | r"""ResNet-152 model from 268 | `"Deep Residual Learning for Image Recognition" `_ 269 | Args: 270 | pretrained (bool): If True, returns a model pre-trained on ImageNet 271 | progress (bool): If True, displays a progress bar of the download to stderr 272 | """ 273 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 274 | **kwargs) 275 | -------------------------------------------------------------------------------- /OpenCoS/models/resnet_auxbn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import torch.nn.functional as F 4 | from .aux_batchnorm import BatchNorm2d 5 | from .cifar_resnet_auxbn import mySequential 6 | from .resnet import conv3x3, conv1x1 7 | 8 | __all__ = ['resnet18_auxbn', 'resnet34_auxbn', 'resnet50_auxbn', 'resnet101_auxbn', 'resnet152_auxbn'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | } 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | expansion = 1 22 | 23 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 24 | base_width=64, norm_layer=None): 25 | super(BasicBlock, self).__init__() 26 | if norm_layer is None: 27 | norm_layer = BatchNorm2d 28 | if groups != 1 or base_width != 64: 29 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 30 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = norm_layer(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | self.bn2 = norm_layer(planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x, aux=False): 40 | identity = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out,aux) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out,aux) 48 | 49 | if self.downsample is not None: 50 | identity = self.downsample(x,aux) 51 | 52 | out += identity 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | expansion = 4 60 | 61 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 62 | base_width=64, norm_layer=None, divide=False): 63 | super(Bottleneck, self).__init__() 64 | if norm_layer is None: 65 | norm_layer = BatchNorm2d 66 | width = int(planes * (base_width / 64.)) * groups 67 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 68 | self.conv1 = conv1x1(inplanes, width) 69 | self.bn1 = norm_layer(width, divide=divide) 70 | self.conv2 = conv3x3(width, width, stride, groups) 71 | self.bn2 = norm_layer(width, divide=divide) 72 | self.conv3 = conv1x1(width, planes * self.expansion) 73 | self.bn3 = norm_layer(planes * self.expansion, divide=divide) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x, aux=False): 79 | identity = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out,aux) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out,aux) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out,aux) 91 | 92 | if self.downsample is not None: 93 | identity = self.downsample(x,aux) 94 | 95 | out += identity 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | class PreActBlock(nn.Module): 101 | '''Pre-activation version of the BasicBlock.''' 102 | expansion = 1 103 | 104 | def __init__(self, in_planes, planes, stride=1): 105 | super(PreActBlock, self).__init__() 106 | self.bn1 = BatchNorm2d(in_planes) 107 | self.conv1 = conv3x3(in_planes, planes, stride) 108 | self.bn2 = BatchNorm2d(planes) 109 | self.conv2 = conv3x3(planes, planes) 110 | 111 | self.shortcut = nn.Sequential() 112 | if stride != 1 or in_planes != self.expansion*planes: 113 | self.shortcut = nn.Sequential( 114 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 115 | ) 116 | 117 | def forward(self, x, aux=False): 118 | out = F.relu(self.bn1(x,aux)) 119 | shortcut = self.shortcut(out) 120 | out = self.conv1(out) 121 | out = self.conv2(F.relu(self.bn2(out,aux))) 122 | out += shortcut 123 | return out 124 | 125 | class ResNet(nn.Module): 126 | 127 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 128 | groups=1, width_per_group=64, norm_layer=None, bias=True, divide=False): 129 | super(ResNet, self).__init__() 130 | if norm_layer is None: 131 | norm_layer = BatchNorm2d 132 | 133 | self.inplanes = 64 134 | self.groups = groups 135 | self.base_width = width_per_group 136 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 137 | bias=False) 138 | self.bn1 = norm_layer(self.inplanes, divide=divide) 139 | self.relu = nn.ReLU(inplace=True) 140 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 141 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, divide=divide) 142 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer, divide=divide) 143 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer, divide=divide) 144 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer, divide=divide) 145 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 146 | self.fc = nn.Linear(512 * block.expansion, num_classes, bias=bias) 147 | self.linear_rot = nn.Linear(512*block.expansion, 4, bias=bias) 148 | for m in self.modules(): 149 | if isinstance(m, nn.Conv2d): 150 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 151 | elif isinstance(m, BatchNorm2d): 152 | nn.init.constant_(m.bn.weight, 1) 153 | nn.init.constant_(m.bn.bias, 0) 154 | elif isinstance(m, nn.GroupNorm): 155 | nn.init.constant_(m.weight, 1) 156 | nn.init.constant_(m.bias, 0) 157 | 158 | # Zero-initialize the last BN in each residual branch, 159 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 160 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 161 | if zero_init_residual: 162 | for m in self.modules(): 163 | if isinstance(m, Bottleneck): 164 | nn.init.constant_(m.bn3.bn.weight, 0) 165 | elif isinstance(m, BasicBlock): 166 | nn.init.constant_(m.bn2.bn.weight, 0) 167 | 168 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None, divide=False): 169 | if norm_layer is None: 170 | norm_layer = BatchNorm2d 171 | downsample = None 172 | if stride != 1 or self.inplanes != planes * block.expansion: 173 | downsample = mySequential( 174 | conv1x1(self.inplanes, planes * block.expansion, stride), 175 | norm_layer(planes * block.expansion, divide=divide), 176 | ) 177 | 178 | layers = [] 179 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 180 | self.base_width, norm_layer, divide=divide)) 181 | self.inplanes = planes * block.expansion 182 | for _ in range(1, blocks): 183 | layers.append(block(self.inplanes, planes, groups=self.groups, 184 | base_width=self.base_width, norm_layer=norm_layer, divide=divide)) 185 | 186 | return mySequential(*layers) 187 | 188 | def forward(self, x, feature=False, aux=False): 189 | x = self.conv1(x) 190 | x = self.bn1(x,aux) 191 | x = self.relu(x) 192 | x = self.maxpool(x) 193 | 194 | x = self.layer1(x,aux) 195 | x = self.layer2(x,aux) 196 | x = self.layer3(x,aux) 197 | x = self.layer4(x,aux) 198 | 199 | x = self.avgpool(x) 200 | x = x.view(x.size(0), -1) 201 | o = self.fc(x) 202 | 203 | if feature: 204 | return o, x 205 | else: 206 | return o 207 | 208 | def feature(self, x, aux=False): 209 | x = self.conv1(x) 210 | x = self.bn1(x,aux) 211 | x = self.relu(x) 212 | x = self.maxpool(x) 213 | 214 | x = self.layer1(x,aux) 215 | x = self.layer2(x,aux) 216 | x = self.layer3(x,aux) 217 | x = self.layer4(x,aux) 218 | 219 | x = self.avgpool(x) 220 | x = x.view(x.size(0), -1) 221 | 222 | return x 223 | 224 | 225 | def rot(self, x, aux=False): 226 | x = self.conv1(x) 227 | x = self.bn1(x,aux) 228 | x = self.relu(x) 229 | x = self.maxpool(x) 230 | 231 | x = self.layer1(x,aux) 232 | x = self.layer2(x,aux) 233 | x = self.layer3(x,aux) 234 | x = self.layer4(x,aux) 235 | 236 | x = self.avgpool(x) 237 | x = x.view(x.size(0), -1) 238 | o_rot = self.linear_rot(x) 239 | 240 | return o_rot 241 | 242 | def forward_rot(self, x, aux=False): 243 | x = self.conv1(x) 244 | x = self.bn1(x,aux) 245 | x = self.relu(x) 246 | x = self.maxpool(x) 247 | 248 | x = self.layer1(x,aux) 249 | x = self.layer2(x,aux) 250 | x = self.layer3(x,aux) 251 | x = self.layer4(x,aux) 252 | 253 | x = self.avgpool(x) 254 | x = x.view(x.size(0), -1) 255 | o = self.fc(x) 256 | o_rot = self.linear_rot(x) 257 | 258 | return o, o_rot 259 | 260 | 261 | def resnet10_auxbn(pretrained=False, **kwargs): 262 | """Constructs a ResNet-10 model. 263 | Args: 264 | pretrained (bool): If True, returns a model pre-trained on ImageNet 265 | """ 266 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 267 | return model 268 | 269 | def resnet18_auxbn(pretrained=False, **kwargs): 270 | """Constructs a ResNet-18 model. 271 | Args: 272 | pretrained (bool): If True, returns a model pre-trained on ImageNet 273 | """ 274 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 275 | if pretrained: 276 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 277 | return model 278 | 279 | 280 | def resnet34_auxbn(pretrained=False, **kwargs): 281 | """Constructs a ResNet-34 model. 282 | Args: 283 | pretrained (bool): If True, returns a model pre-trained on ImageNet 284 | """ 285 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 286 | if pretrained: 287 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 288 | return model 289 | 290 | 291 | def resnet50_auxbn(pretrained=False, **kwargs): 292 | """Constructs a ResNet-50 model. 293 | Args: 294 | pretrained (bool): If True, returns a model pre-trained on ImageNet 295 | """ 296 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 297 | if pretrained: 298 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 299 | return model 300 | 301 | 302 | def resnet101_auxbn(pretrained=False, **kwargs): 303 | """Constructs a ResNet-101 model. 304 | Args: 305 | pretrained (bool): If True, returns a model pre-trained on ImageNet 306 | """ 307 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 308 | if pretrained: 309 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 310 | return model 311 | 312 | 313 | def resnet152_auxbn(pretrained=False, **kwargs): 314 | """Constructs a ResNet-152 model. 315 | Args: 316 | pretrained (bool): If True, returns a model pre-trained on ImageNet 317 | """ 318 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 319 | if pretrained: 320 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 321 | return model 322 | -------------------------------------------------------------------------------- /OpenCoS/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 6 | 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 18 | """3x3 convolution with padding""" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, groups=groups, bias=False) 21 | 22 | 23 | def conv1x1(in_planes, out_planes, stride=1): 24 | """1x1 convolution""" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 32 | base_width=64, norm_layer=None): 33 | super(BasicBlock, self).__init__() 34 | if norm_layer is None: 35 | norm_layer = nn.BatchNorm2d 36 | if groups != 1 or base_width != 64: 37 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 38 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 39 | self.conv1 = conv3x3(inplanes, planes, stride) 40 | self.bn1 = norm_layer(planes) 41 | self.relu = nn.ReLU(inplace=True) 42 | self.conv2 = conv3x3(planes, planes) 43 | self.bn2 = norm_layer(planes) 44 | self.downsample = downsample 45 | self.stride = stride 46 | 47 | def forward(self, x): 48 | identity = x 49 | 50 | out = self.conv1(x) 51 | out = self.bn1(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | 57 | if self.downsample is not None: 58 | identity = self.downsample(x) 59 | 60 | out += identity 61 | out = self.relu(out) 62 | 63 | return out 64 | 65 | 66 | class Bottleneck(nn.Module): 67 | expansion = 4 68 | 69 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 70 | base_width=64, norm_layer=None): 71 | super(Bottleneck, self).__init__() 72 | if norm_layer is None: 73 | norm_layer = nn.BatchNorm2d 74 | width = int(planes * (base_width / 64.)) * groups 75 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 76 | self.conv1 = conv1x1(inplanes, width) 77 | self.bn1 = norm_layer(width) 78 | self.conv2 = conv3x3(width, width, stride, groups) 79 | self.bn2 = norm_layer(width) 80 | self.conv3 = conv1x1(width, planes * self.expansion) 81 | self.bn3 = norm_layer(planes * self.expansion) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.downsample = downsample 84 | self.stride = stride 85 | 86 | def forward(self, x): 87 | identity = x 88 | 89 | out = self.conv1(x) 90 | out = self.bn1(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv2(out) 94 | out = self.bn2(out) 95 | out = self.relu(out) 96 | 97 | out = self.conv3(out) 98 | out = self.bn3(out) 99 | 100 | if self.downsample is not None: 101 | identity = self.downsample(x) 102 | 103 | out += identity 104 | out = self.relu(out) 105 | 106 | return out 107 | 108 | class PreActBlock(nn.Module): 109 | '''Pre-activation version of the BasicBlock.''' 110 | expansion = 1 111 | 112 | def __init__(self, in_planes, planes, stride=1): 113 | super(PreActBlock, self).__init__() 114 | self.bn1 = nn.BatchNorm2d(in_planes) 115 | self.conv1 = conv3x3(in_planes, planes, stride) 116 | self.bn2 = nn.BatchNorm2d(planes) 117 | self.conv2 = conv3x3(planes, planes) 118 | 119 | self.shortcut = nn.Sequential() 120 | if stride != 1 or in_planes != self.expansion*planes: 121 | self.shortcut = nn.Sequential( 122 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 123 | ) 124 | 125 | def forward(self, x): 126 | out = F.relu(self.bn1(x)) 127 | shortcut = self.shortcut(out) 128 | out = self.conv1(out) 129 | out = self.conv2(F.relu(self.bn2(out))) 130 | out += shortcut 131 | return out 132 | 133 | class ResNet(nn.Module): 134 | 135 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 136 | groups=1, width_per_group=64, norm_layer=None, bias=True): 137 | super(ResNet, self).__init__() 138 | if norm_layer is None: 139 | norm_layer = nn.BatchNorm2d 140 | 141 | self.inplanes = 64 142 | self.groups = groups 143 | self.base_width = width_per_group 144 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 145 | bias=False) 146 | self.bn1 = norm_layer(self.inplanes) 147 | self.relu = nn.ReLU(inplace=True) 148 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 149 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) 150 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 151 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer) 152 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer) 153 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 154 | self.fc = nn.Linear(512 * block.expansion, num_classes, bias=bias) 155 | self.linear_rot = nn.Linear(512*block.expansion, 4, bias=bias) 156 | for m in self.modules(): 157 | if isinstance(m, nn.Conv2d): 158 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 159 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 160 | nn.init.constant_(m.weight, 1) 161 | nn.init.constant_(m.bias, 0) 162 | 163 | # Zero-initialize the last BN in each residual branch, 164 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 165 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 166 | if zero_init_residual: 167 | for m in self.modules(): 168 | if isinstance(m, Bottleneck): 169 | nn.init.constant_(m.bn3.weight, 0) 170 | elif isinstance(m, BasicBlock): 171 | nn.init.constant_(m.bn2.weight, 0) 172 | 173 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None): 174 | if norm_layer is None: 175 | norm_layer = nn.BatchNorm2d 176 | downsample = None 177 | if stride != 1 or self.inplanes != planes * block.expansion: 178 | downsample = nn.Sequential( 179 | conv1x1(self.inplanes, planes * block.expansion, stride), 180 | norm_layer(planes * block.expansion), 181 | ) 182 | 183 | layers = [] 184 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 185 | self.base_width, norm_layer)) 186 | self.inplanes = planes * block.expansion 187 | for _ in range(1, blocks): 188 | layers.append(block(self.inplanes, planes, groups=self.groups, 189 | base_width=self.base_width, norm_layer=norm_layer)) 190 | 191 | return nn.Sequential(*layers) 192 | 193 | def forward(self, x, feature=False, aux=False): 194 | x = self.conv1(x) 195 | x = self.bn1(x) 196 | x = self.relu(x) 197 | x = self.maxpool(x) 198 | 199 | x = self.layer1(x) 200 | x = self.layer2(x) 201 | x = self.layer3(x) 202 | x = self.layer4(x) 203 | 204 | x = self.avgpool(x) 205 | x = x.view(x.size(0), -1) 206 | o = self.fc(x) 207 | 208 | if feature: 209 | return o, x 210 | else: 211 | return o 212 | 213 | def feature(self, x, aux=False): 214 | x = self.conv1(x) 215 | x = self.bn1(x) 216 | x = self.relu(x) 217 | x = self.maxpool(x) 218 | 219 | x = self.layer1(x) 220 | x = self.layer2(x) 221 | x = self.layer3(x) 222 | x = self.layer4(x) 223 | 224 | x = self.avgpool(x) 225 | x = x.view(x.size(0), -1) 226 | 227 | return x 228 | 229 | def feature_list(self, x): 230 | output_list = [] 231 | x = self.conv1(x) 232 | x = self.bn1(x) 233 | x = self.relu(x) 234 | x = self.maxpool(x) 235 | 236 | x = self.layer1(x) 237 | output_list.append(self.avgpool(x)) 238 | x = self.layer2(x) 239 | output_list.append(self.avgpool(x)) 240 | x = self.layer3(x) 241 | output_list.append(self.avgpool(x)) 242 | x = self.layer4(x) 243 | output_list.append(self.avgpool(x)) 244 | 245 | return output_list 246 | 247 | def rot(self, x, aux=False): 248 | x = self.conv1(x) 249 | x = self.bn1(x) 250 | x = self.relu(x) 251 | x = self.maxpool(x) 252 | 253 | x = self.layer1(x) 254 | x = self.layer2(x) 255 | x = self.layer3(x) 256 | x = self.layer4(x) 257 | 258 | x = self.avgpool(x) 259 | x = x.view(x.size(0), -1) 260 | o_rot = self.linear_rot(x) 261 | 262 | return o_rot 263 | 264 | def forward_rot(self, x, aux=False): 265 | x = self.conv1(x) 266 | x = self.bn1(x) 267 | x = self.relu(x) 268 | x = self.maxpool(x) 269 | 270 | x = self.layer1(x) 271 | x = self.layer2(x) 272 | x = self.layer3(x) 273 | x = self.layer4(x) 274 | 275 | x = self.avgpool(x) 276 | x = x.view(x.size(0), -1) 277 | o = self.fc(x) 278 | o_rot = self.linear_rot(x) 279 | 280 | return o, o_rot 281 | 282 | 283 | def resnet10(pretrained=False, **kwargs): 284 | """Constructs a ResNet-10 model. 285 | Args: 286 | pretrained (bool): If True, returns a model pre-trained on ImageNet 287 | """ 288 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 289 | return model 290 | 291 | def resnet18(pretrained=False, **kwargs): 292 | """Constructs a ResNet-18 model. 293 | Args: 294 | pretrained (bool): If True, returns a model pre-trained on ImageNet 295 | """ 296 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 297 | if pretrained: 298 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 299 | return model 300 | 301 | 302 | def resnet34(pretrained=False, **kwargs): 303 | """Constructs a ResNet-34 model. 304 | Args: 305 | pretrained (bool): If True, returns a model pre-trained on ImageNet 306 | """ 307 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 308 | if pretrained: 309 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 310 | return model 311 | 312 | 313 | def resnet50(pretrained=False, **kwargs): 314 | """Constructs a ResNet-50 model. 315 | Args: 316 | pretrained (bool): If True, returns a model pre-trained on ImageNet 317 | """ 318 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 319 | if pretrained: 320 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 321 | return model 322 | 323 | 324 | def resnet101(pretrained=False, **kwargs): 325 | """Constructs a ResNet-101 model. 326 | Args: 327 | pretrained (bool): If True, returns a model pre-trained on ImageNet 328 | """ 329 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 330 | if pretrained: 331 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 332 | return model 333 | 334 | 335 | def resnet152(pretrained=False, **kwargs): 336 | """Constructs a ResNet-152 model. 337 | Args: 338 | pretrained (bool): If True, returns a model pre-trained on ImageNet 339 | """ 340 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 341 | if pretrained: 342 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 343 | return model 344 | -------------------------------------------------------------------------------- /OpenCoS/datasets.py: -------------------------------------------------------------------------------- 1 | import csv, torchvision, numpy as np, random, os 2 | from PIL import Image 3 | 4 | from torch.utils.data import Sampler, Dataset, DataLoader, BatchSampler, SequentialSampler, RandomSampler, Subset, ConcatDataset 5 | from torchvision import transforms, datasets 6 | from collections import defaultdict 7 | from randaugment import RandAugmentMC 8 | import bisect 9 | import warnings 10 | import numpy as np 11 | import torch 12 | 13 | 14 | 15 | 16 | class TransformTwice: 17 | def __init__(self, transform): 18 | self.transform = transform 19 | 20 | def __call__(self, inp): 21 | out1 = self.transform(inp) 22 | out2 = self.transform(inp) 23 | return out1, out2 24 | 25 | class TransformDouble: 26 | def __init__(self, transform1, transform2): 27 | self.transform1 = transform1 28 | self.transform2 = transform2 29 | 30 | def __call__(self, inp): 31 | out1 = self.transform1(inp) 32 | out2 = self.transform2(inp) 33 | return out1, out2 34 | 35 | class TransformList: 36 | def __init__(self, transform1, transform2, K): 37 | self.transform1 = transform1 38 | self.transform2 = transform2 39 | self.K = K 40 | 41 | def __call__(self, inp): 42 | return self.transform1(inp), [self.transform2(inp) for _ in range(self.K)] 43 | 44 | class DatasetWrapper(Dataset): 45 | # Additinoal attributes 46 | # - indices 47 | # - classwise_indices 48 | # - num_classes 49 | # - get_class 50 | 51 | def __init__(self, dataset, indices=None): 52 | self.base_dataset = dataset 53 | if indices is None: 54 | self.indices = list(range(len(dataset))) 55 | else: 56 | self.indices = indices 57 | 58 | # torchvision 0.2.0 compatibility 59 | if torchvision.__version__.startswith('0.2'): 60 | if isinstance(self.base_dataset, datasets.ImageFolder): 61 | self.base_dataset.targets = [s[1] for s in self.base_dataset.imgs] 62 | else: 63 | if self.base_dataset.train: 64 | self.base_dataset.targets = self.base_dataset.train_labels 65 | else: 66 | self.base_dataset.targets = self.base_dataset.test_labels 67 | 68 | self.classwise_indices = defaultdict(list) 69 | for i in range(len(self)): 70 | y = self.base_dataset.targets[int(self.indices[i])] 71 | self.classwise_indices[y].append(i) 72 | self.num_classes = len(self.classwise_indices.keys()) 73 | 74 | def __getitem__(self, i): 75 | return self.base_dataset[self.indices[i]] 76 | 77 | def __len__(self): 78 | return len(self.indices) 79 | 80 | def get_class(self, i): 81 | return self.base_dataset.targets[self.indices[i]] 82 | 83 | def reset(self): 84 | self.__init__(self.base_dataset, self.indices) 85 | 86 | 87 | def load_dataset(name, root, sample='default', **kwargs): 88 | 89 | if 'imagenet' in kwargs['uroot']: 90 | imagesize = 224 91 | else: 92 | imagesize = 32 93 | 94 | if imagesize==32: 95 | transform_train = transforms.Compose([ 96 | transforms.RandomCrop(32, padding=4), 97 | transforms.RandomHorizontalFlip(), 98 | transforms.ToTensor(), 99 | ]) 100 | transform_test = transforms.Compose([ 101 | transforms.Resize(imagesize), 102 | transforms.ToTensor(), 103 | ]) 104 | 105 | if kwargs['use_jitter']: 106 | ### color augmentation ### 107 | color_jitter_strength = 0.5 108 | color_jitter = transforms.ColorJitter(0.8*color_jitter_strength, 0.8*color_jitter_strength, 0.8*color_jitter_strength, 0.2*color_jitter_strength) 109 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) 110 | rnd_gray = transforms.RandomGrayscale(p=0.2) 111 | transform_aug = transforms.Compose([ 112 | rnd_color_jitter, 113 | rnd_gray, 114 | transforms.RandomResizedCrop(32), 115 | transforms.RandomHorizontalFlip(), 116 | transforms.ToTensor(), 117 | ]) 118 | else: 119 | ### RandAugment ### 120 | transform_aug = transforms.Compose([ 121 | transforms.RandomCrop(32, padding=4), 122 | transforms.RandomHorizontalFlip(), 123 | RandAugmentMC(n=2, m=10), 124 | transforms.ToTensor(), 125 | ]) 126 | else: 127 | transform_train = transforms.Compose([ 128 | transforms.RandomResizedCrop(224), 129 | transforms.RandomHorizontalFlip(), 130 | transforms.ToTensor(), 131 | ]) 132 | transform_test = transforms.Compose([ 133 | transforms.Resize(256), 134 | transforms.CenterCrop(224), 135 | transforms.ToTensor(), 136 | ]) 137 | 138 | if kwargs['use_jitter']: 139 | ### color augmentation ### 140 | color_jitter_strength = 1 141 | color_jitter = transforms.ColorJitter(0.8*color_jitter_strength, 0.8*color_jitter_strength, 0.8*color_jitter_strength, 0.2*color_jitter_strength) 142 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) 143 | rnd_gray = transforms.RandomGrayscale(p=0.2) 144 | transform_aug = transforms.Compose([ 145 | rnd_color_jitter, 146 | rnd_gray, 147 | transforms.RandomResizedCrop(224), 148 | transforms.RandomHorizontalFlip(), 149 | transforms.ToTensor(), 150 | ]) 151 | else: 152 | ### RandAugment ### 153 | transform_aug = transforms.Compose([ 154 | transforms.RandomResizedCrop(224), 155 | transforms.RandomHorizontalFlip(), 156 | RandAugmentMC(n=2, m=10), 157 | transforms.ToTensor(), 158 | ]) 159 | 160 | if name == 'cifar10': 161 | Label = datasets.CIFAR10 162 | if int(kwargs['pc']) == 5000: 163 | trainidx = None 164 | else: 165 | trainidx = np.load(os.path.join('splits', name + '_' + kwargs['pc'] + 'pc_label_idx.npy')).astype(np.int64) 166 | unlabel_idx = [i for i in range(50000) if i not in trainidx] 167 | trainset = DatasetWrapper(Label(root, train=True, download=True, transform=TransformDouble(transform_train, transform_aug)), trainidx) 168 | testset = Label(root, train=False, download=True, transform=transform_test) 169 | 170 | elif name == 'cifar100': 171 | Label = datasets.CIFAR100 172 | if int(kwargs['pc']) == 500: 173 | trainidx = None 174 | else: 175 | trainidx = np.load(os.path.join('splits', name + '_' + kwargs['pc'] + 'pc_label_idx.npy')).astype(np.int64) 176 | unlabel_idx = [i for i in range(50000) if i not in trainidx] 177 | trainset = DatasetWrapper(Label(root, train=True, download=True, transform=TransformDouble(transform_train, transform_aug)), trainidx) 178 | testset = Label(root, train=False, download=True, transform=transform_test) 179 | 180 | elif name == 'animal': 181 | animal_class = [2,3,4,5,6,7] 182 | def target_transformA(target): 183 | return target - 2 184 | 185 | Label = datasets.CIFAR10 186 | trainidx = np.load(os.path.join('splits', 'cifar10_animal_' + kwargs['pc'] + 'pc_label_idx.npy')).astype(np.int64) 187 | testidx = np.load(os.path.join('splits', 'cifar10_animal_test_idx.npy')).astype(np.int64) 188 | trainset = DatasetWrapper(Label(root, train=True, download=True, transform=TransformDouble(transform_train, transform_aug), target_transform=target_transformA), trainidx) 189 | testset = DatasetWrapper(Label(root, train=False, download=True, transform=transform_test, target_transform=target_transformA),testidx) 190 | notanimal_idx = np.load(os.path.join('splits', 'cifar10_notanimal_unlabel_idx.npy')).astype(np.int64) 191 | unlabel_idx = [i for i in range(50000) if i not in trainidx and i not in notanimal_idx] 192 | 193 | elif name in ['dog_cls', 'bird_cls', 'primate_cls', 'insect_cls', 'reptile_cls', 'aquatic_animal_cls', 'food_cls', 'produce_cls', 'scenery_cls']: 194 | train_val_dataset_dir = os.path.join(kwargs['imgroot'], "train") 195 | test_dataset_dir = os.path.join(kwargs['imgroot'], "val") 196 | 197 | superclasses = np.load(os.path.join('splits_img', name + '_idx.npy')).astype(np.int64) 198 | def target_transformS(target): 199 | return int(np.where(superclasses == target)[0]) 200 | 201 | trainidx = np.load(os.path.join('splits_img', name + '_'+ kwargs['pc'] +'pc_train.npy')).astype(np.int64) 202 | testidx = np.load(os.path.join('splits_img', name + '_test.npy')).astype(np.int64) 203 | 204 | trainset = DatasetWrapper(datasets.ImageFolder(root=train_val_dataset_dir, transform=TransformDouble(transform_train, transform_aug), target_transform=target_transformS), trainidx) 205 | testset = Subset(datasets.ImageFolder(root=test_dataset_dir, transform=transform_test, target_transform=target_transformS), testidx) 206 | 207 | else: 208 | raise Exception('Unknown dataset: {}'.format(name)) 209 | 210 | if kwargs['method']=='default': 211 | unlabel_transform = TransformDouble(transform_train, transform_aug) 212 | elif kwargs['method']=='mixmatch': 213 | unlabel_transform = TransformTwice(transform_train) 214 | elif kwargs['method']=='remixmatch': 215 | unlabel_transform = TransformList(transform_train, transform_aug, kwargs['naug']) 216 | else: 217 | raise Exception('Unknown methods: {}'.format(kwargs['method'])) 218 | 219 | Unlabel = [] 220 | if 'imagenet' in kwargs['uroot']: 221 | unlabel_dataset_dir = os.path.join(kwargs['imgroot'], "train") 222 | if name in ['dog_cls', 'bird_cls', 'primate_cls', 'insect_cls' , 'reptile_cls', 'aquatic_animal_cls', 'food_cls', 'produce_cls', 'scenery_cls']: 223 | unlabel_idx = np.arange(1281167) 224 | unlabel_idx = list(set(unlabel_idx) - set(trainidx)) 225 | Unlabel.append(Subset(datasets.ImageFolder(root=unlabel_dataset_dir, transform=unlabel_transform), unlabel_idx)) 226 | else: 227 | Unlabel.append(datasets.ImageFolder(root=unlabel_dataset_dir, transform=unlabel_transform)) 228 | 229 | if 'tiny' in kwargs['uroot']: 230 | unlabel_dataset_dir = os.path.join(kwargs['tinyroot'], "train") 231 | if kwargs['ood_samples']>0: 232 | tiny_index = np.load(os.path.join('splits', 'tiny_unlabel_train_idx.npy')).astype(np.int64) 233 | tiny_index = tiny_index[:kwargs['ood_samples']] 234 | Unlabel.append(Subset(datasets.ImageFolder(root=unlabel_dataset_dir, transform=unlabel_transform), tiny_index)) 235 | if name in ['cifar10']: 236 | cifar_unlabel = np.load(os.path.join('splits', 'cifar10_unlabel_train_idx.npy')).astype(np.int64) 237 | cifar_unlabel = cifar_unlabel[:50000 - len(trainidx) - kwargs['ood_samples']] 238 | Unlabel.append(Subset(datasets.CIFAR10(root, train=True, download=True, transform=unlabel_transform), cifar_unlabel)) 239 | elif name in ['cifar100']: 240 | cifar_unlabel = np.load(os.path.join('splits', 'cifar100_unlabel_train_idx.npy')).astype(np.int64) 241 | cifar_unlabel = cifar_unlabel[:50000 - len(trainidx) - kwargs['ood_samples']] 242 | Unlabel.append(Subset(datasets.CIFAR100(root, train=True, download=True, transform=unlabel_transform), cifar_unlabel)) 243 | else: 244 | raise Exception('Unknown labeled dataset: {}'.format(name)) 245 | 246 | if 'cten' in kwargs['uroot']: 247 | if name in ['cifar10', 'animal']: 248 | if name == 'animal': 249 | Unlabel.append(Subset(datasets.CIFAR10(root, train=True, download=True, transform=unlabel_transform), notanimal_idx)) 250 | Unlabel.append(Subset(datasets.CIFAR10(root, train=True, download=True, transform=unlabel_transform), unlabel_idx)) 251 | else: 252 | Unlabel.append(datasets.CIFAR10(root, train=True, download=True, transform=unlabel_transform)) 253 | if 'chund' in kwargs['uroot']: 254 | if name in ['cifar100']: 255 | Unlabel.append(Subset(datasets.CIFAR100(root, train=True, download=True, transform=unlabel_transform), unlabel_idx)) 256 | else: 257 | Unlabel.append(datasets.CIFAR100(root, train=True, download=True, transform=unlabel_transform)) 258 | if 'svhn' in kwargs['uroot']: 259 | if kwargs['ood_samples']>0: 260 | svhn_index = np.load(os.path.join('splits', 'svhn_unlabel_train_idx.npy')).astype(np.int64) 261 | svhn_index = svhn_index[:kwargs['ood_samples']] 262 | Unlabel.append(Subset(datasets.SVHN(root, split='train', download=True, transform=unlabel_transform), svhn_index)) 263 | if name in ['cifar10']: 264 | cifar_unlabel = np.load(os.path.join('splits', 'cifar10_unlabel_train_idx.npy')).astype(np.int64) 265 | cifar_unlabel = cifar_unlabel[:50000 - len(trainidx) - kwargs['ood_samples']] 266 | Unlabel.append(Subset(datasets.CIFAR10(root, train=True, download=True, transform=unlabel_transform), cifar_unlabel)) 267 | elif name in ['cifar100']: 268 | cifar_unlabel = np.load(os.path.join('splits', 'cifar100_unlabel_train_idx.npy')).astype(np.int64) 269 | cifar_unlabel = cifar_unlabel[:50000 - len(trainidx) - kwargs['ood_samples']] 270 | Unlabel.append(Subset(datasets.CIFAR100(root, train=True, download=True, transform=unlabel_transform), cifar_unlabel)) 271 | else: 272 | assert(False) 273 | 274 | unlabeled_trainset = ConcatDataset(Unlabel) 275 | 276 | return trainset, unlabeled_trainset, testset 277 | 278 | -------------------------------------------------------------------------------- /OpenCoS/models/aux_batchnorm.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from torch.nn import Module 9 | from torch.nn import Parameter 10 | from torch.nn import init 11 | 12 | 13 | class _NormBase(Module): 14 | """Common base of _InstanceNorm and _BatchNorm""" 15 | _version = 2 16 | __constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias', 17 | 'running_mean', 'running_var', 'num_batches_tracked', 18 | 'num_features', 'affine'] 19 | 20 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 21 | track_running_stats=True): 22 | super(_NormBase, self).__init__() 23 | self.num_features = num_features 24 | self.eps = eps 25 | self.momentum = momentum 26 | self.affine = affine 27 | self.track_running_stats = track_running_stats 28 | if self.affine: 29 | self.weight = Parameter(torch.Tensor(num_features)) 30 | self.bias = Parameter(torch.Tensor(num_features)) 31 | else: 32 | self.register_parameter('weight', None) 33 | self.register_parameter('bias', None) 34 | if self.track_running_stats: 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 38 | else: 39 | self.register_parameter('running_mean', None) 40 | self.register_parameter('running_var', None) 41 | self.register_parameter('num_batches_tracked', None) 42 | self.reset_parameters() 43 | 44 | def reset_running_stats(self): 45 | if self.track_running_stats: 46 | self.running_mean.zero_() 47 | self.running_var.fill_(1) 48 | self.num_batches_tracked.zero_() 49 | 50 | def reset_parameters(self): 51 | self.reset_running_stats() 52 | if self.affine: 53 | init.ones_(self.weight) 54 | init.zeros_(self.bias) 55 | 56 | def _check_input_dim(self, input): 57 | raise NotImplementedError 58 | 59 | def extra_repr(self): 60 | return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 61 | 'track_running_stats={track_running_stats}'.format(**self.__dict__) 62 | 63 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 64 | missing_keys, unexpected_keys, error_msgs): 65 | version = local_metadata.get('version', None) 66 | 67 | if (version is None or version < 2) and self.track_running_stats: 68 | # at version 2: added num_batches_tracked buffer 69 | # this should have a default value of 0 70 | num_batches_tracked_key = prefix + 'num_batches_tracked' 71 | if num_batches_tracked_key not in state_dict: 72 | state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) 73 | 74 | super(_NormBase, self)._load_from_state_dict( 75 | state_dict, prefix, local_metadata, strict, 76 | missing_keys, unexpected_keys, error_msgs) 77 | 78 | 79 | class _BatchNorm(Module): 80 | 81 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 82 | track_running_stats=True, divide=False): 83 | super(_BatchNorm, self).__init__() 84 | self.bn = _NormBase(num_features, eps, momentum, affine, track_running_stats) 85 | self.bn_aux = _NormBase(num_features, eps, momentum, affine if divide else False, track_running_stats) 86 | self.divide = divide 87 | 88 | def forward(self, input, aux=False): 89 | self._check_input_dim(input) 90 | 91 | if aux: 92 | bn = self.bn_aux 93 | else: 94 | bn = self.bn 95 | 96 | # exponential_average_factor is set to self.momentum 97 | # (when it is available) only so that if gets updated 98 | # in ONNX graph when this node is exported to ONNX. 99 | if bn.momentum is None: 100 | exponential_average_factor = 0.0 101 | else: 102 | exponential_average_factor = bn.momentum 103 | 104 | if bn.training and bn.track_running_stats: 105 | # TODO: if statement only here to tell the jit to skip emitting this when it is None 106 | if bn.num_batches_tracked is not None: 107 | bn.num_batches_tracked = bn.num_batches_tracked + 1 108 | if bn.momentum is None: # use cumulative moving average 109 | exponential_average_factor = 1.0 / float(bn.num_batches_tracked) 110 | else: # use exponential moving average 111 | exponential_average_factor = bn.momentum 112 | 113 | return F.batch_norm( 114 | input, bn.running_mean, bn.running_var, bn.weight if self.divide else self.bn.weight, bn.bias if self.divide else self.bn.bias, 115 | bn.training or not bn.track_running_stats, 116 | exponential_average_factor, bn.eps) 117 | 118 | 119 | class BatchNorm1d(_BatchNorm): 120 | r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D 121 | inputs with optional additional channel dimension) as described in the paper 122 | `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . 123 | .. math:: 124 | y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 125 | The mean and standard-deviation are calculated per-dimension over 126 | the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 127 | of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set 128 | to 1 and the elements of :math:`\beta` are set to 0. 129 | Also by default, during training this layer keeps running estimates of its 130 | computed mean and variance, which are then used for normalization during 131 | evaluation. The running estimates are kept with a default :attr:`momentum` 132 | of 0.1. 133 | If :attr:`track_running_stats` is set to ``False``, this layer then does not 134 | keep running estimates, and batch statistics are instead used during 135 | evaluation time as well. 136 | .. note:: 137 | This :attr:`momentum` argument is different from one used in optimizer 138 | classes and the conventional notion of momentum. Mathematically, the 139 | update rule for running statistics here is 140 | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, 141 | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 142 | new observed value. 143 | Because the Batch Normalization is done over the `C` dimension, computing statistics 144 | on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. 145 | Args: 146 | num_features: :math:`C` from an expected input of size 147 | :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` 148 | eps: a value added to the denominator for numerical stability. 149 | Default: 1e-5 150 | momentum: the value used for the running_mean and running_var 151 | computation. Can be set to ``None`` for cumulative moving average 152 | (i.e. simple average). Default: 0.1 153 | affine: a boolean value that when set to ``True``, this module has 154 | learnable affine parameters. Default: ``True`` 155 | track_running_stats: a boolean value that when set to ``True``, this 156 | module tracks the running mean and variance, and when set to ``False``, 157 | this module does not track such statistics and always uses batch 158 | statistics in both training and eval modes. Default: ``True`` 159 | Shape: 160 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 161 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 162 | Examples:: 163 | >>> # With Learnable Parameters 164 | >>> m = nn.BatchNorm1d(100) 165 | >>> # Without Learnable Parameters 166 | >>> m = nn.BatchNorm1d(100, affine=False) 167 | >>> input = torch.randn(20, 100) 168 | >>> output = m(input) 169 | .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: 170 | https://arxiv.org/abs/1502.03167 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | 178 | 179 | class BatchNorm2d(_BatchNorm): 180 | r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs 181 | with additional channel dimension) as described in the paper 182 | `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . 183 | .. math:: 184 | y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 185 | The mean and standard-deviation are calculated per-dimension over 186 | the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 187 | of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set 188 | to 1 and the elements of :math:`\beta` are set to 0. 189 | Also by default, during training this layer keeps running estimates of its 190 | computed mean and variance, which are then used for normalization during 191 | evaluation. The running estimates are kept with a default :attr:`momentum` 192 | of 0.1. 193 | If :attr:`track_running_stats` is set to ``False``, this layer then does not 194 | keep running estimates, and batch statistics are instead used during 195 | evaluation time as well. 196 | .. note:: 197 | This :attr:`momentum` argument is different from one used in optimizer 198 | classes and the conventional notion of momentum. Mathematically, the 199 | update rule for running statistics here is 200 | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, 201 | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 202 | new observed value. 203 | Because the Batch Normalization is done over the `C` dimension, computing statistics 204 | on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. 205 | Args: 206 | num_features: :math:`C` from an expected input of size 207 | :math:`(N, C, H, W)` 208 | eps: a value added to the denominator for numerical stability. 209 | Default: 1e-5 210 | momentum: the value used for the running_mean and running_var 211 | computation. Can be set to ``None`` for cumulative moving average 212 | (i.e. simple average). Default: 0.1 213 | affine: a boolean value that when set to ``True``, this module has 214 | learnable affine parameters. Default: ``True`` 215 | track_running_stats: a boolean value that when set to ``True``, this 216 | module tracks the running mean and variance, and when set to ``False``, 217 | this module does not track such statistics and always uses batch 218 | statistics in both training and eval modes. Default: ``True`` 219 | Shape: 220 | - Input: :math:`(N, C, H, W)` 221 | - Output: :math:`(N, C, H, W)` (same shape as input) 222 | Examples:: 223 | >>> # With Learnable Parameters 224 | >>> m = nn.BatchNorm2d(100) 225 | >>> # Without Learnable Parameters 226 | >>> m = nn.BatchNorm2d(100, affine=False) 227 | >>> input = torch.randn(20, 100, 35, 45) 228 | >>> output = m(input) 229 | .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: 230 | https://arxiv.org/abs/1502.03167 231 | """ 232 | 233 | def _check_input_dim(self, input): 234 | if input.dim() != 4: 235 | raise ValueError('expected 4D input (got {}D input)' 236 | .format(input.dim())) 237 | 238 | 239 | class BatchNorm3d(_BatchNorm): 240 | r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs 241 | with additional channel dimension) as described in the paper 242 | `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . 243 | .. math:: 244 | y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 245 | The mean and standard-deviation are calculated per-dimension over 246 | the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 247 | of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set 248 | to 1 and the elements of :math:`\beta` are set to 0. 249 | Also by default, during training this layer keeps running estimates of its 250 | computed mean and variance, which are then used for normalization during 251 | evaluation. The running estimates are kept with a default :attr:`momentum` 252 | of 0.1. 253 | If :attr:`track_running_stats` is set to ``False``, this layer then does not 254 | keep running estimates, and batch statistics are instead used during 255 | evaluation time as well. 256 | .. note:: 257 | This :attr:`momentum` argument is different from one used in optimizer 258 | classes and the conventional notion of momentum. Mathematically, the 259 | update rule for running statistics here is 260 | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, 261 | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 262 | new observed value. 263 | Because the Batch Normalization is done over the `C` dimension, computing statistics 264 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization 265 | or Spatio-temporal Batch Normalization. 266 | Args: 267 | num_features: :math:`C` from an expected input of size 268 | :math:`(N, C, D, H, W)` 269 | eps: a value added to the denominator for numerical stability. 270 | Default: 1e-5 271 | momentum: the value used for the running_mean and running_var 272 | computation. Can be set to ``None`` for cumulative moving average 273 | (i.e. simple average). Default: 0.1 274 | affine: a boolean value that when set to ``True``, this module has 275 | learnable affine parameters. Default: ``True`` 276 | track_running_stats: a boolean value that when set to ``True``, this 277 | module tracks the running mean and variance, and when set to ``False``, 278 | this module does not track such statistics and always uses batch 279 | statistics in both training and eval modes. Default: ``True`` 280 | Shape: 281 | - Input: :math:`(N, C, D, H, W)` 282 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 283 | Examples:: 284 | >>> # With Learnable Parameters 285 | >>> m = nn.BatchNorm3d(100) 286 | >>> # Without Learnable Parameters 287 | >>> m = nn.BatchNorm3d(100, affine=False) 288 | >>> input = torch.randn(20, 100, 35, 45, 10) 289 | >>> output = m(input) 290 | .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: 291 | https://arxiv.org/abs/1502.03167 292 | """ 293 | 294 | def _check_input_dim(self, input): 295 | if input.dim() != 5: 296 | raise ValueError('expected 5D input (got {}D input)' 297 | .format(input.dim())) 298 | 299 | 300 | #if __name__ == '__main__': 301 | # l = BatchNorm2d(3) 302 | # x = torch.randn(2,3,8,8) 303 | # l.train() 304 | # y1 = l(x) 305 | # y1 = l(x, aux=True) 306 | # l.eval() 307 | # y2 = l(x) 308 | # y2 = l(x, aux=True) 309 | # print (y1.size()) 310 | # print (y1 == y2) # train/eval 311 | -------------------------------------------------------------------------------- /OpenCoS/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import csv 7 | import os, logging 8 | import copy 9 | import random 10 | 11 | import numpy as np 12 | import torch 13 | from torch.autograd import Variable, grad 14 | import torch.backends.cudnn as cudnn 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.optim as optim 18 | import torchvision.transforms as transforms 19 | 20 | import models 21 | from utils import progress_bar, set_logging_defaults 22 | from datasets import load_dataset 23 | from collections import OrderedDict 24 | 25 | # torch_version = [int(v) for v in torch.__version__.split('.')] 26 | tensorboardX_compat = True #(torch_version[0] >= 1) and (torch_version[1] >= 1) # PyTorch >= 1.1 27 | try: 28 | from tensorboardX import SummaryWriter 29 | except ImportError: 30 | print ('No tensorboardX package is found. Start training without tensorboardX') 31 | tensorboardX_compat = False 32 | #raise RuntimeError("No tensorboardX package is found. Please install with the command: \npip install tensorboardX") 33 | 34 | parser = argparse.ArgumentParser(description='SimCLR linear evaluation or fine-tuning Training') 35 | parser.add_argument('--lr', default=0.002, type=float, help='learning rate') 36 | parser.add_argument('--model', default="wide_resnet", type=str, 37 | help='model type (default: wide_resnet)') 38 | parser.add_argument('--name', default='0', type=str, help='name of run') 39 | parser.add_argument('--batch-size', default=64, type=int, help='batch size') 40 | parser.add_argument('--num_iters', default=50000, type=int, help='total epochs to run') 41 | parser.add_argument('--decay', default=0, type=float, help='weight decay') 42 | parser.add_argument('--ngpu', default=1, type=int, help='number of gpu') 43 | parser.add_argument('--sgpu', default=0, type=int, help='gpu index (start)') 44 | parser.add_argument('--dataset', default='cifar10', type=str, help='the name for dataset') 45 | parser.add_argument('--udata', default='svhn', type=str, help='type of unlabel data') 46 | parser.add_argument('--tinyroot', default='/data/tinyimagenet/tiny-imagenet-200/', type=str, help='TinyImageNet directory') 47 | parser.add_argument('--imgroot', default='/data/ILSVRC/Data/CLS-LOC/', type=str, help='unlabel data directory') 48 | parser.add_argument('--dataroot', default='/data/', type=str, help='data directory') 49 | parser.add_argument('--saveroot', default='./results', type=str, help='data directory') 50 | parser.add_argument('--finetune', '-ft', action='store_true', help='finetuning') 51 | parser.add_argument('--pc', default=25, type=int, help='number of samples per class') 52 | parser.add_argument('--nworkers', default=4, type=int, help='num_workers') 53 | 54 | parser.add_argument('--multinomial', action='store_true', help='linear evaluation') 55 | parser.add_argument('--stop_iters', default=None, type=int, help='early stopping') 56 | parser.add_argument('--model_path', default=None, type=str, help='model path') 57 | parser.add_argument('--ood_samples', default=0, type=int, help='number of ood samples in [0,10000,20000,30000,40000]') 58 | parser.add_argument('--fix_optim', action='store_true', help='using optimizer of FixMatch') 59 | parser.add_argument('--simclr_optim', action='store_true', help='using optimizer of SimCLR semi finetune') 60 | args = parser.parse_args() 61 | use_cuda = torch.cuda.is_available() 62 | 63 | best_val = 0 # best validation accuracy 64 | start_iters = 0 # start from epoch 0 or last checkpoint epoch 65 | current_val = 0 66 | 67 | cudnn.benchmark = True 68 | 69 | # Data 70 | _labeled_trainset, _unlabeled_trainset, _labeled_testset = load_dataset(args.dataset, args.dataroot, batch_size=args.batch_size, pc=str(args.pc), method='default', uroot=args.udata, tinyroot=args.tinyroot, imgroot=args.imgroot, ood_samples=args.ood_samples, use_jitter=False) 71 | _labeled_num_class = _labeled_trainset.num_classes 72 | print('Numclass: ', _labeled_num_class) 73 | print('==> Preparing dataset: {}'.format(args.dataset)) 74 | print('Number of label dataset: ' ,len(_labeled_trainset)) 75 | print('Number of unlabel dataset: ',len(_unlabeled_trainset)) 76 | print('Number of test dataset: ',len(_labeled_testset)) 77 | 78 | 79 | logdir = os.path.join(args.saveroot, args.dataset, args.model, args.name) 80 | set_logging_defaults(logdir, args) 81 | logger = logging.getLogger('main') 82 | logname = os.path.join(logdir, 'log.csv') 83 | if args.multinomial: 84 | tensorboardX_compat = False 85 | if tensorboardX_compat: 86 | writer = SummaryWriter(logdir=logdir) 87 | 88 | if use_cuda: 89 | torch.cuda.set_device(args.sgpu) 90 | print(torch.cuda.device_count()) 91 | print('Using CUDA..') 92 | 93 | criterion = nn.CrossEntropyLoss() 94 | 95 | def cycle(iterable): 96 | while True: 97 | for x in iterable: 98 | yield x 99 | 100 | def train(): 101 | # Model 102 | print('==> Building model: {}'.format(args.model)) 103 | net = models.load_model(args.model, _labeled_num_class) 104 | if args.finetune: 105 | model_dict = net.state_dict() 106 | if (args.model in ['resnet50']): 107 | try: 108 | pretrained_dict = torch.load(args.model_path, map_location='cpu')['model'] 109 | except KeyError: 110 | pretrained_dict = torch.load(args.model_path, map_location='cpu')['net'] 111 | classifier = ['fc.weight', 'fc.bias'] 112 | imagesize = 224 113 | elif (args.model in ['wide_resnet', 'CIFAR_ResNet50']): 114 | try: 115 | pretrained_dict = torch.load(args.model_path, map_location='cpu')['model'] 116 | except KeyError: 117 | pretrained_dict = torch.load(args.model_path, map_location='cpu')['net'] 118 | classifier = ['linear.weight', 'linear.bias'] 119 | imagesize = 32 120 | new_state_dict = OrderedDict() 121 | for k, v in pretrained_dict.items(): 122 | if k[:6]=='module': 123 | name = k[7:] # remove `module.` 124 | else: 125 | name = k 126 | new_state_dict[name] = v 127 | new_state_dict = {k: v for k, v in new_state_dict.items() if k in model_dict and k not in classifier} 128 | model_dict.update(new_state_dict) 129 | net.load_state_dict(model_dict) 130 | 131 | net.cuda() 132 | print(' Total params: %.2fM' % (sum(p.numel() for p in net.parameters())/1000000.0)) 133 | # print(net) 134 | if args.ngpu > 1: 135 | net = torch.nn.DataParallel(net, device_ids=list(range(args.sgpu, args.sgpu + args.ngpu))) 136 | 137 | if args.simclr_optim: 138 | assert (not args.fix_optim) 139 | args.lr = 0.05 * float(args.batch_size) / 256 140 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=0, nesterov=True) 141 | elif args.fix_optim: 142 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.decay, nesterov=True) 143 | else: 144 | optimizer = optim.Adam(net.parameters(), lr=args.lr) 145 | 146 | net.train() 147 | if len(_labeled_trainset) < args.batch_size: 148 | rand_sampler = torch.utils.data.RandomSampler(_labeled_trainset, num_samples=args.batch_size, replacement=True) 149 | _labeled_trainloader = torch.utils.data.DataLoader(_labeled_trainset, batch_size=args.batch_size, sampler=rand_sampler, num_workers=0) 150 | else: 151 | _labeled_trainloader = torch.utils.data.DataLoader(_labeled_trainset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True) 152 | _labeled_testloader = torch.utils.data.DataLoader(_labeled_testset, batch_size=args.batch_size, shuffle=False, num_workers=4) 153 | _labeled_train_iter = iter(cycle(_labeled_trainloader)) 154 | train_loss = 0 155 | correct = 0 156 | total = 0 157 | 158 | run_iters = args.num_iters if args.stop_iters is None else args.stop_iters 159 | for batch_idx in range(start_iters, run_iters + 1): 160 | (inputs, inputs_aug), targets = next(_labeled_train_iter) 161 | 162 | if use_cuda: 163 | inputs = inputs.cuda() 164 | targets = targets.cuda() 165 | 166 | logits = net(inputs) 167 | 168 | loss = criterion(logits, targets) 169 | 170 | optimizer.zero_grad() 171 | loss.backward() 172 | optimizer.step() 173 | if args.fix_optim: 174 | adjust_learning_rate(optimizer, batch_idx+1) 175 | 176 | if batch_idx % 1000 == 0: 177 | if batch_idx // 1000 > (run_iters // 1000) - 5: 178 | median = True 179 | else: 180 | median = False 181 | logger = logging.getLogger('train') 182 | logger.info('[Iters {}] [Loss {:.3f}]'.format( 183 | batch_idx, 184 | train_loss/1000)) 185 | print('[Iters {}] [Loss {:.3f}]'.format( 186 | batch_idx, 187 | train_loss/1000)) 188 | if tensorboardX_compat: 189 | writer.add_scalar("training/loss", train_loss/1000, batch_idx+1) 190 | 191 | train_loss = 0 192 | save = val(net, batch_idx, _labeled_testloader, median=median) 193 | if save: 194 | checkpoint(net, optimizer, best_val, batch_idx) 195 | net.train() 196 | else: 197 | progress_bar(batch_idx % 1000, 1000, 'working...') 198 | 199 | checkpoint(net, optimizer, current_val, args.num_iters, last=True) 200 | 201 | 202 | class MergeDataset(torch.utils.data.Dataset): 203 | def __init__(self, dataset1, dataset2): 204 | assert len(dataset1)==len(dataset2) 205 | self.dataset1 = dataset1 206 | self.dataset2 = dataset2 207 | 208 | def __getitem__(self, i): 209 | return (self.dataset1[i][0],)+ self.dataset2[i] 210 | 211 | def __len__(self): 212 | return len(self.dataset1) 213 | 214 | 215 | def multinomial(): 216 | # Model 217 | print('==> Building model: {}'.format(args.model)) 218 | net = models.load_model(args.model, _labeled_num_class) 219 | if args.finetune: 220 | model_dict = net.state_dict() 221 | if (args.model in ['resnet50']): 222 | try: 223 | pretrained_dict = torch.load(args.model_path, map_location='cpu')['model'] 224 | except KeyError: 225 | pretrained_dict = torch.load(args.model_path, map_location='cpu')['net'] 226 | classifier = ['fc.weight', 'fc.bias'] 227 | imagesize = 224 228 | elif (args.model in ['wide_resnet', 'CIFAR_ResNet50']): 229 | try: 230 | pretrained_dict = torch.load(args.model_path, map_location='cpu')['model'] 231 | except KeyError: 232 | pretrained_dict = torch.load(args.model_path, map_location='cpu')['net'] 233 | classifier = ['linear.weight', 'linear.bias'] 234 | imagesize = 32 235 | new_state_dict = OrderedDict() 236 | for k, v in pretrained_dict.items(): 237 | if k[:6]=='module': 238 | name = k[7:] # remove `module.` 239 | else: 240 | name = k 241 | new_state_dict[name] = v 242 | new_state_dict = {k: v for k, v in new_state_dict.items() if k in model_dict and k not in classifier} 243 | model_dict.update(new_state_dict) 244 | net.load_state_dict(model_dict) 245 | 246 | net.cuda() 247 | print(' Total params: %.2fM' % (sum(p.numel() for p in net.parameters())/1000000.0)) 248 | # print(net) 249 | if args.ngpu > 1: 250 | net = torch.nn.DataParallel(net, device_ids=list(range(args.sgpu, args.sgpu + args.ngpu))) 251 | 252 | if (args.model in ['resnet50']): 253 | optimizer = optim.LBFGS(net.fc.parameters(), lr=1, max_iter=5000) 254 | transform_test = transforms.Compose([ 255 | transforms.Resize(224), # for linear eval 256 | transforms.CenterCrop(224), 257 | transforms.ToTensor(), 258 | ]) 259 | elif (args.model in ['wide_resnet', 'CIFAR_ResNet50']): 260 | optimizer = optim.LBFGS(net.linear.parameters(), lr=1, max_iter=5000) 261 | transform_test = transforms.Compose([ 262 | transforms.Resize(32), # for linear eval 263 | transforms.ToTensor(), 264 | ]) 265 | 266 | net.eval() 267 | 268 | labeled_trainset = copy.deepcopy(_labeled_trainset.base_dataset) # full dataset 269 | labeled_trainset.transform = transform_test 270 | labeled_trainset = torch.utils.data.Subset(labeled_trainset, _labeled_trainset.indices) # slicing 271 | 272 | labeled_trainloader = torch.utils.data.DataLoader(labeled_trainset, batch_size=min(args.batch_size, len(labeled_trainset)), shuffle=False, num_workers=4) 273 | labeled_testset = copy.deepcopy(_labeled_testset) # full dataset 274 | labeled_testset.transform = transform_test 275 | labeled_testloader = torch.utils.data.DataLoader(labeled_testset, batch_size=args.batch_size, shuffle=False, num_workers=4) 276 | 277 | train_loss = 0 278 | correct = 0 279 | total = 0 280 | 281 | feats = [] 282 | labels = [] 283 | for batch_idx, (inputs, targets) in enumerate(labeled_trainloader): 284 | if use_cuda: 285 | inputs = inputs.cuda() 286 | targets = targets.cuda() 287 | with torch.no_grad(): 288 | feats.append(net.feature(inputs)) 289 | labels.append(targets) 290 | feats = torch.cat(feats, dim=0) 291 | labels = torch.cat(labels, dim=0) 292 | 293 | 294 | def closure1(): 295 | optimizer.zero_grad() 296 | outputs = net.fc(feats) 297 | loss = criterion(outputs, labels) 298 | for param in net.fc.parameters(): 299 | loss += 0.5 * param.pow(2).sum() * 1e-4 300 | print('loss:', loss.item()) 301 | loss.backward() 302 | return loss 303 | 304 | def closure2(): 305 | optimizer.zero_grad() 306 | outputs = net.linear(feats) 307 | loss = criterion(outputs, labels) 308 | for param in net.linear.parameters(): 309 | loss += 0.5 * param.pow(2).sum() * 1e-4 310 | print('loss:', loss.item()) 311 | loss.backward() 312 | return loss 313 | 314 | if (args.model in ['resnet50']): 315 | optimizer.step(closure1) 316 | elif (args.model in ['wide_resnet', 'CIFAR_ResNet50']): 317 | optimizer.step(closure2) 318 | 319 | save = val(net, 100, labeled_testloader) 320 | checkpoint(net, optimizer, best_val, 100) 321 | 322 | median_acc = [] 323 | 324 | def val(net, iters, testloader, median=False): 325 | global best_val 326 | global median_acc 327 | global current_val 328 | net.eval() 329 | val_loss = 0.0 330 | correct = 0.0 331 | total = 0.0 332 | 333 | with torch.no_grad(): 334 | for batch_idx, (inputs, targets) in enumerate(testloader): 335 | if use_cuda: 336 | inputs, targets = inputs.cuda(), targets.cuda() 337 | 338 | outputs = net(inputs) 339 | loss = torch.mean(criterion(outputs, targets)) 340 | val_loss += loss.item() 341 | _, predicted = torch.max(outputs, 1) 342 | total += targets.size(0) 343 | correct += predicted.eq(targets.data).cpu().sum().float() 344 | progress_bar(batch_idx, len(testloader), 345 | 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 346 | % (val_loss/(batch_idx+1), 100.*correct/total, correct, total)) 347 | 348 | logger = logging.getLogger('test') 349 | logger.info('[Loss {:.3f}] [Acc {:.3f}]'.format( 350 | val_loss/(batch_idx+1), 100.*correct/total)) 351 | 352 | acc = 100.*correct/total 353 | 354 | if median: 355 | median_acc.append(acc.item()) 356 | if tensorboardX_compat: 357 | writer.add_scalar("validation/loss", val_loss/(batch_idx+1), iters+1) 358 | writer.add_scalar("validation/top1_acc", acc, iters+1) 359 | current_val = acc 360 | if acc > best_val: 361 | best_val = acc 362 | return True 363 | else: 364 | return False 365 | 366 | def checkpoint(net, optimizer, acc, iters, last=False): 367 | # Save checkpoint. 368 | print('Saving..') 369 | state = { 370 | 'net': net.state_dict(), 371 | 'optimizer': optimizer.state_dict(), 372 | 'acc': acc, 373 | 'iters': iters, 374 | 'rng_state': torch.get_rng_state() 375 | } 376 | torch.save(state, os.path.join(logdir, 'ckpt.t7' if (not last) else 'last_ckpt.t7')) 377 | 378 | 379 | def adjust_learning_rate(optimizer, iters): 380 | """decrease the learning rate""" 381 | lr = args.lr * np.cos(iters/(args.num_iters+1) * (7 * np.pi) / (2 * 8)) 382 | for param_group in optimizer.param_groups: 383 | param_group['lr'] = lr 384 | 385 | if args.multinomial: 386 | multinomial() 387 | else: 388 | train() 389 | 390 | print("Best Accuracy : {}".format(best_val)) 391 | print("Median Accuracy : {}".format(np.median(median_acc))) 392 | logger = logging.getLogger('best') 393 | if args.multinomial: 394 | logger.info('[Acc {:.3f}]'.format(best_val)) 395 | else: 396 | logger.info('[Acc {:.3f}] [MEDIAN Acc {:.3f}]'.format(best_val, np.median(median_acc))) 397 | if tensorboardX_compat: 398 | writer.close() 399 | -------------------------------------------------------------------------------- /OpenCoS/train_fixmatch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import csv 7 | import os, logging 8 | import copy 9 | from collections import OrderedDict 10 | 11 | import numpy as np 12 | import torch 13 | from torch.autograd import Variable, grad 14 | import torch.backends.cudnn as cudnn 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.optim as optim 18 | import torchvision.transforms as transforms 19 | 20 | import models 21 | from utils import progress_bar, set_logging_defaults 22 | from datasets import load_dataset 23 | 24 | # torch_version = [int(v) for v in torch.__version__.split('.')] 25 | tensorboardX_compat = True #(torch_version[0] >= 1) and (torch_version[1] >= 1) # PyTorch >= 1.1 26 | try: 27 | from tensorboardX import SummaryWriter 28 | except ImportError: 29 | print ('No tensorboardX package is found. Start training without tensorboardX') 30 | tensorboardX_compat = False 31 | #raise RuntimeError("No tensorboardX package is found. Please install with the command: \npip install tensorboardX") 32 | 33 | 34 | parser = argparse.ArgumentParser(description='FixMatch Training') 35 | parser.add_argument('--lr', default=0.002, type=float, help='learning rate') 36 | parser.add_argument('--model', default="wide_resnet", type=str, 37 | help='model type (default: wide_resnet)') 38 | parser.add_argument('--name', default='0', type=str, help='name of run') 39 | parser.add_argument('--batch-size', default=64, type=int, help='batch size') 40 | parser.add_argument('--num_iters', default=50000, type=int, help='total epochs to run') 41 | parser.add_argument('--decay', default=0, type=float, help='weight decay') 42 | parser.add_argument('--ngpu', default=1, type=int, help='number of gpu') 43 | parser.add_argument('--sgpu', default=0, type=int, help='gpu index (start)') 44 | parser.add_argument('--dataset', default='cifar10', type=str, help='the name for dataset') 45 | parser.add_argument('--dataroot', default='/data/', type=str, help='data directory') 46 | parser.add_argument('--udata', default='svhn', type=str, help='type of unlabel data') 47 | parser.add_argument('--tinyroot', default='/data/tinyimagenet/tiny-imagenet-200/', type=str, help='TinyImageNet directory') 48 | parser.add_argument('--imgroot', default='/data/ILSVRC/Data/CLS-LOC/', type=str, help='unlabel data directory') 49 | parser.add_argument('--saveroot', default='./results', type=str, help='data directory') 50 | parser.add_argument('--finetune', '-ft', action='store_true', help='finetuning') 51 | parser.add_argument('--pc', default=25, type=int, help='number of samples per class') 52 | parser.add_argument('--ema', action='store_true', help='EMA training') 53 | parser.add_argument('--nworkers', default=4, type=int, help='num_workers') 54 | parser.add_argument('--mu', default=7, type=int, help='unlabeled batch / labeled batch') 55 | parser.add_argument('--lmd_u', default=1., type=float, help='Lu loss weight') 56 | parser.add_argument('--ths_pred', default=0.95, type=float, help='parameter for threshold') 57 | 58 | parser.add_argument('--model_path', default=None, type=str, help='(unsupervised) pretrained model path') 59 | parser.add_argument('--ood_samples', default=0, type=int, help='number of ood samples in [0,10000,20000,30000,40000]') 60 | parser.add_argument('--fix_optim', action='store_true', help='using optimizer of FixMatch') 61 | parser.add_argument('--stop_iters', default=None, type=int, help='early stopping') 62 | parser.add_argument('--use_jitter', action='store_true', help='using jitter augmentation for unlabeled data') 63 | parser.add_argument('--simclr_optim', action='store_true', help='using optimizer of SimCLR semi finetune') 64 | args = parser.parse_args() 65 | use_cuda = torch.cuda.is_available() 66 | 67 | best_val = 0 # best validation accuracy 68 | best_val_ema = 0 # best validation accuracy 69 | start_iters = 0 # start from epoch 0 or last checkpoint epoch 70 | current_val = 0 71 | current_val_ema = 0 72 | 73 | cudnn.benchmark = True 74 | 75 | # Data 76 | _labeled_trainset, _unlabeled_trainset, _labeled_testset = load_dataset(args.dataset, args.dataroot, batch_size=args.batch_size, pc=str(args.pc), method='default', uroot=args.udata, tinyroot=args.tinyroot, imgroot=args.imgroot, ood_samples=args.ood_samples, use_jitter=args.use_jitter) 77 | _labeled_num_class = _labeled_trainset.num_classes 78 | print('Numclass: ', _labeled_num_class) 79 | print('==> Preparing dataset: {}'.format(args.dataset)) 80 | print('Number of label dataset: ' ,len(_labeled_trainset)) 81 | print('Number of unlabel dataset: ',len(_unlabeled_trainset)) 82 | print('Number of test dataset: ',len(_labeled_testset)) 83 | 84 | logdir = os.path.join(args.saveroot, args.dataset, args.model, args.name) 85 | set_logging_defaults(logdir, args) 86 | logger = logging.getLogger('main') 87 | logname = os.path.join(logdir, 'log.csv') 88 | 89 | if tensorboardX_compat: 90 | writer = SummaryWriter(logdir=logdir) 91 | 92 | if use_cuda: 93 | torch.cuda.set_device(args.sgpu) 94 | print(torch.cuda.device_count()) 95 | print('Using CUDA..') 96 | 97 | class KDLoss(nn.Module): 98 | def __init__(self): 99 | super(KDLoss, self).__init__() 100 | self.kl_div = nn.KLDivLoss(reduction="none") 101 | 102 | def forward(self, input, target): 103 | log_p = torch.log_softmax(input, dim=1) 104 | q = target 105 | loss = self.kl_div(log_p, q) 106 | return loss 107 | 108 | criterion = nn.CrossEntropyLoss() 109 | criterion_none = nn.CrossEntropyLoss(reduction='none') 110 | loss_kl = KDLoss() 111 | 112 | def cycle(iterable): 113 | while True: 114 | for x in iterable: 115 | yield x 116 | 117 | def ema_train(): 118 | # Model 119 | print('==> Building model: {}'.format(args.model)) 120 | net = models.load_model(args.model, _labeled_num_class) 121 | 122 | if args.finetune: 123 | model_dict = net.state_dict() 124 | if (args.model in ['resnet50', 'resnet50_auxbn']): 125 | try: 126 | pretrained_dict = torch.load(args.model_path, map_location='cpu')['model'] 127 | except KeyError: 128 | pretrained_dict = torch.load(args.model_path, map_location='cpu')['net'] 129 | classifier = ['fc.weight', 'fc.bias', 'linear_rot.weight', 'linear_rot.bias'] 130 | imagesize = 224 131 | elif (args.model in ['CIFAR_ResNet50', 'CIFAR_ResNet50_AuxBN', 'wide_resnet', 'wide_resnet_auxbn']): 132 | try: 133 | pretrained_dict = torch.load(args.model_path, map_location='cpu')['model'] 134 | except KeyError: 135 | pretrained_dict = torch.load(args.model_path, map_location='cpu')['net'] 136 | classifier = ['linear.weight', 'linear.bias', 'linear_rot.weight', 'linear_rot.bias'] 137 | imagesize = 32 138 | new_state_dict = OrderedDict() 139 | for k, v in pretrained_dict.items(): 140 | if k[:6]=='module': 141 | name = k[7:] # remove `module.` 142 | else: 143 | name = k 144 | new_state_dict[name] = v 145 | new_state_dict = {k: v for k, v in new_state_dict.items() if k in model_dict and k not in classifier} 146 | model_dict.update(new_state_dict) 147 | net.load_state_dict(model_dict) 148 | 149 | net_ema = copy.deepcopy(net) 150 | for param in net_ema.parameters(): 151 | param.detach_() 152 | 153 | net.cuda() 154 | net_ema.cuda() 155 | print(' Total params: %.2fM' % (sum(p.numel() for p in net.parameters())/1000000.0)) 156 | # print(net) 157 | if args.ngpu > 1: 158 | net = torch.nn.DataParallel(net, device_ids=list(range(args.sgpu, args.sgpu + args.ngpu))) 159 | net_ema = torch.nn.DataParallel(net_ema, device_ids=list(range(args.sgpu, args.sgpu + args.ngpu))) 160 | 161 | 162 | if args.simclr_optim: 163 | assert (not args.fix_optim) 164 | #args.lr = 0.05 * float(args.batch_size) / 256 165 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=0, nesterov=True) 166 | elif args.fix_optim: 167 | args.lr = args.lr * args.mu / 7. 168 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.decay, nesterov=True) 169 | else: 170 | optimizer = optim.Adam(net.parameters(), lr=args.lr) # weight decay in ema_optimizer 171 | 172 | ema_optimizer= WeightEMA(net, net_ema, alpha=0.999, wd=(not args.fix_optim and not args.simclr_optim)) 173 | 174 | net.train() 175 | net_ema.train() 176 | 177 | if len(_labeled_trainset) < args.batch_size: 178 | rand_sampler = torch.utils.data.RandomSampler(_labeled_trainset, num_samples=args.batch_size, replacement=True) 179 | _labeled_trainloader = torch.utils.data.DataLoader(_labeled_trainset, batch_size=args.batch_size, sampler=rand_sampler, num_workers=0) 180 | else: 181 | _labeled_trainloader = torch.utils.data.DataLoader(_labeled_trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.nworkers, drop_last=True) 182 | _labeled_testloader = torch.utils.data.DataLoader(_labeled_testset, batch_size=args.batch_size, shuffle=False, num_workers=4) 183 | _labeled_train_iter = iter(cycle(_labeled_trainloader)) 184 | _unlabeled_trainloader = torch.utils.data.DataLoader(_unlabeled_trainset, batch_size=args.batch_size * args.mu, shuffle=True, num_workers=args.nworkers, drop_last=True) 185 | _unlabeled_train_iter = iter(cycle(_unlabeled_trainloader)) 186 | train_loss = 0 187 | correct = 0 188 | total = 0 189 | 190 | run_iters = args.num_iters if args.stop_iters is None else args.stop_iters 191 | for batch_idx in range(start_iters, run_iters + 1): 192 | (inputs, inputs_aug), targets = next(_labeled_train_iter) 193 | (inputs_o, inputs_o2), targets_u = next(_unlabeled_train_iter) 194 | 195 | if use_cuda: 196 | inputs = inputs.cuda() 197 | targets = targets.cuda() 198 | inputs_o = inputs_o.cuda() 199 | inputs_o2 = inputs_o2.cuda() 200 | 201 | inputs_total = torch.cat([inputs, inputs_o, inputs_o2], dim=0) 202 | inputs_total = list(torch.split(inputs_total, args.batch_size)) 203 | inputs_total = interleave(inputs_total, args.batch_size) 204 | 205 | logits = [net(inputs_total[0])] 206 | for input in inputs_total[1:]: 207 | logits.append(net(input)) 208 | 209 | # put interleaved samples back 210 | logits = interleave(logits, args.batch_size) 211 | 212 | outputs = logits[0] 213 | u_chunk = (len(logits)-1)//2 214 | outputs_o = torch.cat(logits[1:u_chunk+1], dim=0) 215 | outputs_o2 = torch.cat(logits[u_chunk+1:], dim=0) 216 | 217 | with torch.no_grad(): 218 | outputs_o.detach_() 219 | pred_o = torch.softmax(outputs_o, dim=1) 220 | pred_max = pred_o.max(1) 221 | mask = (pred_max[0] >= args.ths_pred) 222 | targets_pseudo = pred_max[1] 223 | 224 | loss = criterion(outputs, targets) + args.lmd_u * torch.mean(criterion_none(outputs_o2, targets_pseudo.detach()) * mask) 225 | 226 | train_loss += loss.item() 227 | 228 | optimizer.zero_grad() 229 | loss.backward() 230 | optimizer.step() 231 | ema_optimizer.step() 232 | if args.fix_optim: 233 | adjust_learning_rate(optimizer, batch_idx+1) 234 | 235 | if batch_idx % 1000 == 0: 236 | if batch_idx // 1000 > (run_iters // 1000) - 5: 237 | median = True 238 | else: 239 | median = False 240 | logger = logging.getLogger('train') 241 | logger.info('[Iters {}] [Loss {:.3f}]'.format( 242 | batch_idx, 243 | train_loss/1000)) 244 | print('[Iters {}] [Loss {:.3f}]'.format( 245 | batch_idx, 246 | train_loss/1000)) 247 | if tensorboardX_compat: 248 | writer.add_scalar("training/loss", train_loss/1000, batch_idx+1) 249 | 250 | train_loss = 0 251 | ema_optimizer.step(bn=True) 252 | save = val(net, batch_idx, _labeled_testloader, median=median) 253 | if save: 254 | checkpoint(net, optimizer, best_val, batch_idx) 255 | save = val(net_ema, batch_idx, _labeled_testloader, ema=True, median=median) 256 | if save: 257 | checkpoint(net_ema, optimizer, best_val_ema, batch_idx, ema=True) 258 | net.train() 259 | net_ema.train() 260 | else: 261 | progress_bar(batch_idx % 1000, 1000, 'working...') 262 | 263 | checkpoint(net, optimizer, current_val, args.num_iters, last=True) 264 | checkpoint(net_ema, optimizer, current_val_ema, args.num_iters, ema=True, last=True) 265 | 266 | def interleave_offsets(batch, nu): 267 | groups = [batch // (nu + 1)] * (nu + 1) 268 | for x in range(batch - sum(groups)): 269 | groups[-x - 1] += 1 270 | offsets = [0] 271 | for g in groups: 272 | offsets.append(offsets[-1] + g) 273 | assert offsets[-1] == batch 274 | return offsets 275 | 276 | def interleave(xy, batch): 277 | nu = len(xy) - 1 278 | offsets = interleave_offsets(batch, nu) 279 | xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy] 280 | for i in range(1, nu + 1): 281 | xy[0][i], xy[i][i] = xy[i][i], xy[0][i] 282 | return [torch.cat(v, dim=0) for v in xy] 283 | 284 | median_acc = [] 285 | median_acc_ema = [] 286 | 287 | def val(net, iters, testloader, ema=False, median=False): 288 | global best_val 289 | global best_val_ema 290 | global median_acc 291 | global median_acc_ema 292 | global current_val 293 | global current_val_ema 294 | net.eval() 295 | val_loss = 0.0 296 | correct = 0.0 297 | total = 0.0 298 | 299 | with torch.no_grad(): 300 | for batch_idx, (inputs, targets) in enumerate(testloader): 301 | if use_cuda: 302 | inputs, targets = inputs.cuda(), targets.cuda() 303 | 304 | outputs = net(inputs) 305 | loss = torch.mean(criterion(outputs, targets)) 306 | val_loss += loss.item() 307 | _, predicted = torch.max(outputs, 1) 308 | total += targets.size(0) 309 | correct += predicted.eq(targets.data).cpu().sum().float() 310 | progress_bar(batch_idx, len(testloader), 311 | 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 312 | % (val_loss/(batch_idx+1), 100.*correct/total, correct, total)) 313 | 314 | logger = logging.getLogger('test') 315 | logger.info('[Loss {:.3f}] [Acc {:.3f}]'.format( 316 | val_loss/(batch_idx+1), 100.*correct/total)) 317 | 318 | acc = 100.*correct/total 319 | 320 | if ema: 321 | if median: 322 | median_acc_ema.append(acc.item()) 323 | if tensorboardX_compat: 324 | writer.add_scalar("validation/ema_loss", val_loss/(batch_idx+1), iters+1) 325 | writer.add_scalar("validation/ema_top1_acc", acc, iters+1) 326 | current_val_ema = acc 327 | if acc > best_val_ema: 328 | best_val_ema = acc 329 | return True 330 | else: 331 | return False 332 | else: 333 | if median: 334 | median_acc.append(acc.item()) 335 | if tensorboardX_compat: 336 | writer.add_scalar("validation/loss", val_loss/(batch_idx+1), iters+1) 337 | writer.add_scalar("validation/top1_acc", acc, iters+1) 338 | current_val = acc 339 | if acc > best_val: 340 | best_val = acc 341 | return True 342 | else: 343 | return False 344 | 345 | def checkpoint(net, optimizer, acc, iters, ema=False, last=False): 346 | # Save checkpoint. 347 | print('Saving..') 348 | state = { 349 | 'net': net.state_dict(), 350 | 'optimizer': optimizer.state_dict(), 351 | 'acc': acc, 352 | 'iters': iters, 353 | 'rng_state': torch.get_rng_state() 354 | } 355 | if ema: 356 | torch.save(state, os.path.join(logdir, 'ema_ckpt.t7' if (not last) else 'last_ema_ckpt.t7')) 357 | else: 358 | torch.save(state, os.path.join(logdir, 'ckpt.t7' if (not last) else 'last_ckpt.t7')) 359 | 360 | def adjust_learning_rate(optimizer, iters): 361 | """decrease the learning rate""" 362 | lr = args.lr * np.cos(iters/(args.num_iters+1) * (7 * np.pi) / (2 * 8)) 363 | for param_group in optimizer.param_groups: 364 | param_group['lr'] = lr 365 | 366 | class WeightEMA(object): 367 | def __init__(self, model, ema_model, alpha=0.999, wd=False): 368 | self.model = model 369 | self.ema_model = ema_model 370 | self.alpha = alpha 371 | self.tmp_model = models.load_model(args.model, _labeled_num_class) 372 | self.wd = 0.02 * args.lr if wd else 0 373 | 374 | for param, ema_param in zip(self.model.parameters(), self.ema_model.parameters()): 375 | ema_param.data.copy_(param.data) 376 | 377 | def step(self, bn=False): 378 | if bn: 379 | # copy batchnorm stats to ema model 380 | for ema_param, tmp_param in zip(self.ema_model.parameters(), self.tmp_model.parameters()): 381 | tmp_param.data.copy_(ema_param.data.detach()) 382 | 383 | self.ema_model.load_state_dict(self.model.state_dict()) 384 | 385 | for ema_param, tmp_param in zip(self.ema_model.parameters(), self.tmp_model.parameters()): 386 | ema_param.data.copy_(tmp_param.data.detach()) 387 | else: 388 | one_minus_alpha = 1.0 - self.alpha 389 | for param, ema_param in zip(self.model.parameters(), self.ema_model.parameters()): 390 | ema_param.data.mul_(self.alpha) 391 | ema_param.data.add_(param.data.detach() * one_minus_alpha) 392 | # customized weight decay 393 | param.data.mul_(1 - self.wd) 394 | 395 | if args.ema: 396 | ema_train() 397 | 398 | print("Best Accuracy : {}".format(best_val)) 399 | print("Best Accuracy EMA : {}".format(best_val_ema)) 400 | print("Median Accuracy : {}".format(np.median(median_acc))) 401 | print("Median Accuracy EMA : {}".format(np.median(median_acc_ema))) 402 | logger = logging.getLogger('best') 403 | logger.info('[Acc {:.3f}] [EMA Acc {:.3f}] [MEDIAN Acc {:.3f}] [MEDIAN EMA Acc {:.3f}]'.format(best_val, best_val_ema, np.median(median_acc), np.median(median_acc_ema))) 404 | else: 405 | raise NotImplementedError 406 | 407 | if tensorboardX_compat: 408 | writer.close() -------------------------------------------------------------------------------- /OpenCoS/models/meta_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from torch.autograd import Variable 6 | import random 7 | 8 | 9 | def to_var(x, requires_grad=True): 10 | if torch.cuda.is_available(): 11 | x = x.cuda() 12 | return Variable(x, requires_grad=requires_grad) 13 | 14 | 15 | class MetaModule(nn.Module): 16 | # adopted from: Adrien Ecoffet https://github.com/AdrienLE 17 | def params(self): 18 | for name, param in self.named_params(self): 19 | yield param 20 | 21 | def named_leaves(self): 22 | return [] 23 | 24 | def named_submodules(self): 25 | return [] 26 | 27 | def named_params(self, curr_module=None, memo=None, prefix=''): 28 | if memo is None: 29 | memo = set() 30 | 31 | if hasattr(curr_module, 'named_leaves'): 32 | for name, p in curr_module.named_leaves(): 33 | if p is not None and p not in memo: 34 | memo.add(p) 35 | yield prefix + ('.' if prefix else '') + name, p 36 | else: 37 | for name, p in curr_module._parameters.items(): 38 | if p is not None and p not in memo: 39 | memo.add(p) 40 | yield prefix + ('.' if prefix else '') + name, p 41 | 42 | for mname, module in curr_module.named_children(): 43 | submodule_prefix = prefix + ('.' if prefix else '') + mname 44 | for name, p in self.named_params(module, memo, submodule_prefix): 45 | yield name, p 46 | 47 | def update_params(self, lr_inner, first_order=False, source_params=None, detach=False): 48 | if source_params is not None: 49 | for tgt, src in zip(self.named_params(self), source_params): 50 | name_t, param_t = tgt 51 | # name_s, param_s = src 52 | # grad = param_s.grad 53 | # name_s, param_s = src 54 | grad = src 55 | if first_order: 56 | grad = to_var(grad.detach().data) 57 | tmp = param_t - lr_inner * grad 58 | self.set_param(self, name_t, tmp) 59 | else: 60 | 61 | for name, param in self.named_params(self): 62 | if not detach: 63 | grad = param.grad 64 | if first_order: 65 | grad = to_var(grad.detach().data) 66 | tmp = param - lr_inner * grad 67 | self.set_param(self, name, tmp) 68 | else: 69 | param = param.detach_() # https://blog.csdn.net/qq_39709535/article/details/81866686 70 | self.set_param(self, name, param) 71 | 72 | def set_param(self, curr_mod, name, param): 73 | if '.' in name: 74 | n = name.split('.') 75 | module_name = n[0] 76 | rest = '.'.join(n[1:]) 77 | for name, mod in curr_mod.named_children(): 78 | if module_name == name: 79 | self.set_param(mod, rest, param) 80 | break 81 | else: 82 | setattr(curr_mod, name, param) 83 | 84 | def detach_params(self): 85 | for name, param in self.named_params(self): 86 | self.set_param(self, name, param.detach()) 87 | 88 | def copy(self, other, same_var=False): 89 | for name, param in other.named_params(): 90 | if not same_var: 91 | param = to_var(param.data.clone(), requires_grad=True) 92 | self.set_param(name, param) 93 | 94 | 95 | class MetaLinear(MetaModule): 96 | def __init__(self, *args, **kwargs): 97 | super().__init__() 98 | ignore = nn.Linear(*args, **kwargs) 99 | 100 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 101 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 102 | 103 | def forward(self, x): 104 | return F.linear(x, self.weight, self.bias) 105 | 106 | def named_leaves(self): 107 | return [('weight', self.weight), ('bias', self.bias)] 108 | 109 | 110 | class MetaConv2d(MetaModule): 111 | def __init__(self, *args, **kwargs): 112 | super().__init__() 113 | ignore = nn.Conv2d(*args, **kwargs) 114 | 115 | self.in_channels = ignore.in_channels 116 | self.out_channels = ignore.out_channels 117 | self.stride = ignore.stride 118 | self.padding = ignore.padding 119 | self.dilation = ignore.dilation 120 | self.groups = ignore.groups 121 | self.kernel_size = ignore.kernel_size 122 | 123 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 124 | 125 | if ignore.bias is not None: 126 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 127 | else: 128 | self.register_buffer('bias', None) 129 | 130 | def forward(self, x): 131 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 132 | 133 | def named_leaves(self): 134 | return [('weight', self.weight), ('bias', self.bias)] 135 | 136 | 137 | class MetaConvTranspose2d(MetaModule): 138 | def __init__(self, *args, **kwargs): 139 | super().__init__() 140 | ignore = nn.ConvTranspose2d(*args, **kwargs) 141 | 142 | self.stride = ignore.stride 143 | self.padding = ignore.padding 144 | self.dilation = ignore.dilation 145 | self.groups = ignore.groups 146 | 147 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 148 | 149 | if ignore.bias is not None: 150 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 151 | else: 152 | self.register_buffer('bias', None) 153 | 154 | def forward(self, x, output_size=None): 155 | output_padding = self._output_padding(x, output_size) 156 | return F.conv_transpose2d(x, self.weight, self.bias, self.stride, self.padding, 157 | output_padding, self.groups, self.dilation) 158 | 159 | def named_leaves(self): 160 | return [('weight', self.weight), ('bias', self.bias)] 161 | 162 | 163 | class MetaBatchNorm2d(MetaModule): 164 | def __init__(self, *args, **kwargs): 165 | super().__init__() 166 | ignore = nn.BatchNorm2d(*args, **kwargs) 167 | 168 | self.num_features = ignore.num_features 169 | self.eps = ignore.eps 170 | self.momentum = ignore.momentum 171 | self.affine = ignore.affine 172 | self.track_running_stats = ignore.track_running_stats 173 | 174 | self.update_batch_stats = True 175 | 176 | if self.affine: 177 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 178 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 179 | 180 | if self.track_running_stats: 181 | self.register_buffer('running_mean', torch.zeros(self.num_features)) 182 | self.register_buffer('running_var', torch.ones(self.num_features)) 183 | else: 184 | self.register_parameter('running_mean', None) 185 | self.register_parameter('running_var', None) 186 | 187 | def forward(self, x): 188 | #if self.update_batch_stats: 189 | # return super().forward(x) 190 | #else: 191 | return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 192 | self.training or not self.track_running_stats, self.momentum, self.eps) 193 | 194 | def named_leaves(self): 195 | return [('weight', self.weight), ('bias', self.bias)] 196 | 197 | 198 | 199 | class CIFAR_Bottleneck(MetaModule): 200 | expansion = 4 201 | 202 | def __init__(self, in_planes, planes, stride=1): 203 | super(CIFAR_Bottleneck, self).__init__() 204 | self.conv1 = MetaConv2d(in_planes, planes, kernel_size=1, bias=False) 205 | self.bn1 = MetaBatchNorm2d(planes) 206 | self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 207 | self.bn2 = MetaBatchNorm2d(planes) 208 | self.conv3 = MetaConv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 209 | self.bn3 = MetaBatchNorm2d(self.expansion*planes) 210 | 211 | self.shortcut = nn.Sequential() 212 | if stride != 1 or in_planes != self.expansion*planes: 213 | self.shortcut = nn.Sequential( 214 | MetaConv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 215 | MetaBatchNorm2d(self.expansion*planes) 216 | ) 217 | 218 | def forward(self, x): 219 | out = F.relu(self.bn1(self.conv1(x))) 220 | out = F.relu(self.bn2(self.conv2(out))) 221 | out = self.bn3(self.conv3(out)) 222 | out += self.shortcut(x) 223 | out = F.relu(out) 224 | return out 225 | 226 | 227 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 228 | """3x3 convolution with padding""" 229 | return MetaConv2d(in_planes, out_planes, kernel_size=3, stride=stride, 230 | padding=1, groups=groups, bias=False) 231 | 232 | 233 | class Meta_CIFARResNet(MetaModule): 234 | def __init__(self, block, num_blocks, num_classes=10, bias=True): 235 | super(Meta_CIFARResNet, self).__init__() 236 | self.in_planes = 64 237 | self.conv1 = conv3x3(3,64) 238 | self.bn1 = MetaBatchNorm2d(64) 239 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 240 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 241 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 242 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 243 | self.linear = MetaLinear(512*block.expansion, num_classes, bias=bias) 244 | # self.linear_rot = MetaLinear(512*block.expansion, 4, bias=bias) 245 | 246 | 247 | def _make_layer(self, block, planes, num_blocks, stride): 248 | strides = [stride] + [1]*(num_blocks-1) 249 | layers = [] 250 | for stride in strides: 251 | layers.append(block(self.in_planes, planes, stride)) 252 | self.in_planes = planes * block.expansion 253 | return nn.Sequential(*layers) 254 | 255 | def forward(self, x, feature=False, aux=False): 256 | out = x 257 | out = self.conv1(out) 258 | out = self.bn1(out) 259 | out = F.relu(out) 260 | out1 = self.layer1(out) 261 | out2 = self.layer2(out1) 262 | out3 = self.layer3(out2) 263 | out = self.layer4(out3) 264 | out = F.avg_pool2d(out, 4) 265 | out4 = out.view(out.size(0), -1) 266 | out = self.linear(out4) 267 | if feature: 268 | return out, out4 269 | else: 270 | return out 271 | 272 | class WNet(MetaModule): 273 | def __init__(self, input, hidden, output): 274 | super(WNet, self).__init__() 275 | self.linear1 = MetaLinear(input, hidden) 276 | self.relu = nn.ReLU(inplace=True) 277 | self.linear2 = MetaLinear(hidden, output) 278 | 279 | def forward(self, x): 280 | x = self.linear1(x) 281 | x = self.relu(x) 282 | out = self.linear2(x) 283 | return torch.sigmoid(out) 284 | 285 | def Meta_ResNet50(pretrained=False, **kwargs): 286 | return Meta_CIFARResNet(CIFAR_Bottleneck, [3,4,6,3], **kwargs) 287 | 288 | 289 | class WideBasicBlock(MetaModule): 290 | expansion = 1 291 | 292 | def __init__(self, in_planes, planes, stride=1): 293 | super(WideBasicBlock, self).__init__() 294 | self.bn1 = MetaBatchNorm2d(in_planes) 295 | self.conv1 = MetaConv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 296 | self.bn2 = MetaBatchNorm2d(planes) 297 | self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 298 | 299 | self.shortcut = nn.Sequential() 300 | if stride != 1 or in_planes != self.expansion*planes: 301 | self.shortcut = nn.Sequential( 302 | MetaConv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 303 | ) 304 | 305 | def forward(self, x): 306 | o1 = F.leaky_relu(self.bn1(x), 0.1) 307 | y = self.conv1(o1) 308 | o2 = F.leaky_relu(self.bn2(y), 0.1) 309 | z = self.conv2(o2) 310 | if len(self.shortcut)==0: 311 | return z + x 312 | else: 313 | return z + self.shortcut(o1) 314 | 315 | 316 | class MetaWideResNet(MetaModule): 317 | """ WRN28-width with leaky relu (negative slope is 0.1)""" 318 | def __init__(self, block, depth, width, num_classes): 319 | super(MetaWideResNet , self).__init__() 320 | self.in_planes = 16 321 | 322 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 323 | n = (depth - 4) // 6 324 | widths = [int(v * width) for v in (16, 32, 64)] 325 | 326 | self.conv1 = MetaConv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 327 | self.layer1 = self._make_layer(block, widths[0], n, stride=1) 328 | self.layer2 = self._make_layer(block, widths[1], n, stride=2) 329 | self.layer3 = self._make_layer(block, widths[2], n, stride=2) 330 | self.bn1 = MetaBatchNorm2d(widths[2]) 331 | self.linear = MetaLinear(widths[2]*block.expansion, num_classes) 332 | # self.linear_rot = MetaLinear(widths[2]*block.expansion, 4) 333 | 334 | for m in self.modules(): 335 | if isinstance(m, nn.Linear): 336 | nn.init.kaiming_normal_(m.weight) 337 | nn.init.constant_(m.bias, 0) 338 | elif isinstance(m, nn.Conv2d): 339 | nn.init.kaiming_normal_(m.weight) 340 | elif isinstance(m, nn.BatchNorm2d): 341 | nn.init.uniform_(m.weight) 342 | nn.init.constant_(m.bias, 0) 343 | nn.init.constant_(m.running_mean, 0) 344 | nn.init.constant_(m.running_var, 1) 345 | 346 | def _make_layer(self, block, planes, num_blocks, stride): 347 | strides = [stride] + [1]*(num_blocks-1) 348 | layers = [] 349 | for stride in strides: 350 | layers.append(block(self.in_planes, planes, stride)) 351 | self.in_planes = planes * block.expansion 352 | return nn.Sequential(*layers) 353 | 354 | def forward(self, x, feature=False, blocks=False, aux=False): 355 | f0 = self.conv1(x) 356 | f1 = self.layer1(f0) 357 | f2 = self.layer2(f1) 358 | f3 = self.layer3(f2) 359 | out = F.leaky_relu(self.bn1(f3), 0.1) 360 | out = F.avg_pool2d(out, 8) 361 | out4 = out.view(out.size(0), -1) 362 | out = self.linear(out4) 363 | if blocks: 364 | return out, [f1,f2,f3] 365 | elif feature: 366 | return out, out4 367 | else: 368 | return out 369 | 370 | def meta_wide_resnet(depth, width, num_classes=10): 371 | return MetaWideResNet(WideBasicBlock, depth, width, num_classes) 372 | 373 | def conv1x1(in_planes, out_planes, stride=1): 374 | """1x1 convolution""" 375 | return MetaConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 376 | 377 | class Bottleneck(MetaModule): 378 | expansion = 4 379 | 380 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 381 | base_width=64, norm_layer=None): 382 | super(Bottleneck, self).__init__() 383 | if norm_layer is None: 384 | norm_layer = MetaBatchNorm2d 385 | width = int(planes * (base_width / 64.)) * groups 386 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 387 | self.conv1 = conv1x1(inplanes, width) 388 | self.bn1 = norm_layer(width) 389 | self.conv2 = conv3x3(width, width, stride, groups) 390 | self.bn2 = norm_layer(width) 391 | self.conv3 = conv1x1(width, planes * self.expansion) 392 | self.bn3 = norm_layer(planes * self.expansion) 393 | self.relu = nn.ReLU(inplace=True) 394 | self.downsample = downsample 395 | self.stride = stride 396 | 397 | def forward(self, x): 398 | identity = x 399 | 400 | out = self.conv1(x) 401 | out = self.bn1(out) 402 | out = self.relu(out) 403 | 404 | out = self.conv2(out) 405 | out = self.bn2(out) 406 | out = self.relu(out) 407 | 408 | out = self.conv3(out) 409 | out = self.bn3(out) 410 | 411 | if self.downsample is not None: 412 | identity = self.downsample(x) 413 | 414 | out += identity 415 | out = self.relu(out) 416 | 417 | return out 418 | 419 | 420 | class Meta_ResNet(MetaModule): 421 | 422 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 423 | groups=1, width_per_group=64, norm_layer=None, bias=True): 424 | super(Meta_ResNet, self).__init__() 425 | if norm_layer is None: 426 | norm_layer = MetaBatchNorm2d 427 | 428 | self.inplanes = 64 429 | self.groups = groups 430 | self.base_width = width_per_group 431 | self.conv1 = MetaConv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 432 | bias=False) 433 | self.bn1 = norm_layer(self.inplanes) 434 | self.relu = nn.ReLU(inplace=True) 435 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 436 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) 437 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 438 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer) 439 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer) 440 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 441 | self.fc = MetaLinear(512 * block.expansion, num_classes, bias=bias) 442 | # self.linear_rot = nn.Linear(512*block.expansion, 4, bias=bias) 443 | for m in self.modules(): 444 | if isinstance(m, MetaConv2d): 445 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 446 | elif isinstance(m, (MetaBatchNorm2d, nn.GroupNorm)): 447 | nn.init.constant_(m.weight, 1) 448 | nn.init.constant_(m.bias, 0) 449 | 450 | # Zero-initialize the last BN in each residual branch, 451 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 452 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 453 | if zero_init_residual: 454 | for m in self.modules(): 455 | if isinstance(m, Bottleneck): 456 | nn.init.constant_(m.bn3.weight, 0) 457 | elif isinstance(m, BasicBlock): 458 | nn.init.constant_(m.bn2.weight, 0) 459 | 460 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None): 461 | if norm_layer is None: 462 | norm_layer = MetaBatchNorm2d 463 | downsample = None 464 | if stride != 1 or self.inplanes != planes * block.expansion: 465 | downsample = nn.Sequential( 466 | conv1x1(self.inplanes, planes * block.expansion, stride), 467 | norm_layer(planes * block.expansion), 468 | ) 469 | 470 | layers = [] 471 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 472 | self.base_width, norm_layer)) 473 | self.inplanes = planes * block.expansion 474 | for _ in range(1, blocks): 475 | layers.append(block(self.inplanes, planes, groups=self.groups, 476 | base_width=self.base_width, norm_layer=norm_layer)) 477 | 478 | return nn.Sequential(*layers) 479 | 480 | def forward(self, x, feature=False, aux=False): 481 | x = self.conv1(x) 482 | x = self.bn1(x) 483 | x = self.relu(x) 484 | x = self.maxpool(x) 485 | 486 | x = self.layer1(x) 487 | x = self.layer2(x) 488 | x = self.layer3(x) 489 | x = self.layer4(x) 490 | 491 | x = self.avgpool(x) 492 | x = x.view(x.size(0), -1) 493 | o = self.fc(x) 494 | 495 | if feature: 496 | return o, x 497 | else: 498 | return o 499 | 500 | def meta_resnet50(pretrained=False, **kwargs): 501 | model = Meta_ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 502 | return model 503 | --------------------------------------------------------------------------------