├── models ├── __init__.py └── resnet.py ├── utils ├── __init__.py ├── ramps.py └── util.py ├── data ├── __init__.py ├── concat.py ├── omniglot.py ├── utils.py ├── imagenetloader.py ├── rotationloader.py ├── svhnloader.py ├── omniglotloader.py └── cifarloader.py ├── asset └── splash.png ├── .gitignore ├── scripts ├── download_pretrained_models.sh ├── download_imagenet_splits.sh ├── auto_novel_svhn.sh ├── auto_novel_IL_svhn.sh ├── auto_novel_cifar10.sh ├── auto_novel_IL_cifar10.sh ├── auto_novel_cifar100.sh └── auto_novel_IL_cifar100.sh ├── environment.yml ├── supervised_learning.py ├── selfsupervised_learning.py ├── README.md ├── auto_novel_omniglot.py ├── auto_novel_imagenet.py └── auto_novel.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /asset/splash.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/k-han/AutoNovel/HEAD/asset/splash.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | datasets 3 | experiments 4 | 5 | *.pyc 6 | *.txt 7 | *.pth 8 | *.mat 9 | *.out 10 | *.rar 11 | *.tar 12 | .DS_Store 13 | -------------------------------------------------------------------------------- /scripts/download_pretrained_models.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | path="data/experiments/" 4 | mkdir -p $path 5 | cd $path 6 | 7 | wget http://www.robots.ox.ac.uk/~vgg/research/auto_novel/asset/pretrained.zip 8 | 9 | unzip pretrained.zip && rm pretrained.zip 10 | 11 | cd ../ -------------------------------------------------------------------------------- /scripts/download_imagenet_splits.sh: -------------------------------------------------------------------------------- 1 | path="data/datasets/ImageNet/imagenet_rand118/" 2 | mkdir -p $path 3 | cd $path 4 | 5 | for file in "imagenet_118.txt" "imagenet_30_A.txt" "imagenet_30_B.txt" "imagenet_30_C.txt"; do 6 | wget http://www.robots.ox.ac.uk/~vgg/research/DTC/data/datasets/ImageNet/imagenet_rand118/${file} 7 | done 8 | 9 | cd ../ 10 | -------------------------------------------------------------------------------- /scripts/auto_novel_svhn.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python auto_novel.py \ 4 | --dataset_root $1 \ 5 | --exp_root $2 \ 6 | --warmup_model_dir $3 \ 7 | --lr 0.1 \ 8 | --step_size 170 \ 9 | --batch_size 128 \ 10 | --epochs 200 \ 11 | --rampup_length 80 \ 12 | --rampup_coefficient 50 \ 13 | --seed 0 \ 14 | --dataset_name svhn \ 15 | --model_name resnet_svhn \ 16 | --mode train -------------------------------------------------------------------------------- /scripts/auto_novel_IL_svhn.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python auto_novel.py \ 4 | --dataset_root $1 \ 5 | --exp_root $2 \ 6 | --warmup_model_dir $3 \ 7 | --lr 0.1 \ 8 | --step_size 170 \ 9 | --batch_size 128 \ 10 | --epochs 200 \ 11 | --rampup_length 80 \ 12 | --rampup_coefficient 50 \ 13 | --IL \ 14 | --increment_coefficient 0.05 \ 15 | --seed 0 \ 16 | --dataset_name svhn \ 17 | --model_name resnet_IL_svhn \ 18 | --mode train -------------------------------------------------------------------------------- /scripts/auto_novel_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python auto_novel.py \ 4 | --dataset_root $1 \ 5 | --exp_root $2 \ 6 | --warmup_model_dir $3 \ 7 | --lr 0.1 \ 8 | --gamma 0.1 \ 9 | --weight_decay 1e-4 \ 10 | --step_size 170 \ 11 | --batch_size 128 \ 12 | --epochs 200 \ 13 | --rampup_length 50 \ 14 | --rampup_coefficient 5.0 \ 15 | --dataset_name cifar10 \ 16 | --seed 0 \ 17 | --model_name resnet_cifar10 \ 18 | --mode train -------------------------------------------------------------------------------- /scripts/auto_novel_IL_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python auto_novel.py \ 4 | --dataset_root $1 \ 5 | --exp_root $2 \ 6 | --warmup_model_dir $3 \ 7 | --lr 0.1 \ 8 | --gamma 0.1 \ 9 | --weight_decay 1e-4 \ 10 | --step_size 170 \ 11 | --batch_size 128 \ 12 | --epochs 200 \ 13 | --rampup_length 50 \ 14 | --rampup_coefficient 5.0 \ 15 | --dataset_name cifar10 \ 16 | --IL \ 17 | --increment_coefficient 0.05 \ 18 | --seed 0 \ 19 | --model_name resnet_IL_cifar10 \ 20 | --mode train -------------------------------------------------------------------------------- /scripts/auto_novel_cifar100.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python auto_novel.py \ 4 | --dataset_root $1 \ 5 | --exp_root $2 \ 6 | --warmup_model_dir $3 \ 7 | --lr 0.1 \ 8 | --gamma 0.1 \ 9 | --weight_decay 1e-4 \ 10 | --step_size 170 \ 11 | --batch_size 128 \ 12 | --epochs 200 \ 13 | --rampup_length 150 \ 14 | --rampup_coefficient 50 \ 15 | --num_labeled_classes 80 \ 16 | --num_unlabeled_classes 20 \ 17 | --dataset_name cifar100 \ 18 | --seed 0 \ 19 | --model_name resnet_cifar100 \ 20 | --mode train -------------------------------------------------------------------------------- /scripts/auto_novel_IL_cifar100.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python auto_novel.py \ 4 | --dataset_root $1 \ 5 | --exp_root $2 \ 6 | --warmup_model_dir $3 \ 7 | --lr 0.1 \ 8 | --gamma 0.1 \ 9 | --weight_decay 1e-4 \ 10 | --step_size 340 \ 11 | --batch_size 256 \ 12 | --epochs 400 \ 13 | --rampup_length 300 \ 14 | --rampup_coefficient 25 \ 15 | --num_labeled_classes 80 \ 16 | --num_unlabeled_classes 20 \ 17 | --dataset_name cifar100 \ 18 | --IL \ 19 | --increment_coefficient 0.05 \ 20 | --seed 0 \ 21 | --model_name resnet_IL_cifar100 \ 22 | --mode train -------------------------------------------------------------------------------- /utils/ramps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Functions for ramping hyperparameters up or down 9 | 10 | Each function takes the current training step or epoch, and the 11 | ramp length in the same format, and returns a multiplier between 12 | 0 and 1. 13 | """ 14 | 15 | 16 | import numpy as np 17 | 18 | 19 | def sigmoid_rampup(current, rampup_length): 20 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 21 | if rampup_length == 0: 22 | return 1.0 23 | else: 24 | current = np.clip(current, 0.0, rampup_length) 25 | phase = 1.0 - current / rampup_length 26 | return float(np.exp(-5.0 * phase * phase)) 27 | 28 | 29 | def linear_rampup(current, rampup_length): 30 | """Linear rampup""" 31 | assert current >= 0 and rampup_length >= 0 32 | if current >= rampup_length: 33 | return 1.0 34 | else: 35 | return current / rampup_length 36 | 37 | 38 | def cosine_rampdown(current, rampdown_length): 39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 40 | assert 0 <= current <= rampdown_length 41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 42 | -------------------------------------------------------------------------------- /data/concat.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import warnings 3 | 4 | from torch._utils import _accumulate 5 | from torch import randperm 6 | 7 | 8 | class Dataset(object): 9 | """An abstract class representing a Dataset. 10 | 11 | All other datasets should subclass it. All subclasses should override 12 | ``__len__``, that provides the size of the dataset, and ``__getitem__``, 13 | supporting integer indexing in range from 0 to len(self) exclusive. 14 | """ 15 | 16 | def __getitem__(self, index): 17 | raise NotImplementedError 18 | 19 | def __len__(self): 20 | raise NotImplementedError 21 | 22 | def __add__(self, other): 23 | return ConcatDataset([self, other]) 24 | 25 | class ConcatDataset(Dataset): 26 | """ 27 | Dataset to concatenate multiple datasets. 28 | Purpose: useful to assemble different existing datasets, possibly 29 | large-scale datasets as the concatenation operation is done in an 30 | on-the-fly manner. 31 | 32 | Arguments: 33 | datasets (sequence): List of datasets to be concatenated 34 | """ 35 | 36 | @staticmethod 37 | def cumsum(sequence): 38 | r, s = [], 0 39 | for e in sequence: 40 | l = len(e) 41 | r.append(l + s) 42 | s += l 43 | return r 44 | 45 | def __init__(self, datasets): 46 | super(ConcatDataset, self).__init__() 47 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 48 | self.datasets = list(datasets) 49 | self.cumulative_sizes = self.cumsum(self.datasets) 50 | 51 | def __len__(self): 52 | return self.cumulative_sizes[-1] 53 | 54 | def __getitem__(self, idx): 55 | if idx < 0: 56 | if -idx > len(self): 57 | raise ValueError("absolute value of index should not exceed dataset length") 58 | idx = len(self) + idx 59 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 60 | if dataset_idx == 0: 61 | sample_idx = idx 62 | else: 63 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 64 | return (*self.datasets[dataset_idx][sample_idx][:-1], idx) 65 | 66 | @property 67 | def cummulative_sizes(self): 68 | warnings.warn("cummulative_sizes attribute is renamed to " 69 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 70 | return self.cumulative_sizes 71 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: auto_novel 2 | channels: 3 | - usgs-astrogeology 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - blas=1.0=openblas 10 | - ca-certificates=2019.11.27=0 11 | - certifi=2019.11.28=py36_0 12 | - cffi=1.12.3=py36h2e261b9_0 13 | - cuda90=1.0=h6433d27_0 14 | - cudatoolkit=10.0.130=0 15 | - cycler=0.10.0=py_1 16 | - dbus=1.13.6=he372182_0 17 | - expat=2.2.5=he1b5a44_1003 18 | - fontconfig=2.13.1=he4413a7_1000 19 | - freetype=2.10.0=he983fc9_1 20 | - gettext=0.19.8.1=hc5be6a0_1002 21 | - glib=2.58.3=h6f030ca_1002 22 | - gst-plugins-base=1.14.5=h0935bb2_0 23 | - gstreamer=1.14.5=h36ae1b5_0 24 | - icu=58.2=hf484d3e_1000 25 | - intel-openmp=2019.4=243 26 | - jpeg=9c=h14c3975_1001 27 | - kiwisolver=1.1.0=py36hc9558a2_0 28 | - libedit=3.1.20181209=hc058e9b_0 29 | - libffi=3.2.1=hd88cf55_4 30 | - libgcc-ng=9.1.0=hdf63c60_0 31 | - libgfortran-ng=7.3.0=hdf63c60_0 32 | - libiconv=1.15=h516909a_1005 33 | - libopenblas=0.2.20=h9ac9557_7 34 | - libpng=1.6.37=hed695b0_0 35 | - libstdcxx-ng=9.1.0=hdf63c60_0 36 | - libtiff=4.0.10=h2733197_2 37 | - libuuid=2.32.1=h14c3975_1000 38 | - libxcb=1.13=h14c3975_1002 39 | - libxml2=2.9.9=h13577e0_2 40 | - matplotlib=3.1.1=py36h5429711_0 41 | - mkl=2019.4=243 42 | - mkl-service=2.0.2=py36h7b6447c_0 43 | - ncurses=6.1=he6710b0_1 44 | - ninja=1.9.0=py36hfd86e86_0 45 | - numpy=1.14.2=py36_nomklh2b20989_1 46 | - olefile=0.46=py36_0 47 | - openssl=1.1.1d=h7b6447c_3 48 | - pandas=0.25.1=py36he6710b0_0 49 | - patsy=0.5.1=py36_0 50 | - pcre=8.41=hf484d3e_1003 51 | - pillow=6.1.0=py36h34e0f95_0 52 | - pip=19.2.2=py36_0 53 | - pthread-stubs=0.4=h14c3975_1001 54 | - pycparser=2.19=py36_0 55 | - pyparsing=2.4.2=py_0 56 | - pyqt=5.9.2=py36hcca6a23_4 57 | - python=3.6.9=h265db76_0 58 | - python-dateutil=2.8.0=py_0 59 | - pytorch=1.1.0=py3.6_cuda10.0.130_cudnn7.5.1_0 60 | - pytz=2019.2=py_0 61 | - qt=5.9.7=h52cfd70_2 62 | - readline=7.0=h7b6447c_5 63 | - scikit-learn=0.19.1=py36_nomklh6cfcb94_0 64 | - scipy=1.1.0=py36_nomklh9c1e066_0 65 | - seaborn=0.9.0=py36_0 66 | - setuptools=41.0.1=py36_0 67 | - sip=4.19.8=py36hf484d3e_1000 68 | - six=1.12.0=py36_0 69 | - sqlite=3.29.0=h7b6447c_0 70 | - statsmodels=0.10.1=py36hdd07704_0 71 | - tk=8.6.8=hbc83047_0 72 | - tnt=126=0 73 | - torchvision=0.3.0=py36_cu10.0.130_1 74 | - tornado=6.0.3=py36h516909a_0 75 | - tqdm=4.32.1=py_0 76 | - wheel=0.33.4=py36_0 77 | - xorg-libxau=1.0.9=h14c3975_0 78 | - xorg-libxdmcp=1.1.3=h516909a_0 79 | - xz=5.2.4=h14c3975_4 80 | - zlib=1.2.11=h7b6447c_3 81 | - zstd=1.3.7=h0b5b093_0 82 | - pip: 83 | - chardet==3.0.4 84 | - idna==2.8 85 | - jsonpatch==1.24 86 | - jsonpointer==2.0 87 | - pyzmq==18.1.1 88 | - requests==2.22.0 89 | - torchfile==0.1.0 90 | - torchnet==0.0.4 91 | - urllib3==1.25.7 92 | - visdom==0.1.8.9 93 | - websocket-client==0.57.0 -------------------------------------------------------------------------------- /data/omniglot.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | from os.path import join 4 | import os 5 | import torch.utils.data as data 6 | from .utils import download_url, check_integrity, list_dir, list_files 7 | 8 | class Omniglot(data.Dataset): 9 | """`Omniglot `_ Dataset. 10 | Args: 11 | root (string): Root directory of dataset where directory 12 | ``omniglot-py`` exists. 13 | background (bool, optional): If True, creates dataset from the "background" set, otherwise 14 | creates from the "evaluation" set. This terminology is defined by the authors. 15 | transform (callable, optional): A function/transform that takes in an PIL image 16 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 17 | target_transform (callable, optional): A function/transform that takes in the 18 | target and transforms it. 19 | download (bool, optional): If true, downloads the dataset zip files from the internet and 20 | puts it in root directory. If the zip files are already downloaded, they are not 21 | downloaded again. 22 | """ 23 | folder = 'omniglot-py' 24 | 25 | def __init__(self, root, subfolder_name='images_background', 26 | transform=None, target_transform=None): 27 | self.root = join(os.path.expanduser(root), self.folder) 28 | self.subfolder_name = subfolder_name 29 | self.transform = transform 30 | self.target_transform = target_transform 31 | 32 | self.target_folder = join(self.root, self.subfolder_name) 33 | self._alphabets = list_dir(self.target_folder) 34 | self._characters = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))] 35 | for a in self._alphabets], []) 36 | self._character_images = [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')] 37 | for idx, character in enumerate(self._characters)] 38 | self._flat_character_images = sum(self._character_images, []) 39 | 40 | def __len__(self): 41 | return len(self._flat_character_images) 42 | 43 | def __getitem__(self, index): 44 | """ 45 | Args: 46 | index (int): Index 47 | 48 | Returns: 49 | tuple: (image, target) where target is index of the target character class. 50 | """ 51 | image_name, character_class = self._flat_character_images[index] 52 | image_path = join(self.target_folder, self._characters[character_class], image_name) 53 | image = Image.open(image_path, mode='r').convert('L') 54 | 55 | if self.transform: 56 | image = self.transform(image) 57 | 58 | if self.target_transform: 59 | character_class = self.target_transform(character_class) 60 | 61 | return image, character_class, index -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torch.autograd import Variable 8 | from torchvision import transforms 9 | import pickle 10 | import os.path 11 | import datetime 12 | import numpy as np 13 | 14 | class ResNet(nn.Module): 15 | def __init__(self, block, num_blocks, num_labeled_classes=5, num_unlabeled_classes=5): 16 | super(ResNet, self).__init__() 17 | self.in_planes = 64 18 | 19 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(64) 21 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 22 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 23 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 24 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 25 | self.head1 = nn.Linear(512*block.expansion, num_labeled_classes) 26 | self.head2 = nn.Linear(512*block.expansion, num_unlabeled_classes) 27 | 28 | def _make_layer(self, block, planes, num_blocks, stride): 29 | strides = [stride] + [1]*(num_blocks-1) 30 | layers = [] 31 | for stride in strides: 32 | layers.append(block(self.in_planes, planes, stride)) 33 | self.in_planes = planes * block.expansion 34 | return nn.Sequential(*layers) 35 | 36 | def forward(self, x): 37 | out = F.relu(self.bn1(self.conv1(x))) 38 | out = self.layer1(out) 39 | out = self.layer2(out) 40 | out = self.layer3(out) 41 | out = self.layer4(out) 42 | out = F.avg_pool2d(out, 4) 43 | out = out.view(out.size(0), -1) 44 | out = F.relu(out) #add ReLU to benifit ranking 45 | out1 = self.head1(out) 46 | out2 = self.head2(out) 47 | return out1, out2, out 48 | 49 | class BasicBlock(nn.Module): 50 | expansion = 1 51 | 52 | def __init__(self, in_planes, planes, stride=1): 53 | super(BasicBlock, self).__init__() 54 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 55 | self.bn1 = nn.BatchNorm2d(planes) 56 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.shortcut = nn.Sequential() 59 | self.is_padding = 0 60 | if stride != 1 or in_planes != self.expansion*planes: 61 | self.shortcut = nn.AvgPool2d(2) 62 | if in_planes != self.expansion*planes: 63 | self.is_padding = 1 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(self.conv1(x))) 67 | out = self.bn2(self.conv2(out)) 68 | 69 | if self.is_padding: 70 | shortcut = self.shortcut(x) 71 | out += torch.cat([shortcut,torch.zeros(shortcut.shape).type(torch.cuda.FloatTensor)],1) 72 | else: 73 | out += self.shortcut(x) 74 | out = F.relu(out) 75 | return out 76 | 77 | if __name__ == '__main__': 78 | 79 | from torch.nn.parameter import Parameter 80 | device = torch.device('cuda') 81 | num_labeled_classes = 10 82 | num_unlabeled_classes = 20 83 | model = ResNet(BasicBlock, [2,2,2,2],num_labeled_classes, num_unlabeled_classes) 84 | model = model.to(device) 85 | print(model) 86 | y1, y2 = model(Variable(torch.randn(256,3,32,32).to(device))) 87 | print(y1.size(), y2.size()) 88 | 89 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.parameter import Parameter 7 | from sklearn.manifold import TSNE 8 | import matplotlib 9 | matplotlib.use('agg') 10 | import seaborn as sns 11 | from matplotlib import pyplot as plt 12 | from sklearn.utils.linear_assignment_ import linear_assignment 13 | import random 14 | import os 15 | import argparse 16 | ####################################################### 17 | # Evaluate Critiron 18 | ####################################################### 19 | def cluster_acc(y_true, y_pred): 20 | """ 21 | Calculate clustering accuracy. Require scikit-learn installed 22 | 23 | # Arguments 24 | y: true labels, numpy.array with shape `(n_samples,)` 25 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 26 | 27 | # Return 28 | accuracy, in [0,1] 29 | """ 30 | y_true = y_true.astype(np.int64) 31 | assert y_pred.size == y_true.size 32 | D = max(y_pred.max(), y_true.max()) + 1 33 | w = np.zeros((D, D), dtype=np.int64) 34 | for i in range(y_pred.size): 35 | w[y_pred[i], y_true[i]] += 1 36 | ind = linear_assignment(w.max() - w) 37 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 38 | 39 | class AverageMeter(object): 40 | """Computes and stores the average and current value""" 41 | def __init__(self): 42 | self.reset() 43 | 44 | def reset(self): 45 | self.val = 0 46 | self.avg = 0 47 | self.sum = 0 48 | self.count = 0 49 | 50 | def update(self, val, n=1): 51 | self.val = val 52 | self.sum += val * n 53 | self.count += n 54 | self.avg = self.sum / self.count 55 | 56 | class Identity(nn.Module): 57 | def __init__(self): 58 | super(Identity, self).__init__() 59 | def forward(self, x): 60 | return x 61 | 62 | class BCE(nn.Module): 63 | eps = 1e-7 # Avoid calculating log(0). Use the small value of float16. 64 | def forward(self, prob1, prob2, simi): 65 | # simi: 1->similar; -1->dissimilar; 0->unknown(ignore) 66 | assert len(prob1)==len(prob2)==len(simi), 'Wrong input size:{0},{1},{2}'.format(str(len(prob1)),str(len(prob2)),str(len(simi))) 67 | P = prob1.mul_(prob2) 68 | P = P.sum(1) 69 | P.mul_(simi).add_(simi.eq(-1).type_as(P)) 70 | neglogP = -P.add_(BCE.eps).log_() 71 | return neglogP.mean() 72 | 73 | def PairEnum(x,mask=None): 74 | # Enumerate all pairs of feature in x 75 | assert x.ndimension() == 2, 'Input dimension must be 2' 76 | x1 = x.repeat(x.size(0),1) 77 | x2 = x.repeat(1,x.size(0)).view(-1,x.size(1)) 78 | if mask is not None: 79 | xmask = mask.view(-1,1).repeat(1,x.size(1)) 80 | #dim 0: #sample, dim 1:#feature 81 | x1 = x1[xmask].view(-1,x.size(1)) 82 | x2 = x2[xmask].view(-1,x.size(1)) 83 | return x1,x2 84 | 85 | def accuracy(output, target, topk=(1,)): 86 | """Computes the accuracy over the k top predictions for the specified values of k""" 87 | with torch.no_grad(): 88 | maxk = max(topk) 89 | batch_size = target.size(0) 90 | 91 | _, pred = output.topk(maxk, 1, True, True) 92 | pred = pred.t() 93 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 94 | 95 | res = [] 96 | for k in topk: 97 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 98 | res.append(correct_k.mul_(100.0 / batch_size)) 99 | return res 100 | 101 | def seed_torch(seed=1029): 102 | random.seed(seed) 103 | os.environ['PYTHONHASHSEED'] = str(seed) 104 | np.random.seed(seed) 105 | torch.manual_seed(seed) 106 | torch.cuda.manual_seed(seed) 107 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 108 | torch.backends.cudnn.benchmark = False 109 | torch.backends.cudnn.deterministic = True 110 | 111 | def str2bool(v): 112 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 113 | return True 114 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 115 | return False 116 | else: 117 | raise argparse.ArgumentTypeError('Boolean value expected.') -------------------------------------------------------------------------------- /supervised_learning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.optim import SGD, lr_scheduler 5 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score 6 | from sklearn.metrics import adjusted_rand_score as ari_score 7 | from utils.util import cluster_acc, Identity, AverageMeter 8 | from models.resnet import ResNet, BasicBlock 9 | from data.cifarloader import CIFAR10Loader, CIFAR100Loader 10 | from data.svhnloader import SVHNLoader 11 | from tqdm import tqdm 12 | import numpy as np 13 | import os 14 | 15 | def train(model, train_loader, labeled_eval_loader, args): 16 | optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 17 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) 18 | criterion1 = nn.CrossEntropyLoss() 19 | for epoch in range(args.epochs): 20 | loss_record = AverageMeter() 21 | model.train() 22 | exp_lr_scheduler.step() 23 | for batch_idx, (x, label, idx) in enumerate(tqdm(train_loader)): 24 | x, label = x.to(device), label.to(device) 25 | output1, _, _ = model(x) 26 | loss= criterion1(output1, label) 27 | loss_record.update(loss.item(), x.size(0)) 28 | optimizer.zero_grad() 29 | loss.backward() 30 | optimizer.step() 31 | print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) 32 | print('test on labeled classes') 33 | args.head = 'head1' 34 | test(model, labeled_eval_loader, args) 35 | 36 | def test(model, test_loader, args): 37 | model.eval() 38 | preds=np.array([]) 39 | targets=np.array([]) 40 | for batch_idx, (x, label, _) in enumerate(tqdm(test_loader)): 41 | x, label = x.to(device), label.to(device) 42 | output1, output2, _ = model(x) 43 | if args.head=='head1': 44 | output = output1 45 | else: 46 | output = output2 47 | _, pred = output.max(1) 48 | targets=np.append(targets, label.cpu().numpy()) 49 | preds=np.append(preds, pred.cpu().numpy()) 50 | acc, nmi, ari = cluster_acc(targets.astype(int), preds.astype(int)), nmi_score(targets, preds), ari_score(targets, preds) 51 | print('Test acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(acc, nmi, ari)) 52 | return preds 53 | 54 | if __name__ == "__main__": 55 | import argparse 56 | parser = argparse.ArgumentParser( 57 | description='cluster', 58 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 59 | parser.add_argument('--lr', type=float, default=0.1) 60 | parser.add_argument('--gamma', type=float, default=0.5) 61 | parser.add_argument('--momentum', type=float, default=0.9) 62 | parser.add_argument('--weight_decay', type=float, default=1e-4) 63 | parser.add_argument('--epochs', default=100, type=int) 64 | parser.add_argument('--step_size', default=10, type=int) 65 | parser.add_argument('--batch_size', default=128, type=int) 66 | parser.add_argument('--num_unlabeled_classes', default=5, type=int) 67 | parser.add_argument('--num_labeled_classes', default=5, type=int) 68 | parser.add_argument('--dataset_root', type=str, default='./data/datasets/CIFAR/') 69 | parser.add_argument('--exp_root', type=str, default='./data/experiments/') 70 | parser.add_argument('--rotnet_dir', type=str, default='./data/experiments/selfsupervised_learning/rotnet_cifar10.pth') 71 | parser.add_argument('--model_name', type=str, default='resnet_rotnet') 72 | parser.add_argument('--dataset_name', type=str, default='cifar10', help='options: cifar10, cifar100, svhn') 73 | parser.add_argument('--mode', type=str, default='train') 74 | args = parser.parse_args() 75 | args.cuda = torch.cuda.is_available() 76 | device = torch.device("cuda" if args.cuda else "cpu") 77 | runner_name = os.path.basename(__file__).split(".")[0] 78 | model_dir= os.path.join(args.exp_root, runner_name) 79 | if not os.path.exists(model_dir): 80 | os.makedirs(model_dir) 81 | args.model_dir = model_dir+'/'+'{}.pth'.format(args.model_name) 82 | 83 | model = ResNet(BasicBlock, [2,2,2,2], args.num_labeled_classes, args.num_unlabeled_classes).to(device) 84 | 85 | num_classes = args.num_labeled_classes + args.num_unlabeled_classes 86 | 87 | state_dict = torch.load(args.rotnet_dir) 88 | del state_dict['linear.weight'] 89 | del state_dict['linear.bias'] 90 | model.load_state_dict(state_dict, strict=False) 91 | for name, param in model.named_parameters(): 92 | if 'head' not in name and 'layer4' not in name: 93 | param.requires_grad = False 94 | 95 | if args.dataset_name == 'cifar10': 96 | labeled_train_loader = CIFAR10Loader(root=args.dataset_root, batch_size=args.batch_size, split='train', aug='once', shuffle=True, target_list = range(args.num_labeled_classes)) 97 | labeled_eval_loader = CIFAR10Loader(root=args.dataset_root, batch_size=args.batch_size, split='test', aug=None, shuffle=False, target_list = range(args.num_labeled_classes)) 98 | elif args.dataset_name == 'cifar100': 99 | labeled_train_loader = CIFAR100Loader(root=args.dataset_root, batch_size=args.batch_size, split='train', aug='once', shuffle=True, target_list = range(args.num_labeled_classes)) 100 | labeled_eval_loader = CIFAR100Loader(root=args.dataset_root, batch_size=args.batch_size, split='test', aug=None, shuffle=False, target_list = range(args.num_labeled_classes)) 101 | elif args.dataset_name == 'svhn': 102 | labeled_train_loader = SVHNLoader(root=args.dataset_root, batch_size=args.batch_size, split='train', aug='once', shuffle=True, target_list = range(args.num_labeled_classes)) 103 | labeled_eval_loader = SVHNLoader(root=args.dataset_root, batch_size=args.batch_size, split='test', aug=None, shuffle=False, target_list = range(args.num_labeled_classes)) 104 | 105 | if args.mode == 'train': 106 | train(model, labeled_train_loader, labeled_eval_loader, args) 107 | torch.save(model.state_dict(), args.model_dir) 108 | print("model saved to {}.".format(args.model_dir)) 109 | elif args.mode == 'test': 110 | print("model loaded from {}.".format(args.model_dir)) 111 | model.load_state_dict(torch.load(args.model_dir)) 112 | print('test on labeled classes') 113 | args.head = 'head1' 114 | test(model, labeled_eval_loader, args) 115 | -------------------------------------------------------------------------------- /selfsupervised_learning.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torch.optim import lr_scheduler 8 | from torchvision import transforms 9 | import pickle 10 | import os 11 | import os.path 12 | import datetime 13 | import numpy as np 14 | from data.rotationloader import DataLoader, GenericDataset 15 | from utils.util import AverageMeter, accuracy 16 | from models.resnet import BasicBlock 17 | from tqdm import tqdm 18 | import shutil 19 | 20 | class ResNet(nn.Module): 21 | def __init__(self, block, num_blocks, num_classes=10): 22 | super(ResNet, self).__init__() 23 | self.in_planes = 64 24 | 25 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 26 | self.bn1 = nn.BatchNorm2d(64) 27 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 28 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 29 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 30 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 31 | self.linear = nn.Linear(512*block.expansion, num_classes) 32 | if is_adapters: 33 | self.parallel_conv1 = nn.Conv2d(3, 64, kernel_size=1, stride=1, bias=False) 34 | 35 | def _make_layer(self, block, planes, num_blocks, stride): 36 | strides = [stride] + [1]*(num_blocks-1) 37 | layers = [] 38 | for stride in strides: 39 | layers.append(block(self.in_planes, planes, stride)) 40 | self.in_planes = planes * block.expansion 41 | return nn.Sequential(*layers) 42 | 43 | def forward(self, x): 44 | if is_adapters: 45 | out = F.relu(self.bn1(self.conv1(x)+self.parallel_conv1(x))) 46 | else: 47 | out = F.relu(self.bn1(self.conv1(x))) 48 | out = self.layer1(out) 49 | out = self.layer2(out) 50 | out = self.layer3(out) 51 | out = self.layer4(out) 52 | out = F.avg_pool2d(out, 4) 53 | out = out.view(out.size(0), -1) 54 | out = self.linear(out) 55 | return out 56 | 57 | def train(epoch, model, device, dataloader, optimizer, exp_lr_scheduler, criterion, args): 58 | loss_record = AverageMeter() 59 | acc_record = AverageMeter() 60 | exp_lr_scheduler.step() 61 | model.train() 62 | for batch_idx, (data, label) in enumerate(tqdm(dataloader(epoch))): 63 | data, label = data.to(device), label.to(device) 64 | optimizer.zero_grad() 65 | output = model(data) 66 | loss = criterion(output, label) 67 | 68 | # measure accuracy and record loss 69 | acc = accuracy(output, label) 70 | acc_record.update(acc[0].item(), data.size(0)) 71 | loss_record.update(loss.item(), data.size(0)) 72 | 73 | # compute gradient and do optimizer step 74 | optimizer.zero_grad() 75 | loss.backward() 76 | optimizer.step() 77 | 78 | print('Train Epoch: {} Avg Loss: {:.4f} \t Avg Acc: {:.4f}'.format(epoch, loss_record.avg, acc_record.avg)) 79 | 80 | return loss_record 81 | 82 | def test(model, device, dataloader, args): 83 | acc_record = AverageMeter() 84 | model.eval() 85 | for batch_idx, (data, label) in enumerate(tqdm(dataloader())): 86 | data, label = data.to(device), label.to(device) 87 | output = model(data) 88 | 89 | # measure accuracy and record loss 90 | acc = accuracy(output, label) 91 | acc_record.update(acc[0].item(), data.size(0)) 92 | 93 | print('Test Acc: {:.4f}'.format(acc_record.avg)) 94 | return acc_record 95 | 96 | def main(): 97 | # Training settings 98 | parser = argparse.ArgumentParser(description='Rot_resNet') 99 | parser.add_argument('--batch_size', type=int, default=128, metavar='N', 100 | help='input batch size for training (default: 64)') 101 | parser.add_argument('--no_cuda', action='store_true', default=False, 102 | help='disables CUDA training') 103 | parser.add_argument('--num_workers', type=int, default=4, help='number of data loading workers') 104 | parser.add_argument('--seed', type=int, default=1, 105 | help='random seed (default: 1)') 106 | parser.add_argument('--epochs', type=int, default=200, metavar='N', 107 | help='number of epochs to train (default: 200)') 108 | parser.add_argument('--lr', type=float, default=0.1, metavar='LR', 109 | help='learning rate (default: 0.1)') 110 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 111 | help='SGD momentum (default: 0.9)') 112 | parser.add_argument('--dataset_name', type=str, default='cifar10', help='options: cifar10, cifar100, svhn') 113 | parser.add_argument('--dataset_root', type=str, default='./data/datasets/CIFAR/') 114 | parser.add_argument('--exp_root', type=str, default='./data/experiments/') 115 | parser.add_argument('--model_name', type=str, default='rotnet') 116 | 117 | args = parser.parse_args() 118 | use_cuda = not args.no_cuda and torch.cuda.is_available() 119 | device = torch.device("cuda" if use_cuda else "cpu") 120 | torch.manual_seed(args.seed) 121 | 122 | runner_name = os.path.basename(__file__).split(".")[0] 123 | model_dir= os.path.join(args.exp_root, runner_name) 124 | if not os.path.exists(model_dir): 125 | os.makedirs(model_dir) 126 | args.model_dir = model_dir+'/'+'{}.pth'.format(args.model_name) 127 | 128 | dataset_train = GenericDataset( 129 | dataset_name=args.dataset_name, 130 | split='train', 131 | dataset_root=args.dataset_root 132 | ) 133 | dataset_test = GenericDataset( 134 | dataset_name=args.dataset_name, 135 | split='test', 136 | dataset_root=args.dataset_root 137 | ) 138 | 139 | dloader_train = DataLoader( 140 | dataset=dataset_train, 141 | batch_size=args.batch_size, 142 | num_workers=args.num_workers, 143 | shuffle=True) 144 | 145 | dloader_test = DataLoader( 146 | dataset=dataset_test, 147 | batch_size=args.batch_size, 148 | num_workers=args.num_workers, 149 | shuffle=False) 150 | 151 | global is_adapters 152 | is_adapters = 0 153 | model = ResNet(BasicBlock, [2,2,2,2], num_classes=4) 154 | model = model.to(device) 155 | 156 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5e-4, nesterov=True) 157 | exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160, 200], gamma=0.2) 158 | 159 | criterion = nn.CrossEntropyLoss() 160 | 161 | best_acc = 0 162 | for epoch in range(args.epochs +1): 163 | loss_record = train(epoch, model, device, dloader_train, optimizer, exp_lr_scheduler, criterion, args) 164 | acc_record = test(model, device, dloader_test, args) 165 | 166 | is_best = acc_record.avg > best_acc 167 | best_acc = max(acc_record.avg, best_acc) 168 | if is_best: 169 | torch.save(model.state_dict(), args.model_dir) 170 | 171 | if __name__ == '__main__': 172 | main() 173 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoNovel 2 | 3 | **[Automatically Discovering and Learning New Visual Categories with Ranking Statistics, ICLR 2020](http://www.robots.ox.ac.uk/~vgg/research/auto_novel/)**, 4 |
5 | [Kai Han*](http://www.hankai.org), [Sylvestre-Alvise Rebuffi*](http://www.robots.ox.ac.uk/~srebuffi/), [Sebastien Ehrhardt*](), [Andrea Vedaldi](http://www.robots.ox.ac.uk/~vedaldi/), [Andrew Zisserman](http://www.robots.ox.ac.uk/~az/) 6 |
7 | 8 | 9 | ![splash](asset/splash.png) 10 | 11 | ## Dependencies 12 | 13 | All dependencies are included in `environment.yml`. To install, run 14 | 15 | ```shell 16 | conda env create -f environment.yml 17 | ``` 18 | 19 | (Make sure you have installed [Anaconda](https://www.anaconda.com/) before running.) 20 | 21 | Then, activate the installed environment by 22 | 23 | ``` 24 | conda activate auto_novel 25 | ``` 26 | 27 | ## Overview 28 | 29 | We provide code and models for our experiments on CIFAR10, CIFAR100, SVHN, OmniGlot, and ImageNet: 30 | - Code for self-supervised learning 31 | - Code for supervised learning 32 | - Code for novel category discovery 33 | - Our trained models and all other required pretrained models 34 | 35 | ## Data preparation 36 | 37 | By default, we put the datasets in `./data/datasets/` and save trained models in `./data/experiments/` (soft link is suggested). You may also use any other directories you like by setting the `--dataset_root` argument to `/your/data/path/`, and the `--exp_root` argument to `/your/experiment/path/` when running all experiments below. 38 | 39 | - For CIFAR-10, CIFAR-100, and SVHN, simply download the datasets and put into `./data/datasets/`. 40 | - For OmniGlot, after downloading, you need to put `Alphabet_of_the_Magi, Japanese_(katakana), Latin, Cyrillic, Grantha` from `imags_background` folder into `images_background_val` folder, and put the rest alphabets into `images_background_train` folder. 41 | - For ImageNet, we provide the exact split files used in the experiments following existing work. To download the split files, run the command: 42 | `` 43 | sh scripts/download_imagenet_splits.sh 44 | `` 45 | . The ImageNet dataset folder is organized in the following way: 46 | 47 | ``` 48 | ImageNet/imagenet_rand118 #downloaded by the above command 49 | ImageNet/images/train #standard ImageNet training split 50 | ImageNet/images/val #standard ImageNet validation split 51 | ``` 52 | 53 | ## Pretrained models 54 | We provide our trained models and all other required pretrained models. To download, run: 55 | ``` 56 | sh scripts/download_pretrained_models.sh 57 | ``` 58 | After downloading, you may directly jump to Step 3 below, if you only want to run our ranking based method. 59 | 60 | ## Step 1: Self-supervised learning with both labelled and unlabelled data 61 | 62 | ``` 63 | CUDA_VISIBLE_DEVICES=0 python selfsupervised_learning.py --dataset_name cifar10 --model_name rotnet_cifar10 --dataset_root ./data/datasets/CIFAR/ 64 | ``` 65 | 66 | ``--dataset_name`` can be one of ``{cifar10, cifar100, svhn}``; ``--dataset_root`` is set to ``./data/datasets/CIFAR/`` for CIFAR10/CIFAR100 and ``./data/datasets/SVHN/`` for SVHN. 67 | 68 | Our code for step 1 is based on the official code of the [RotNet paper](https://arxiv.org/pdf/1803.07728.pdf). 69 | 70 | ## Step 2: Supervised learning with labelled data 71 | 72 | ``` 73 | # For CIFAR10 74 | CUDA_VISIBLE_DEVICES=0 python supervised_learning.py --dataset_name cifar10 --model_name resnet_rotnet_cifar10 75 | 76 | # For CIFAR100 77 | CUDA_VISIBLE_DEVICES=0 python supervised_learning.py --dataset_name cifar100 --model_name resnet_rotnet_cifar100 --num_labeled_classes 80 --num_unlabeled_classes 20 78 | 79 | # For SVHN 80 | CUDA_VISIBLE_DEVICES=0 python supervised_learning.py --dataset_name svhn --model_name resnet_rotnet_svhn --dataset_root ./data/datasets/SVHN/ 81 | ``` 82 | 83 | ## Step 3: Joint training for novel category discovery 84 | 85 | ### Novel category discovery on CIFAR10/CIFAR100/SVHN 86 | 87 | ```shell 88 | # Train on CIFAR10 89 | CUDA_VISIBLE_DEVICES=0 sh scripts/auto_novel_cifar10.sh ./data/datasets/CIFAR/ ./data/experiments/ ./data/experiments/pretrained/supervised_learning/resnet_rotnet_cifar10.pth 90 | 91 | # Train on CIFAR100 92 | CUDA_VISIBLE_DEVICES=0 sh scripts/auto_novel_cifar100.sh ./data/datasets/CIFAR/ ./data/experiments/ ./data/experiments/pretrained/supervised_learning/resnet_rotnet_cifar100.pth 93 | 94 | # Train on SVHN 95 | CUDA_VISIBLE_DEVICES=0 sh scripts/auto_novel_svhn.sh ./data/datasets/SVHN/ ./data/experiments/ ./data/experiments/pretrained/supervised_learning/resnet_rotnet_svhn.pth 96 | ``` 97 | 98 | To train in the Incremental Learning (IL) mode, replace ``auto_novel_{cifar10, cifar100, svhn}.sh`` in the above commands by ``auto_novel_IL_{cifar10, cifar100, svhn}.sh``. 99 | 100 | ### Novel category discovery on OmniGlot 101 | 102 | ```shell 103 | # For OmniGlot 104 | CUDA_VISIBLE_DEVICES=0 python auto_novel_omniglot.py 105 | ``` 106 | 107 | ### Novel category discovery on ImageNet 108 | 109 | ```shell 110 | # For ImageNet subset A 111 | CUDA_VISIBLE_DEVICES=0 python auto_novel_imagenet.py --unlabeled_subset A 112 | 113 | # For ImageNet subset B 114 | CUDA_VISIBLE_DEVICES=0 python auto_novel_imagenet.py --unlabeled_subset B 115 | 116 | # For ImageNet subset C 117 | CUDA_VISIBLE_DEVICES=0 python auto_novel_imagenet.py --unlabeled_subset C 118 | ``` 119 | 120 | ### Evaluation on novel category discovery 121 | To run our code in evaluation mode, set the `--mode` to `test`. 122 | 123 | ```shell 124 | # For CIFAR10 125 | CUDA_VISIBLE_DEVICES=0 python auto_novel.py --mode test --dataset_name cifar10 --model_name resnet_cifar10 --exp_root ./data/experiments/pretrained/ 126 | 127 | # For CIFAR100 128 | CUDA_VISIBLE_DEVICES=0 python auto_novel.py --mode test --dataset_name cifar100 --model_name resnet_cifar100 --exp_root ./data/experiments/pretrained/ --num_labeled_classes 80 --num_unlabeled_classes 20 129 | 130 | # For SVHN 131 | CUDA_VISIBLE_DEVICES=0 python auto_novel.py --mode test --dataset_name svhn --model_name resnet_svhn --exp_root ./data/experiments/pretrained/ --dataset_root ./data/datasets/SVHN 132 | 133 | # For OmniGlot 134 | CUDA_VISIBLE_DEVICES=0 python auto_novel_omniglot.py --mode test --model_name vgg6_seed_0 --exp_root ./data/experiments/pretrained/ 135 | 136 | # For ImageNet subset A 137 | CUDA_VISIBLE_DEVICES=0 python auto_novel_imagenet.py --mode test --unlabeled_subset A --exp_root ./data/experiments/pretrained/ 138 | 139 | # For ImageNet subset B 140 | CUDA_VISIBLE_DEVICES=0 python auto_novel_imagenet.py --mode test --unlabeled_subset B --exp_root ./data/experiments/pretrained/ 141 | 142 | # For ImageNet subset C 143 | CUDA_VISIBLE_DEVICES=0 python auto_novel_imagenet.py --mode test --unlabeled_subset C --exp_root ./data/experiments/pretrained/ 144 | ``` 145 | To perform the evaluation in the Incremental Learning (IL) mode, add in the above commands the argument ``--IL`` and replace the model name``resnet_{cifar10, cifar100, svhn}`` by ``resnet_IL_{cifar10, cifar100, svhn}``. 146 | 147 | ## Citation 148 | If this work is helpful for your research, please cite our paper. 149 | ``` 150 | @inproceedings{Han2020automatically, 151 | author = {Kai Han and Sylvestre-Alvise Rebuffi and Sebastien Ehrhardt and Andrea Vedaldi and Andrew Zisserman}, 152 | title = {Automatically Discovering and Learning New Visual Categories with Ranking Statistics}, 153 | booktitle = {International Conference on Learning Representations (ICLR)}, 154 | year = {2020} 155 | } 156 | ``` 157 | 158 | ## Acknowledgments 159 | This work is supported by the [EPSRC Programme Grant Seebibyte EP/M013774/1](http://seebibyte.org/), [Mathworks/DTA DFR02620](), and [ERC IDIU-638009](https://cordis.europa.eu/project/rcn/196773/factsheet/en). 160 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import hashlib 4 | import errno 5 | from tqdm import tqdm 6 | from PIL import Image 7 | import numpy as np 8 | import itertools 9 | import torch 10 | from torch.utils.data.sampler import Sampler 11 | 12 | class TransformKtimes: 13 | def __init__(self, transform, k=10): 14 | self.transform = transform 15 | self.k = k 16 | 17 | def __call__(self, inp): 18 | return torch.stack([self.transform(inp) for i in range(self.k)]) 19 | 20 | class TransformTwice: 21 | def __init__(self, transform): 22 | self.transform = transform 23 | 24 | def __call__(self, inp): 25 | out1 = self.transform(inp) 26 | out2 = self.transform(inp) 27 | return out1, out2 28 | 29 | 30 | class RandomTranslateWithReflect: 31 | """Translate image randomly 32 | 33 | Translate vertically and horizontally by n pixels where 34 | n is integer drawn uniformly independently for each axis 35 | from [-max_translation, max_translation]. 36 | 37 | Fill the uncovered blank area with reflect padding. 38 | """ 39 | 40 | def __init__(self, max_translation): 41 | self.max_translation = max_translation 42 | 43 | def __call__(self, old_image): 44 | xtranslation, ytranslation = np.random.randint(-self.max_translation, 45 | self.max_translation + 1, 46 | size=2) 47 | xpad, ypad = abs(xtranslation), abs(ytranslation) 48 | xsize, ysize = old_image.size 49 | 50 | flipped_lr = old_image.transpose(Image.FLIP_LEFT_RIGHT) 51 | flipped_tb = old_image.transpose(Image.FLIP_TOP_BOTTOM) 52 | flipped_both = old_image.transpose(Image.ROTATE_180) 53 | 54 | new_image = Image.new("RGB", (xsize + 2 * xpad, ysize + 2 * ypad)) 55 | 56 | new_image.paste(old_image, (xpad, ypad)) 57 | 58 | new_image.paste(flipped_lr, (xpad + xsize - 1, ypad)) 59 | new_image.paste(flipped_lr, (xpad - xsize + 1, ypad)) 60 | 61 | new_image.paste(flipped_tb, (xpad, ypad + ysize - 1)) 62 | new_image.paste(flipped_tb, (xpad, ypad - ysize + 1)) 63 | 64 | new_image.paste(flipped_both, (xpad - xsize + 1, ypad - ysize + 1)) 65 | new_image.paste(flipped_both, (xpad + xsize - 1, ypad - ysize + 1)) 66 | new_image.paste(flipped_both, (xpad - xsize + 1, ypad + ysize - 1)) 67 | new_image.paste(flipped_both, (xpad + xsize - 1, ypad + ysize - 1)) 68 | 69 | new_image = new_image.crop((xpad - xtranslation, 70 | ypad - ytranslation, 71 | xpad + xsize - xtranslation, 72 | ypad + ysize - ytranslation)) 73 | 74 | return new_image 75 | 76 | class TwoStreamBatchSampler(Sampler): 77 | """Iterate two sets of indices 78 | 79 | An 'epoch' is one iteration through the primary indices. 80 | During the epoch, the secondary indices are iterated through 81 | as many times as needed. 82 | """ 83 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 84 | self.primary_indices = primary_indices 85 | self.secondary_indices = secondary_indices 86 | self.secondary_batch_size = secondary_batch_size 87 | self.primary_batch_size = batch_size - secondary_batch_size 88 | 89 | assert len(self.primary_indices) >= self.primary_batch_size > 0 90 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 91 | 92 | def __iter__(self): 93 | primary_iter = iterate_once(self.primary_indices) 94 | secondary_iter = iterate_eternally(self.secondary_indices) 95 | return ( 96 | primary_batch + secondary_batch 97 | for (primary_batch, secondary_batch) 98 | in zip(grouper(primary_iter, self.primary_batch_size), 99 | grouper(secondary_iter, self.secondary_batch_size)) 100 | ) 101 | 102 | def __len__(self): 103 | return len(self.primary_indices) // self.primary_batch_size 104 | 105 | 106 | def iterate_once(iterable): 107 | return np.random.permutation(iterable) 108 | 109 | 110 | def iterate_eternally(indices): 111 | def infinite_shuffles(): 112 | while True: 113 | yield np.random.permutation(indices) 114 | return itertools.chain.from_iterable(infinite_shuffles()) 115 | 116 | 117 | def grouper(iterable, n): 118 | "Collect data into fixed-length chunks or blocks" 119 | # grouper('ABCDEFG', 3) --> ABC DEF" 120 | args = [iter(iterable)] * n 121 | return zip(*args) 122 | 123 | def gen_bar_updater(pbar): 124 | def bar_update(count, block_size, total_size): 125 | if pbar.total is None and total_size: 126 | pbar.total = total_size 127 | progress_bytes = count * block_size 128 | pbar.update(progress_bytes - pbar.n) 129 | 130 | return bar_update 131 | 132 | 133 | def check_integrity(fpath, md5=None): 134 | if md5 is None: 135 | return True 136 | if not os.path.isfile(fpath): 137 | return False 138 | md5o = hashlib.md5() 139 | with open(fpath, 'rb') as f: 140 | # read in 1MB chunks 141 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 142 | md5o.update(chunk) 143 | md5c = md5o.hexdigest() 144 | if md5c != md5: 145 | return False 146 | return True 147 | 148 | 149 | def makedir_exist_ok(dirpath): 150 | """ 151 | Python2 support for os.makedirs(.., exist_ok=True) 152 | """ 153 | try: 154 | os.makedirs(dirpath) 155 | except OSError as e: 156 | if e.errno == errno.EEXIST: 157 | pass 158 | else: 159 | raise 160 | 161 | 162 | def download_url(url, root, filename, md5): 163 | from six.moves import urllib 164 | 165 | root = os.path.expanduser(root) 166 | fpath = os.path.join(root, filename) 167 | 168 | makedir_exist_ok(root) 169 | 170 | # downloads file 171 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 172 | print('Using downloaded and verified file: ' + fpath) 173 | else: 174 | try: 175 | print('Downloading ' + url + ' to ' + fpath) 176 | urllib.request.urlretrieve( 177 | url, fpath, 178 | reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True)) 179 | ) 180 | except: 181 | if url[:5] == 'https': 182 | url = url.replace('https:', 'http:') 183 | print('Failed download. Trying https -> http instead.' 184 | ' Downloading ' + url + ' to ' + fpath) 185 | urllib.request.urlretrieve( 186 | url, fpath, 187 | reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True)) 188 | ) 189 | 190 | 191 | def list_dir(root, prefix=False): 192 | """List all directories at a given root 193 | 194 | Args: 195 | root (str): Path to directory whose folders need to be listed 196 | prefix (bool, optional): If true, prepends the path to each result, otherwise 197 | only returns the name of the directories found 198 | """ 199 | root = os.path.expanduser(root) 200 | directories = list( 201 | filter( 202 | lambda p: os.path.isdir(os.path.join(root, p)), 203 | os.listdir(root) 204 | ) 205 | ) 206 | 207 | if prefix is True: 208 | directories = [os.path.join(root, d) for d in directories] 209 | 210 | return directories 211 | 212 | 213 | def list_files(root, suffix, prefix=False): 214 | """List all files ending with a suffix at a given root 215 | 216 | Args: 217 | root (str): Path to directory whose folders need to be listed 218 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 219 | It uses the Python "str.endswith" method and is passed directly 220 | prefix (bool, optional): If true, prepends the path to each result, otherwise 221 | only returns the name of the files found 222 | """ 223 | root = os.path.expanduser(root) 224 | files = list( 225 | filter( 226 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 227 | os.listdir(root) 228 | ) 229 | ) 230 | 231 | if prefix is True: 232 | files = [os.path.join(root, d) for d in files] 233 | 234 | return files 235 | -------------------------------------------------------------------------------- /data/imagenetloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | import sys 7 | if sys.version_info[0] == 2: 8 | import cPickle as pickle 9 | else: 10 | import pickle 11 | 12 | import torch.backends.cudnn as cudnn 13 | import random 14 | import torch.utils.data as data 15 | import torch 16 | import torchvision 17 | import torchvision.transforms as transforms 18 | from torch.utils.data.dataloader import default_collate, DataLoader 19 | from .utils import TransformTwice, TransformKtimes, RandomTranslateWithReflect, TwoStreamBatchSampler 20 | from .concat import ConcatDataset 21 | 22 | def find_classes_from_folder(dir): 23 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 24 | classes.sort() 25 | class_to_idx = {classes[i]: i for i in range(len(classes))} 26 | return classes, class_to_idx 27 | 28 | def find_classes_from_file(file_path): 29 | with open(file_path) as f: 30 | classes = f.readlines() 31 | classes = [x.strip() for x in classes] 32 | classes.sort() 33 | class_to_idx = {classes[i]: i for i in range(len(classes))} 34 | return classes, class_to_idx 35 | 36 | def make_dataset(dir, classes, class_to_idx): 37 | samples = [] 38 | for target in classes: 39 | d = os.path.join(dir, target) 40 | if not os.path.isdir(d): 41 | continue 42 | 43 | for root, _, fnames in sorted(os.walk(d)): 44 | for fname in sorted(fnames): 45 | path = os.path.join(root, fname) 46 | item = (path, class_to_idx[target]) 47 | if 'JPEG' in path or 'jpg' in path: 48 | samples.append(item) 49 | 50 | return samples 51 | 52 | IMG_EXTENSIONS = [ 53 | '.jpg', '.JPG', '.jpeg', '.JPEG', 54 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 55 | ] 56 | 57 | def pil_loader(path): 58 | return Image.open(path).convert('RGB') 59 | 60 | class ImageFolder(data.Dataset): 61 | 62 | def __init__(self, transform=None, target_transform=None, samples=None, loader=pil_loader): 63 | 64 | if len(samples) == 0: 65 | raise(RuntimeError("Found 0 images in subfolders \n" 66 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 67 | 68 | self.samples=samples 69 | self.transform = transform 70 | self.target_transform = target_transform 71 | self.loader = loader 72 | 73 | def __getitem__(self, index): 74 | path = self.samples[index][0] 75 | target = self.samples[index][1] 76 | img = self.loader(path) 77 | if self.transform is not None: 78 | img = self.transform(img) 79 | if self.target_transform is not None: 80 | target = self.target_transform(target) 81 | return img, target, index 82 | 83 | def __len__(self): 84 | return len(self.samples) 85 | 86 | def ImageNet882(aug=None, subfolder='train', path='./data/datasets/ImageNet/'): 87 | img_split = 'images/'+subfolder 88 | classes_118, class_to_idx_118 = find_classes_from_file(os.path.join(path, 'imagenet_rand118/imagenet_118.txt')) 89 | samples_118 = make_dataset(path+img_split, classes_118, class_to_idx_118) 90 | classes_1000, _ = find_classes_from_folder(os.path.join(path, img_split)) 91 | classes_882 = list(set(classes_1000) - set(classes_118)) 92 | class_to_idx_882 = {classes_882[i]: i for i in range(len(classes_882))} 93 | samples_882 = make_dataset(path+img_split, classes_882, class_to_idx_882) 94 | if aug==None: 95 | transform = transforms.Compose([ 96 | transforms.Resize(256), 97 | transforms.CenterCrop(224), 98 | transforms.ToTensor(), 99 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 100 | ]) 101 | elif aug=='once': 102 | transform = transforms.Compose([ 103 | transforms.RandomResizedCrop(224, scale=(0.5, 1.0)), 104 | transforms.RandomHorizontalFlip(), 105 | transforms.ToTensor(), 106 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 107 | ]) 108 | elif aug=='twice': 109 | transform = TransformTwice(transforms.Compose([ 110 | transforms.RandomResizedCrop(224, scale=(0.5, 1.0)), 111 | transforms.RandomHorizontalFlip(), 112 | transforms.ToTensor(), 113 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 114 | ])) 115 | elif aug=='ktimes': 116 | transform = TransformKtimes(transforms.Compose([ 117 | transforms.RandomResizedCrop(224, scale=(0.5, 1.0)), 118 | transforms.RandomHorizontalFlip(), 119 | transforms.ToTensor(), 120 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 121 | ]), k=10) 122 | dataset = ImageFolder(transform=transform, samples=samples_882) 123 | return dataset 124 | 125 | def ImageNet30(path='./data/datasets/ImageNet/', subset='A', aug=None, subfolder='train'): 126 | classes_30, class_to_idx_30 = find_classes_from_file(os.path.join(path, 'imagenet_rand118/imagenet_30_{}.txt'.format(subset))) 127 | samples_30 = make_dataset(path+'images/{}'.format(subfolder), classes_30, class_to_idx_30) 128 | if aug==None: 129 | transform = transforms.Compose([ 130 | transforms.Resize(256), 131 | transforms.CenterCrop(224), 132 | transforms.ToTensor(), 133 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 134 | ]) 135 | elif aug=='once': 136 | transform = transforms.Compose([ 137 | transforms.RandomResizedCrop(224, scale=(0.5, 1.0)), 138 | transforms.RandomHorizontalFlip(), 139 | transforms.ToTensor(), 140 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 141 | ]) 142 | elif aug=='twice': 143 | transform = TransformTwice(transforms.Compose([ 144 | transforms.RandomResizedCrop(224, scale=(0.5, 1.0)), 145 | transforms.RandomHorizontalFlip(), 146 | transforms.ToTensor(), 147 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 148 | ])) 149 | elif aug=='ktimes': 150 | transform = TransformKtimes(transforms.Compose([ 151 | transforms.RandomResizedCrop(224, scale=(0.5, 1.0)), 152 | transforms.RandomHorizontalFlip(), 153 | transforms.ToTensor(), 154 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 155 | ]), k=10) 156 | 157 | dataset = ImageFolder(transform=transform, samples=samples_30) 158 | return dataset 159 | 160 | def ImageNetLoader30(batch_size, num_workers=2, path='./data/datasets/ImageNet/', subset='A', aug=None, shuffle=False, subfolder='train'): 161 | dataset = ImageNet30(path, subset, aug, subfolder) 162 | dataloader_30 = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True) 163 | return dataloader_30 164 | 165 | def ImageNetLoader882(batch_size, num_workers=2, path='./data/datasets/ImageNet/', aug=None, shuffle=False, subfolder='train'): 166 | dataset = ImageNet882(aug=aug, subfolder=subfolder, path=path) 167 | dataloader_882 = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True) 168 | return dataloader_882 169 | 170 | def ImageNetLoader882_30Mix(batch_size, num_workers=2, path='./data/datasets/ImageNet/', unlabeled_subset='A', aug=None, shuffle=False, subfolder='train', unlabeled_batch_size=64): 171 | dataset_labeled = ImageNet882(aug=aug, subfolder=subfolder, path=path) 172 | dataset_unlabeled= ImageNet30(path, unlabeled_subset, aug, subfolder) 173 | dataset= ConcatDataset((dataset_labeled, dataset_unlabeled)) 174 | labeled_idxs = range(len(dataset_labeled)) 175 | unlabeled_idxs = range(len(dataset_labeled), len(dataset_labeled)+len(dataset_unlabeled)) 176 | batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, unlabeled_batch_size) 177 | loader = data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers) 178 | loader.labeled_length = len(dataset_labeled) 179 | loader.unlabeled_length = len(dataset_unlabeled) 180 | return loader 181 | 182 | -------------------------------------------------------------------------------- /data/rotationloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.utils.data as data 4 | import torchvision 5 | import torchnet as tnt 6 | import torchvision.datasets as datasets 7 | import torchvision.transforms as transforms 8 | # from Places205 import Places205 9 | import numpy as np 10 | import random 11 | from torch.utils.data.dataloader import default_collate 12 | from PIL import Image 13 | import os 14 | import errno 15 | import numpy as np 16 | import sys 17 | import csv 18 | from tqdm import tqdm 19 | 20 | from pdb import set_trace as breakpoint 21 | 22 | class GenericDataset(data.Dataset): 23 | def __init__(self, dataset_name, split, random_sized_crop=False, 24 | num_imgs_per_cat=None, dataset_root=None): 25 | self.split = split.lower() 26 | self.dataset_name = dataset_name.lower() 27 | self.name = self.dataset_name + '_' + self.split 28 | self.random_sized_crop = random_sized_crop 29 | 30 | # The num_imgs_per_cats input argument specifies the number 31 | # of training examples per category that would be used. 32 | # This input argument was introduced in order to be able 33 | # to use less annotated examples than what are available 34 | # in a semi-superivsed experiment. By default all the 35 | # available training examplers per category are being used. 36 | self.num_imgs_per_cat = num_imgs_per_cat 37 | 38 | if self.dataset_name=='cifar10': 39 | self.mean_pix = [x/255.0 for x in [125.3, 123.0, 113.9]] 40 | self.std_pix = [x/255.0 for x in [63.0, 62.1, 66.7]] 41 | 42 | if self.random_sized_crop: 43 | raise ValueError('The random size crop option is not supported for the CIFAR dataset') 44 | 45 | transform = [] 46 | if (split != 'test'): 47 | transform.append(transforms.RandomCrop(32, padding=4)) 48 | transform.append(transforms.RandomHorizontalFlip()) 49 | transform.append(lambda x: np.asarray(x)) 50 | self.transform = transforms.Compose(transform) 51 | self.data = datasets.__dict__[self.dataset_name.upper()]( 52 | dataset_root, train=self.split=='train', 53 | download=True, transform=self.transform) 54 | elif self.dataset_name=='cifar100': 55 | self.mean_pix = [x/255.0 for x in [129.3, 124.1, 112.4]] 56 | self.std_pix = [x/255.0 for x in [68.2, 65.4, 70.4]] 57 | 58 | if self.random_sized_crop: 59 | raise ValueError('The random size crop option is not supported for the CIFAR dataset') 60 | 61 | transform = [] 62 | if (split != 'test'): 63 | transform.append(transforms.RandomCrop(32, padding=4)) 64 | transform.append(transforms.RandomHorizontalFlip()) 65 | transform.append(lambda x: np.asarray(x)) 66 | self.transform = transforms.Compose(transform) 67 | self.data = datasets.__dict__[self.dataset_name.upper()]( 68 | dataset_root, train=self.split=='train', 69 | download=True, transform=self.transform) 70 | elif self.dataset_name=='svhn': 71 | self.mean_pix = [0.485, 0.456, 0.406] 72 | self.std_pix = [0.229, 0.224, 0.225] 73 | 74 | if self.random_sized_crop: 75 | raise ValueError('The random size crop option is not supported for the SVHN dataset') 76 | 77 | transform = [] 78 | if (split != 'test'): 79 | transform.append(transforms.RandomCrop(32, padding=4)) 80 | transform.append(lambda x: np.asarray(x)) 81 | self.transform = transforms.Compose(transform) 82 | self.data = datasets.__dict__[self.dataset_name.upper()]( 83 | dataset_root, split=self.split, 84 | download=True, transform=self.transform) 85 | else: 86 | raise ValueError('Not recognized dataset {0}'.format(dataset_name)) 87 | 88 | def __getitem__(self, index): 89 | img, label = self.data[index] 90 | return img, int(label) 91 | 92 | def __len__(self): 93 | return len(self.data) 94 | 95 | class Denormalize(object): 96 | def __init__(self, mean, std): 97 | self.mean = mean 98 | self.std = std 99 | 100 | def __call__(self, tensor): 101 | for t, m, s in zip(tensor, self.mean, self.std): 102 | t.mul_(s).add_(m) 103 | return tensor 104 | 105 | def rotate_img(img, rot): 106 | if rot == 0: # 0 degrees rotation 107 | return img 108 | elif rot == 90: # 90 degrees rotation 109 | return np.flipud(np.transpose(img, (1,0,2))) 110 | elif rot == 180: # 90 degrees rotation 111 | return np.fliplr(np.flipud(img)) 112 | elif rot == 270: # 270 degrees rotation / or -90 113 | return np.transpose(np.flipud(img), (1,0,2)) 114 | else: 115 | raise ValueError('rotation should be 0, 90, 180, or 270 degrees') 116 | 117 | 118 | class DataLoader(object): 119 | def __init__(self, 120 | dataset, 121 | batch_size=1, 122 | unsupervised=True, 123 | epoch_size=None, 124 | num_workers=0, 125 | shuffle=True): 126 | self.dataset = dataset 127 | self.shuffle = shuffle 128 | self.epoch_size = epoch_size if epoch_size is not None else len(dataset) 129 | self.batch_size = batch_size 130 | self.unsupervised = unsupervised 131 | self.num_workers = num_workers 132 | 133 | mean_pix = self.dataset.mean_pix 134 | std_pix = self.dataset.std_pix 135 | self.transform = transforms.Compose([ 136 | transforms.ToTensor(), 137 | transforms.Normalize(mean=mean_pix, std=std_pix) 138 | ]) 139 | self.inv_transform = transforms.Compose([ 140 | Denormalize(mean_pix, std_pix), 141 | lambda x: x.numpy() * 255.0, 142 | lambda x: x.transpose(1,2,0).astype(np.uint8), 143 | ]) 144 | 145 | def get_iterator(self, epoch=0): 146 | rand_seed = epoch * self.epoch_size 147 | random.seed(rand_seed) 148 | if self.unsupervised: 149 | # if in unsupervised mode define a loader function that given the 150 | # index of an image it returns the 4 rotated copies of the image 151 | # plus the label of the rotation, i.e., 0 for 0 degrees rotation, 152 | # 1 for 90 degrees, 2 for 180 degrees, and 3 for 270 degrees. 153 | def _load_function(idx): 154 | idx = idx % len(self.dataset) 155 | img0, _ = self.dataset[idx] 156 | rotated_imgs = [ 157 | self.transform(img0), 158 | self.transform(rotate_img(img0, 90).copy()), 159 | self.transform(rotate_img(img0, 180).copy()), 160 | self.transform(rotate_img(img0, 270).copy()) 161 | ] 162 | rotation_labels = torch.LongTensor([0, 1, 2, 3]) 163 | return torch.stack(rotated_imgs, dim=0), rotation_labels 164 | def _collate_fun(batch): 165 | batch = default_collate(batch) 166 | assert(len(batch)==2) 167 | batch_size, rotations, channels, height, width = batch[0].size() 168 | batch[0] = batch[0].view([batch_size*rotations, channels, height, width]) 169 | batch[1] = batch[1].view([batch_size*rotations]) 170 | return batch 171 | else: # supervised mode 172 | # if in supervised mode define a loader function that given the 173 | # index of an image it returns the image and its categorical label 174 | def _load_function(idx): 175 | idx = idx % len(self.dataset) 176 | img, categorical_label = self.dataset[idx] 177 | img = self.transform(img) 178 | return img, categorical_label 179 | _collate_fun = default_collate 180 | 181 | tnt_dataset = tnt.dataset.ListDataset(elem_list=range(self.epoch_size), 182 | load=_load_function) 183 | data_loader = tnt_dataset.parallel(batch_size=self.batch_size, 184 | collate_fn=_collate_fun, num_workers=self.num_workers, 185 | shuffle=self.shuffle) 186 | return data_loader 187 | 188 | def __call__(self, epoch=0): 189 | return self.get_iterator(epoch) 190 | 191 | def __len__(self): 192 | return self.epoch_size / self.batch_size -------------------------------------------------------------------------------- /data/svhnloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.utils.data as data 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | from .utils import download_url, check_integrity 8 | from .utils import TransformTwice, TransformKtimes, RandomTranslateWithReflect, TwoStreamBatchSampler 9 | from .concat import ConcatDataset 10 | import torch 11 | import torch.utils.data as data 12 | import torchvision.transforms as transforms 13 | 14 | class SVHN(data.Dataset): 15 | """`SVHN `_ Dataset. 16 | Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset, 17 | we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which 18 | expect the class labels to be in the range `[0, C-1]` 19 | 20 | Args: 21 | root (string): Root directory of dataset where directory 22 | ``SVHN`` exists. 23 | split (string): One of {'train', 'test', 'extra'}. 24 | Accordingly dataset is selected. 'extra' is Extra training set. 25 | transform (callable, optional): A function/transform that takes in an PIL image 26 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 27 | target_transform (callable, optional): A function/transform that takes in the 28 | target and transforms it. 29 | download (bool, optional): If true, downloads the dataset from the internet and 30 | puts it in root directory. If dataset is already downloaded, it is not 31 | downloaded again. 32 | 33 | """ 34 | url = "" 35 | filename = "" 36 | file_md5 = "" 37 | 38 | split_list = { 39 | 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 40 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 41 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat", 42 | "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"], 43 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 44 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]} 45 | 46 | def __init__(self, root, split='train', 47 | transform=None, target_transform=None, download=False, target_list=range(5)): 48 | self.root = os.path.expanduser(root) 49 | self.transform = transform 50 | self.target_transform = target_transform 51 | self.split = split # training set or test set or extra set 52 | 53 | if self.split not in self.split_list: 54 | raise ValueError('Wrong split entered! Please use split="train" ' 55 | 'or split="extra" or split="test"') 56 | 57 | self.url = self.split_list[split][0] 58 | self.filename = self.split_list[split][1] 59 | self.file_md5 = self.split_list[split][2] 60 | 61 | if download: 62 | self.download() 63 | 64 | if not self._check_integrity(): 65 | raise RuntimeError('Dataset not found or corrupted.' + 66 | ' You can use download=True to download it') 67 | 68 | # import here rather than at top of file because this is 69 | # an optional dependency for torchvision 70 | import scipy.io as sio 71 | 72 | # reading(loading) mat file as array 73 | loaded_mat = sio.loadmat(os.path.join(self.root, self.filename)) 74 | 75 | self.data = loaded_mat['X'] 76 | # loading from the .mat file gives an np array of type np.uint8 77 | # converting to np.int64, so that we have a LongTensor after 78 | # the conversion from the numpy array 79 | # the squeeze is needed to obtain a 1D tensor 80 | self.labels = loaded_mat['y'].astype(np.int64).squeeze() 81 | 82 | # the svhn dataset assigns the class label "10" to the digit 0 83 | # this makes it inconsistent with several loss functions 84 | # which expect the class labels to be in the range [0, C-1] 85 | np.place(self.labels, self.labels == 10, 0) 86 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 87 | 88 | ind = [i for i in range(len(self.labels)) if int(self.labels[i]) in target_list] 89 | 90 | self.data = self.data[ind] 91 | self.labels= self.labels[ind] 92 | 93 | def __getitem__(self, index): 94 | """ 95 | Args: 96 | index (int): Index 97 | 98 | Returns: 99 | tuple: (image, target) where target is index of the target class. 100 | """ 101 | img, target = self.data[index], int(self.labels[index]) 102 | 103 | # doing this so that it is consistent with all other datasets 104 | # to return a PIL Image 105 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 106 | 107 | if self.transform is not None: 108 | img = self.transform(img) 109 | 110 | if self.target_transform is not None: 111 | target = self.target_transform(target) 112 | 113 | return img, target, index 114 | 115 | def __len__(self): 116 | return len(self.data) 117 | 118 | def _check_integrity(self): 119 | root = self.root 120 | md5 = self.split_list[self.split][2] 121 | fpath = os.path.join(root, self.filename) 122 | return check_integrity(fpath, md5) 123 | 124 | def download(self): 125 | md5 = self.split_list[self.split][2] 126 | download_url(self.url, self.root, self.filename, md5) 127 | 128 | def __repr__(self): 129 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 130 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 131 | fmt_str += ' Split: {}\n'.format(self.split) 132 | fmt_str += ' Root Location: {}\n'.format(self.root) 133 | tmp = ' Transforms (if any): ' 134 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 135 | tmp = ' Target Transforms (if any): ' 136 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 137 | return fmt_str 138 | 139 | def SVHNData(root, split='train', aug=None, target_list=range(5)): 140 | if aug==None: 141 | transform = transforms.Compose([ 142 | transforms.ToTensor(), 143 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 144 | ]) 145 | elif aug=='once': 146 | transform = transforms.Compose([ 147 | transforms.RandomCrop(32, padding=4), 148 | transforms.ToTensor(), 149 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 150 | ]) 151 | elif aug=='twice': 152 | transform = TransformTwice(transforms.Compose([ 153 | transforms.RandomCrop(32, padding=4), 154 | transforms.ToTensor(), 155 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 156 | ])) 157 | 158 | dataset = SVHN(root=root, split=split, transform=transform, target_list=target_list) 159 | return dataset 160 | 161 | 162 | def SVHNLoader(root, batch_size, split='train', num_workers=2, aug=None, shuffle=True, target_list=range(5)): 163 | dataset = SVHNData(root, split, aug,target_list) 164 | loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 165 | return loader 166 | 167 | def SVHNLoaderMix(root, batch_size, split='train',num_workers=2, aug=None, shuffle=True, labeled_list=range(5), unlabeled_list=range(5, 10)): 168 | dataset_labeled = SVHNData(root, split, aug, labeled_list) 169 | dataset_unlabeled = SVHNData(root, split, aug, unlabeled_list) 170 | dataset_labeled.labels = np.concatenate((dataset_labeled.labels,dataset_unlabeled.labels)) 171 | dataset_labeled.data = np.concatenate((dataset_labeled.data,dataset_unlabeled.data),0) 172 | loader = data.DataLoader(dataset_labeled, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 173 | return loader 174 | 175 | def SVHNLoaderTwoStream(root, batch_size, split='train',num_workers=2, aug=None, shuffle=True, labeled_list=range(5), unlabeled_list=range(5, 10), unlabeled_batch_size=64): 176 | dataset_labeled = SVHNData(root, split, aug, labeled_list) 177 | dataset_unlabeled = SVHNData(root, split, aug, unlabeled_list) 178 | dataset = ConcatDataset((dataset_labeled, dataset_unlabeled)) 179 | labeled_idxs = range(len(dataset_labeled)) 180 | unlabeled_idxs = range(len(dataset_labeled), len(dataset_labeled)+len(dataset_unlabeled)) 181 | batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, unlabeled_batch_size) 182 | loader = data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers) 183 | loader.labeled_length = len(dataset_labeled) 184 | loader.unlabeled_length = len(dataset_unlabeled) 185 | return loader 186 | 187 | 188 | -------------------------------------------------------------------------------- /auto_novel_omniglot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.optim import Adam, lr_scheduler 5 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score 6 | from sklearn.metrics import adjusted_rand_score as ari_score 7 | from utils.util import BCE, PairEnum, cluster_acc, Identity, AverageMeter, seed_torch, str2bool 8 | from utils import ramps 9 | from data.omniglotloader import OmniglotLoaderMix, alphabetLoader, omniglot_evaluation_alphabets_mapping 10 | from tqdm import tqdm 11 | import numpy as np 12 | import os 13 | 14 | class VGG(nn.Module): 15 | 16 | def __init__(self, num_labeled_classes=5, num_unlabeled_classes=5): 17 | super(VGG, self).__init__() 18 | self.layer1 = nn.Sequential( 19 | nn.Conv2d(1, 64, kernel_size=3, padding=1), 20 | nn.BatchNorm2d(64), 21 | nn.ReLU(inplace=True), 22 | nn.MaxPool2d(kernel_size=2, stride=2) 23 | ) 24 | self.layer2 = nn.Sequential( 25 | nn.Conv2d(64, 128, kernel_size=3, padding=1), 26 | nn.BatchNorm2d(128), 27 | nn.ReLU(inplace=True), 28 | nn.MaxPool2d(kernel_size=2, stride=2) 29 | ) 30 | self.layer3 = nn.Sequential( 31 | nn.Conv2d(128, 256, kernel_size=3, padding=1), 32 | nn.BatchNorm2d(256), 33 | nn.ReLU(inplace=True), 34 | nn.MaxPool2d(kernel_size=2, stride=2) 35 | ) 36 | self.layer4 = nn.Sequential( 37 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 38 | nn.BatchNorm2d(256), 39 | nn.ReLU(inplace=True), 40 | nn.MaxPool2d(kernel_size=2, stride=2) 41 | ) 42 | self.head1 = nn.Sequential( 43 | nn.Linear(1024, 512), 44 | nn.BatchNorm1d(512), 45 | nn.ReLU(inplace=True), 46 | nn.Linear(512, num_labeled_classes) 47 | ) 48 | self.head2 = nn.Sequential( 49 | nn.Linear(1024, 512), 50 | nn.BatchNorm1d(512), 51 | nn.ReLU(inplace=True), 52 | nn.Linear(512, num_unlabeled_classes) 53 | ) 54 | 55 | def forward(self, x): 56 | x = self.layer1(x) 57 | x = self.layer2(x) 58 | x = self.layer3(x) 59 | x = self.layer4(x) 60 | x = x.view(x.size(0), -1) 61 | out1 = self.head1(x) 62 | out2 = self.head2(x) 63 | return out1, out2, x 64 | 65 | def train(model, train_loader, unlabeled_eval_loader, args): 66 | optimizer = Adam(model.parameters(), lr=args.lr) 67 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) 68 | criterion1 = nn.CrossEntropyLoss() 69 | criterion2 = BCE() 70 | for epoch in range(args.epochs): 71 | loss_record = AverageMeter() 72 | model.train() 73 | exp_lr_scheduler.step() 74 | w = args.rampup_coefficient * ramps.sigmoid_rampup(epoch, args.rampup_length) 75 | for batch_idx, ((x, x_bar), label, idx) in enumerate(tqdm(train_loader)): 76 | x, x_bar, label = x.to(device), x_bar.to(device), label.to(device) 77 | output1, output2, feat = model(x) 78 | output1_bar, output2_bar, _ = model(x_bar) 79 | prob1, prob1_bar, prob2, prob2_bar=F.softmax(output1, dim=1), F.softmax(output1_bar, dim=1), F.softmax(output2, dim=1), F.softmax(output2_bar, dim=1) 80 | 81 | mask_lb = idx0] = -1 95 | 96 | prob1_ulb, _= PairEnum(prob2[~mask_lb]) 97 | _, prob2_ulb = PairEnum(prob2_bar[~mask_lb]) 98 | 99 | loss_ce = criterion1(output1[mask_lb], label[mask_lb]) 100 | loss_bce = criterion2(prob1_ulb, prob2_ulb, target_ulb) 101 | 102 | consistency_loss = F.mse_loss(prob1, prob1_bar) + F.mse_loss(prob2, prob2_bar) 103 | 104 | loss = loss_ce + loss_bce + w * consistency_loss 105 | 106 | loss_record.update(loss.item(), x.size(0)) 107 | optimizer.zero_grad() 108 | loss.backward() 109 | optimizer.step() 110 | print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) 111 | print('test on unlabeled classes') 112 | args.head='head2' 113 | test(model, unlabeled_eval_loader, args) 114 | 115 | 116 | def test(model, test_loader, args): 117 | model.eval() 118 | acc_record = AverageMeter() 119 | preds=np.array([]) 120 | targets=np.array([]) 121 | for batch_idx, (x, label, _) in enumerate(tqdm(test_loader)): 122 | x, label = x.to(device), label.to(device) 123 | output1, output2, _ = model(x) 124 | if args.head=='head1': 125 | output = output1 126 | else: 127 | output = output2 128 | _, pred = output.max(1) 129 | targets=np.append(targets, label.cpu().numpy()) 130 | preds=np.append(preds, pred.cpu().numpy()) 131 | acc, nmi, ari = cluster_acc(targets.astype(int), preds.astype(int)), nmi_score(targets, preds), ari_score(targets, preds) 132 | print('Test acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(acc, nmi, ari)) 133 | return acc, nmi, ari 134 | 135 | def copy_param(model, pre_dict): 136 | new=list(pre_dict.items()) 137 | dict_len = len(pre_dict.items()) 138 | model_kvpair=model.state_dict() 139 | count=0 140 | for key, value in model_kvpair.items(): 141 | if count < dict_len: 142 | layer_name,weights=new[count] 143 | model_kvpair[key]=weights 144 | count+=1 145 | else: 146 | break 147 | model.load_state_dict(model_kvpair) 148 | return model 149 | 150 | 151 | if __name__ == "__main__": 152 | import argparse 153 | parser = argparse.ArgumentParser( 154 | description='cluster', 155 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 156 | parser.add_argument('--lr', type=float, default=0.01) 157 | parser.add_argument('--gamma', type=float, default=0.1) 158 | parser.add_argument('--epochs', default=100, type=int) 159 | parser.add_argument('--rampup_length', default=1, type=int) 160 | parser.add_argument('--rampup_coefficient', type=float, default=100.0) 161 | parser.add_argument('--step_size', default=100, type=int) 162 | parser.add_argument('--batch_size', default=100, type=int) 163 | parser.add_argument('--num_labeled_classes', default=964, type=int) 164 | parser.add_argument('--num_unlabeled_classes', default=20, type=int) 165 | parser.add_argument('--topk', default=5, type=int) 166 | parser.add_argument('--dataset_root', type=str, default='./data/datasets') 167 | parser.add_argument('--exp_root', type=str, default='./data/experiments') 168 | parser.add_argument('--warmup_model_dir', type=str, default='./data/experiments/pretrained/vgg6_omniglot_proto.pth') 169 | parser.add_argument('--model_name', type=str, default='vgg6') 170 | parser.add_argument('--seed', default=0, type=int) 171 | parser.add_argument('--mode', type=str, default='train') 172 | 173 | args = parser.parse_args() 174 | args.cuda = torch.cuda.is_available() 175 | device = torch.device("cuda" if args.cuda else "cpu") 176 | 177 | seed_torch(args.seed) 178 | runner_name = os.path.basename(__file__).split(".")[0] 179 | model_dir= os.path.join(args.exp_root, runner_name) 180 | if not os.path.exists(model_dir): 181 | os.makedirs(model_dir) 182 | 183 | if args.mode == 'train': 184 | state_dict = torch.load(args.warmup_model_dir) 185 | 186 | acc = {} 187 | nmi = {} 188 | ari = {} 189 | 190 | for _, alphabetStr in omniglot_evaluation_alphabets_mapping.items(): 191 | 192 | mix_train_loader= OmniglotLoaderMix(alphabet=alphabetStr, batch_size=args.batch_size, aug='twice', shuffle=True, num_workers=2, root=args.dataset_root, unlabeled_batch_size=32) 193 | unlabeled_eval_loader= alphabetLoader(root=args.dataset_root, batch_size=args.batch_size, alphabet=alphabetStr, subfolder_name='images_evaluation', aug=None, num_workers=2, shuffle=False) 194 | args.num_unlabeled_classes = unlabeled_eval_loader.num_classes 195 | args.model_dir = model_dir+'/'+'{}_{}.pth'.format(args.model_name, alphabetStr) 196 | 197 | model = VGG(num_labeled_classes=args.num_labeled_classes, num_unlabeled_classes=args.num_unlabeled_classes).to(device) 198 | 199 | if args.mode == 'train': 200 | model = copy_param(model, state_dict) 201 | for name, param in model.named_parameters(): 202 | if 'head' not in name and 'layer4' not in name: 203 | param.requires_grad = False 204 | train(model, mix_train_loader, unlabeled_eval_loader, args) 205 | torch.save(model.state_dict(), args.model_dir) 206 | print("model saved to {}.".format(args.model_dir)) 207 | elif args.mode == 'test': 208 | print("model loaded from {}.".format(args.model_dir)) 209 | model.load_state_dict(torch.load(args.model_dir)) 210 | print('test on unlabeled classes') 211 | args.head = 'head2' 212 | acc[alphabetStr], nmi[alphabetStr], ari[alphabetStr] = test(model, unlabeled_eval_loader, args) 213 | print('ACC for all alphabets:',acc) 214 | print('NMI for all alphabets:',nmi) 215 | print('ARI for all alphabets:',ari) 216 | avg_acc, avg_nmi, avg_ari = sum(acc.values())/float(len(acc)), sum(nmi.values())/float(len(nmi)), sum(ari.values())/float(len(ari)) 217 | print('AVG: acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(avg_acc, avg_nmi, avg_ari)) 218 | 219 | -------------------------------------------------------------------------------- /auto_novel_imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.optim import SGD, lr_scheduler 5 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score 6 | from sklearn.metrics import adjusted_rand_score as ari_score 7 | from utils.util import BCE, PairEnum, cluster_acc, Identity, AverageMeter, seed_torch 8 | from utils import ramps 9 | from torchvision.models.resnet import BasicBlock 10 | from data.imagenetloader import ImageNetLoader30, ImageNetLoader882_30Mix, ImageNetLoader882 11 | from tqdm import tqdm 12 | import numpy as np 13 | import math 14 | import os 15 | import warnings 16 | warnings.filterwarnings("ignore", category=UserWarning) 17 | 18 | class ResNet(nn.Module): 19 | 20 | def __init__(self, block, layers, num_labeled_classes=10, num_unlabeled_classes=10): 21 | self.inplanes = 64 22 | super(ResNet, self).__init__() 23 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 24 | bias=False) 25 | self.bn1 = nn.BatchNorm2d(64) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 28 | self.layer1 = self._make_layer(block, 64, layers[0]) 29 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 30 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 31 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 32 | self.avgpool = nn.AvgPool2d(7, stride=1) 33 | self.head1= nn.Linear(512 * block.expansion, num_labeled_classes) 34 | self.head2= nn.Linear(512 * block.expansion, num_unlabeled_classes) 35 | 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 39 | m.weight.data.normal_(0, math.sqrt(2. / n)) 40 | elif isinstance(m, nn.BatchNorm2d): 41 | m.weight.data.fill_(1) 42 | m.bias.data.zero_() 43 | 44 | def _make_layer(self, block, planes, blocks, stride=1): 45 | downsample = None 46 | if stride != 1 or self.inplanes != planes * block.expansion: 47 | downsample = nn.Sequential( 48 | nn.Conv2d(self.inplanes, planes * block.expansion, 49 | kernel_size=1, stride=stride, bias=False), 50 | nn.BatchNorm2d(planes * block.expansion), 51 | ) 52 | 53 | layers = [] 54 | layers.append(block(self.inplanes, planes, stride, downsample)) 55 | self.inplanes = planes * block.expansion 56 | for i in range(1, blocks): 57 | layers.append(block(self.inplanes, planes)) 58 | 59 | return nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | x = self.conv1(x) 63 | x = self.bn1(x) 64 | x = self.relu(x) 65 | x = self.maxpool(x) 66 | 67 | x = self.layer1(x) 68 | x = self.layer2(x) 69 | x = self.layer3(x) 70 | x = self.layer4(x) 71 | 72 | x = self.avgpool(x) 73 | x = x.view(x.size(0), -1) 74 | out1 = self.head1(x) 75 | out2 = self.head2(x) 76 | 77 | return out1, out2, x 78 | 79 | 80 | def train(model, train_loader, labeled_eval_loader, unlabeled_eval_loader, args): 81 | optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 82 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) 83 | criterion1 = nn.CrossEntropyLoss() 84 | criterion2 = BCE() 85 | for epoch in range(args.epochs): 86 | loss_record = AverageMeter() 87 | model.train() 88 | exp_lr_scheduler.step() 89 | w = args.rampup_coefficient * ramps.sigmoid_rampup(epoch, args.rampup_length) 90 | for batch_idx, ((x, x_bar), label, idx) in enumerate(tqdm(train_loader)): 91 | x, x_bar, label = x.to(device), x_bar.to(device), label.to(device) 92 | output1, output2, feat = model(x) 93 | output1_bar, output2_bar, _ = model(x_bar) 94 | prob1, prob1_bar, prob2, prob2_bar=F.softmax(output1, dim=1), F.softmax(output1_bar, dim=1), F.softmax(output2, dim=1), F.softmax(output2_bar, dim=1) 95 | 96 | mask_lb = idx0] = -1 109 | 110 | prob1_ulb, _= PairEnum(prob2[~mask_lb]) 111 | _, prob2_ulb = PairEnum(prob2_bar[~mask_lb]) 112 | 113 | loss_ce = criterion1(output1[mask_lb], label[mask_lb]) 114 | loss_bce = criterion2(prob1_ulb, prob2_ulb, target_ulb) 115 | 116 | consistency_loss = (F.mse_loss(prob1, prob1_bar) + F.mse_loss(prob2, prob2_bar)) 117 | 118 | loss = loss_ce + loss_bce + w * consistency_loss 119 | 120 | loss_record.update(loss.item(), x.size(0)) 121 | optimizer.zero_grad() 122 | loss.backward() 123 | optimizer.step() 124 | print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) 125 | print('test on labeled classes') 126 | args.head = 'head1' 127 | test(model, labeled_eval_loader, args) 128 | print('test on unlabeled classes') 129 | args.head='head2' 130 | test(model, unlabeled_eval_loader, args) 131 | 132 | 133 | def test(model, test_loader, args): 134 | model.eval() 135 | preds=np.array([]) 136 | targets=np.array([]) 137 | for batch_idx, (x, label, _) in enumerate(tqdm(test_loader)): 138 | x, label = x.to(device), label.to(device) 139 | output1, output2, _ = model(x) 140 | if args.head=='head1': 141 | output = output1 142 | else: 143 | output = output2 144 | _, pred = output.max(1) 145 | targets=np.append(targets, label.cpu().numpy()) 146 | preds=np.append(preds, pred.cpu().numpy()) 147 | acc, nmi, ari = cluster_acc(targets.astype(int), preds.astype(int)), nmi_score(targets, preds), ari_score(targets, preds) 148 | print('Test acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(acc, nmi, ari)) 149 | 150 | def copy_param(model, pretrain_dir): 151 | pre_dict = torch.load(pretrain_dir) 152 | new=list(pre_dict.items()) 153 | dict_len = len(pre_dict.items()) 154 | model_kvpair=model.state_dict() 155 | count=0 156 | for key, value in model_kvpair.items(): 157 | if count < dict_len: 158 | layer_name,weights=new[count] 159 | model_kvpair[key]=weights 160 | count+=1 161 | else: 162 | break 163 | model.load_state_dict(model_kvpair) 164 | return model 165 | 166 | if __name__ == "__main__": 167 | import argparse 168 | parser = argparse.ArgumentParser( 169 | description='cluster', 170 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 171 | parser.add_argument('--device_ids', default=[0], type=int, nargs='+', 172 | help='device ids assignment (e.g 0 1 2 3)') 173 | parser.add_argument('--lr', type=float, default=0.1) 174 | parser.add_argument('--gamma', type=float, default=0.1) 175 | parser.add_argument('--momentum', type=float, default=0.9) 176 | parser.add_argument('--weight_decay', type=float, default=1e-4) 177 | parser.add_argument('--epochs', default=90, type=int) 178 | parser.add_argument('--rampup_length', default=50, type=int) 179 | parser.add_argument('--rampup_coefficient', type=float, default=10.0) 180 | parser.add_argument('--step_size', default=30, type=int) 181 | parser.add_argument('--batch_size', default=512, type=int) 182 | parser.add_argument('--unlabeled_batch_size', default=128, type=int) 183 | parser.add_argument('--num_labeled_classes', default=882, type=int) 184 | parser.add_argument('--num_unlabeled_classes', default=30, type=int) 185 | parser.add_argument('--dataset_root', type=str, default='./data/datasets/ImageNet/') 186 | parser.add_argument('--exp_root', type=str, default='./data/experiments/') 187 | parser.add_argument('--warmup_model_dir', type=str, default='./data/experiments/pretrained/resnet18_imagenet_classif_882_ICLR18.pth') 188 | parser.add_argument('--topk', default=5, type=int) 189 | parser.add_argument('--model_name', type=str, default='resnet') 190 | parser.add_argument('--seed', default=1, type=int) 191 | parser.add_argument('--unlabeled_subset', type=str, default='A') 192 | parser.add_argument('--mode', type=str, default='train') 193 | 194 | args = parser.parse_args() 195 | args.cuda = torch.cuda.is_available() 196 | device = torch.device("cuda" if args.cuda else "cpu") 197 | seed_torch(args.seed) 198 | runner_name = os.path.basename(__file__).split(".")[0] 199 | model_dir= os.path.join(args.exp_root, runner_name) 200 | if not os.path.exists(model_dir): 201 | os.makedirs(model_dir) 202 | args.model_dir = model_dir+'/'+'{}_{}.pth'.format(args.model_name, args.unlabeled_subset) 203 | 204 | model = ResNet(BasicBlock, [2,2,2,2], args.num_labeled_classes, args.num_unlabeled_classes) 205 | model = nn.DataParallel(model, args.device_ids).to(device) 206 | model = copy_param(model, args.warmup_model_dir) 207 | 208 | for name, param in model.named_parameters(): 209 | if 'head' not in name and 'layer4' not in name: 210 | param.requires_grad = False 211 | 212 | mix_train_loader = ImageNetLoader882_30Mix(args.batch_size, num_workers=8, path=args.dataset_root, unlabeled_subset=args.unlabeled_subset, aug='twice', shuffle=True, subfolder='train', unlabeled_batch_size=args.unlabeled_batch_size) 213 | labeled_eval_loader = ImageNetLoader882(args.batch_size, num_workers=8, path=args.dataset_root, aug=None, shuffle=False, subfolder='val') 214 | unlabeled_eval_loader = ImageNetLoader30(args.batch_size, num_workers=8, path=args.dataset_root, subset=args.unlabeled_subset, aug=None, shuffle=False, subfolder='train') 215 | 216 | if args.mode == 'train': 217 | train(model, mix_train_loader, labeled_eval_loader, unlabeled_eval_loader, args) 218 | torch.save(model.state_dict(), args.model_dir) 219 | print("model saved to {}.".format(args.model_dir)) 220 | else: 221 | print("model loaded from {}.".format(args.model_dir)) 222 | model.load_state_dict(torch.load(args.model_dir)) 223 | 224 | print('test on labeled classes') 225 | args.head = 'head1' 226 | test(model, labeled_eval_loader, args) 227 | print('test on unlabeled classes') 228 | args.head = 'head2' 229 | test(model, unlabeled_eval_loader, args) 230 | -------------------------------------------------------------------------------- /data/omniglotloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torchvision import transforms 4 | import torch.utils.data as data 5 | from .omniglot import Omniglot 6 | from .utils import TransformTwice, TransformKtimes, RandomTranslateWithReflect, TwoStreamBatchSampler 7 | from .concat import ConcatDataset 8 | 9 | 10 | def OmniglotLoader(root, batch_size, subfolder_name='images_background', num_workers=2, aug=None, shuffle=True): 11 | binary_flip = transforms.Lambda(lambda x: 1 - x) 12 | normalize = transforms.Normalize((0.086,), (0.235,)) 13 | if aug==None: 14 | transform=transforms.Compose([ 15 | transforms.Resize(32), 16 | transforms.ToTensor(), 17 | binary_flip, 18 | normalize 19 | ]) 20 | elif aug=='once': 21 | transform=transforms.Compose([ 22 | transforms.RandomResizedCrop(32, (0.85, 1.)), 23 | transforms.ToTensor(), 24 | binary_flip, 25 | normalize 26 | ]) 27 | elif aug=='twice': 28 | transform = TransformTwice(transforms.Compose([ 29 | transforms.RandomResizedCrop(32, (0.85, 1.)), 30 | transforms.ToTensor(), 31 | binary_flip, 32 | normalize 33 | ])) 34 | elif aug=='ktimes': 35 | transform = TransformKtimes(transforms.Compose([ 36 | transforms.RandomResizedCrop(32, (0.85, 1.)), 37 | transforms.RandomAffine(degrees = (-5, 5), translate=(0.1, 0.1), scale=(0.8, 1.2), shear = (-10, 10), fillcolor=255), 38 | transforms.ToTensor(), 39 | binary_flip, 40 | normalize 41 | ]), k=10) 42 | 43 | dataset = Omniglot(root=root, subfolder_name=subfolder_name, transform=transform) 44 | loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 45 | return loader 46 | 47 | def alphabetLoader(root, alphabet, batch_size, subfolder_name='images_evaluation', aug=None, num_workers=2, shuffle=False): 48 | binary_flip = transforms.Lambda(lambda x: 1 - x) 49 | normalize = transforms.Normalize((0.086,), (0.235,)) 50 | if aug==None: 51 | transform=transforms.Compose([ 52 | transforms.Resize(32), 53 | transforms.ToTensor(), 54 | binary_flip, 55 | normalize 56 | ]) 57 | elif aug=='once': 58 | transform=transforms.Compose([ 59 | transforms.RandomResizedCrop(32, (0.85, 1.)), 60 | transforms.ToTensor(), 61 | binary_flip, 62 | normalize 63 | ]) 64 | elif aug=='twice': 65 | transform = TransformTwice(transforms.Compose([ 66 | transforms.RandomResizedCrop(32, (0.85, 1.)), 67 | transforms.ToTensor(), 68 | binary_flip, 69 | normalize 70 | ])) 71 | elif aug=='ktimes': 72 | transform = TransformKtimes(transforms.Compose([ 73 | transforms.RandomResizedCrop(32, (0.85, 1.)), 74 | transforms.RandomAffine(degrees = (-5, 5), translate=(0.1, 0.1), scale=(0.8, 1.2), shear = (-10, 10), fillcolor=255), 75 | transforms.ToTensor(), 76 | binary_flip, 77 | normalize 78 | ]), k=10) 79 | 80 | dataset = Omniglot(root=root, subfolder_name=subfolder_name, transform=transform) 81 | # Only use the images which has alphabet-name in their path name (_characters[cid]) 82 | valid_flat_character_images = [(imgname,cid) for imgname,cid in dataset._flat_character_images if alphabet in dataset._characters[cid]] 83 | ndata = len(valid_flat_character_images) # The number of data after filtering 84 | imgid2cid = [valid_flat_character_images[i][1] for i in range(ndata)] # The tuple (valid_flat_character_images[i]) are (img, cid) 85 | cid_set = set(imgid2cid) # The labels are not 0..c-1 here. 86 | cid2ncid = {cid:ncid for ncid,cid in enumerate(cid_set)} # Create the mapping table for New cid (ncid) 87 | valid_characters = {cid2ncid[cid]:dataset._characters[cid] for cid in cid_set} 88 | for i in range(ndata): # Convert the labels to make sure it has the value {0..c-1} 89 | valid_flat_character_images[i] = (valid_flat_character_images[i][0],cid2ncid[valid_flat_character_images[i][1]]) 90 | # Apply surgery to the dataset 91 | dataset._flat_character_images = valid_flat_character_images 92 | dataset._characters = valid_characters 93 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 94 | loader.num_classes = len(cid_set) 95 | print('=> Alphabet %s has %d characters and %d images.'%(alphabet, loader.num_classes, len(dataset))) 96 | return loader 97 | 98 | def OmniglotLoaderMix(root, alphabet, batch_size, num_workers=2, aug=None, shuffle=False, unlabeled_batch_size=32): 99 | binary_flip = transforms.Lambda(lambda x: 1 - x) 100 | normalize = transforms.Normalize((0.086,), (0.235,)) 101 | if aug==None: 102 | transform=transforms.Compose([ 103 | transforms.Resize(32), 104 | transforms.ToTensor(), 105 | binary_flip, 106 | normalize 107 | ]) 108 | elif aug=='once': 109 | transform=transforms.Compose([ 110 | transforms.RandomResizedCrop(32, (0.85, 1.)), 111 | transforms.ToTensor(), 112 | binary_flip, 113 | normalize 114 | ]) 115 | elif aug=='twice': 116 | transform = TransformTwice(transforms.Compose([ 117 | transforms.RandomResizedCrop(32, (0.85, 1.)), 118 | transforms.ToTensor(), 119 | binary_flip, 120 | normalize 121 | ])) 122 | elif aug=='ktimes': 123 | transform = TransformKtimes(transforms.Compose([ 124 | transforms.RandomResizedCrop(32, (0.85, 1.)), 125 | transforms.RandomAffine(degrees = (-5, 5), translate=(0.1, 0.1), scale=(0.8, 1.2), shear = (-10, 10), fillcolor=255), 126 | transforms.ToTensor(), 127 | binary_flip, 128 | normalize 129 | ]), k=10) 130 | 131 | dataset_labeled = Omniglot(root=root, subfolder_name='images_background', transform=transform) 132 | dataset_unlabeled = Omniglot(root=root, subfolder_name='images_evaluation', transform=transform) 133 | # Only use the images which has alphabet-name in their path name (_characters[cid]) 134 | valid_flat_character_images = [(imgname, cid) for imgname,cid in dataset_unlabeled._flat_character_images if alphabet in dataset_unlabeled._characters[cid]] 135 | ndata = len(valid_flat_character_images) # The number of data after filtering 136 | imgid2cid = [valid_flat_character_images[i][1] for i in range(ndata)] # The tuple (valid_flat_character_images[i]) are (img, cid) 137 | cid_set = set(imgid2cid) # The labels are not 0..c-1 here. 138 | cid2ncid = {cid:ncid for ncid,cid in enumerate(cid_set)} # Create the mapping table for New cid (ncid) 139 | valid_characters = {cid2ncid[cid]:dataset_unlabeled._characters[cid] for cid in cid_set} 140 | for i in range(ndata): # Convert the labels to make sure it has the value {0..c-1} 141 | valid_flat_character_images[i] = (valid_flat_character_images[i][0],cid2ncid[valid_flat_character_images[i][1]]) 142 | # Apply surgery to the dataset 143 | dataset_unlabeled._flat_character_images = valid_flat_character_images 144 | dataset_unlabeled._characters = valid_characters 145 | dataset= ConcatDataset((dataset_labeled, dataset_unlabeled)) 146 | labeled_idxs = range(len(dataset_labeled)) 147 | unlabeled_idxs = range(len(dataset_labeled), len(dataset_labeled)+len(dataset_unlabeled)) 148 | batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, unlabeled_batch_size) 149 | loader = data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers) 150 | loader.labeled_length = len(dataset_labeled) 151 | loader.unlabeled_length = len(dataset_unlabeled) 152 | return loader 153 | 154 | 155 | def alphabetData(root, alphabet, batch_size, subfolder_name='images_evaluation', aug=None): 156 | binary_flip = transforms.Lambda(lambda x: 1 - x) 157 | normalize = transforms.Normalize((0.086,), (0.235,)) 158 | if aug==None: 159 | transform=transforms.Compose([ 160 | transforms.Resize(32), 161 | transforms.ToTensor(), 162 | binary_flip, 163 | normalize 164 | ]) 165 | elif aug=='once': 166 | transform=transforms.Compose([ 167 | transforms.RandomResizedCrop(32, (0.85, 1.)), 168 | transforms.ToTensor(), 169 | binary_flip, 170 | normalize 171 | ]) 172 | elif aug=='twice': 173 | transform = TransformTwice(transforms.Compose([ 174 | transforms.RandomResizedCrop(32, (0.85, 1.)), 175 | transforms.ToTensor(), 176 | binary_flip, 177 | normalize 178 | ])) 179 | elif aug=='ktimes': 180 | transform = TransformKtimes(transforms.Compose([ 181 | transforms.RandomResizedCrop(32, (0.85, 1.)), 182 | transforms.RandomAffine(degrees = (-5, 5), translate=(0.1, 0.1), scale=(0.8, 1.2), shear = (-10, 10), fillcolor=255), 183 | transforms.ToTensor(), 184 | binary_flip, 185 | normalize 186 | ]), k=10) 187 | 188 | dataset = Omniglot(root=root, subfolder_name=subfolder_name, transform=transform) 189 | # Only use the images which has alphabet-name in their path name (_characters[cid]) 190 | valid_flat_character_images = [(imgname,cid) for imgname,cid in dataset._flat_character_images if alphabet in dataset._characters[cid]] 191 | ndata = len(valid_flat_character_images) # The number of data after filtering 192 | imgid2cid = [valid_flat_character_images[i][1] for i in range(ndata)] # The tuple (valid_flat_character_images[i]) are (img, cid) 193 | cid_set = set(imgid2cid) # The labels are not 0..c-1 here. 194 | cid2ncid = {cid:ncid for ncid,cid in enumerate(cid_set)} # Create the mapping table for New cid (ncid) 195 | valid_characters = {cid2ncid[cid]:dataset._characters[cid] for cid in cid_set} 196 | for i in range(ndata): # Convert the labels to make sure it has the value {0..c-1} 197 | valid_flat_character_images[i] = (valid_flat_character_images[i][0],cid2ncid[valid_flat_character_images[i][1]]) 198 | # Apply surgery to the dataset 199 | dataset._flat_character_images = valid_flat_character_images 200 | dataset._characters = valid_characters 201 | num_classes = len(cid_set) 202 | print('=> Alphabet %s has %d characters and %d images.'%(alphabet, num_classes, len(dataset))) 203 | return dataset, num_classes 204 | 205 | 206 | def alphabetLoaderMix(root, labeled_alphabet, unlabeled_alphabet, batch_size, num_workers=2, aug=None, shuffle=False, unlabeled_batch_size=64): 207 | dataset_labeled, num_labeled_classes = alphabetData(root, labeled_alphabet, batch_size, subfolder_name='images_background', aug=aug) 208 | dataset_unlabeled, num_unlabeled_classes = alphabetData(root, unlabeled_alphabet, batch_size, subfolder_name='images_evaluation', aug=aug) 209 | dataset= ConcatDataset((dataset_labeled, dataset_unlabeled)) 210 | labeled_idxs = range(len(dataset_labeled)) 211 | unlabeled_idxs = range(len(dataset_labeled), len(dataset_labeled)+len(dataset_unlabeled)) 212 | batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, unlabeled_batch_size) 213 | loader = data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers) 214 | loader.num_labeled_classes = num_labeled_classes 215 | loader.num_unlabeled_classes = num_unlabeled_classes 216 | loader.labeled_length = len(dataset_labeled) 217 | loader.unlabeled_length = len(dataset_unlabeled) 218 | return loader 219 | 220 | 221 | omniglot_background_alphabets=[ 222 | 'Alphabet_of_the_Magi', 223 | 'Gujarati', 224 | 'Anglo-Saxon_Futhorc', 225 | 'Hebrew', 226 | 'Arcadian', 227 | 'Inuktitut_(Canadian_Aboriginal_Syllabics)', 228 | 'Armenian', 229 | 'Japanese_(hiragana)', 230 | 'Asomtavruli_(Georgian)', 231 | 'Japanese_(katakana)', 232 | 'Balinese', 233 | 'Korean', 234 | 'Bengali', 235 | 'Latin', 236 | 'Blackfoot_(Canadian_Aboriginal_Syllabics)', 237 | 'Malay_(Jawi_-_Arabic)', 238 | 'Braille', 239 | 'Mkhedruli_(Georgian)', 240 | 'Burmese_(Myanmar)', 241 | 'N_Ko', 242 | 'Cyrillic', 243 | 'Ojibwe_(Canadian_Aboriginal_Syllabics)', 244 | 'Early_Aramaic', 245 | 'Sanskrit', 246 | 'Futurama', 247 | 'Syriac_(Estrangelo)', 248 | 'Grantha', 249 | 'Tagalog', 250 | 'Greek', 251 | 'Tifinagh' 252 | ] 253 | 254 | omniglot_evaluation_alphabets_mapping = { 255 | 'Malayalam':'Malayalam', 256 | 'Kannada':'Kannada', 257 | 'Syriac':'Syriac_(Serto)', 258 | 'Atemayar_Qelisayer':'Atemayar_Qelisayer', 259 | 'Gurmukhi':'Gurmukhi', 260 | 'Old_Church_Slavonic':'Old_Church_Slavonic_(Cyrillic)', 261 | 'Manipuri':'Manipuri', 262 | 'Atlantean':'Atlantean', 263 | 'Sylheti':'Sylheti', 264 | 'Mongolian':'Mongolian', 265 | 'Aurek':'Aurek-Besh', 266 | 'Angelic':'Angelic', 267 | 'ULOG':'ULOG', 268 | 'Oriya':'Oriya', 269 | 'Avesta':'Avesta', 270 | 'Tibetan':'Tibetan', 271 | 'Tengwar':'Tengwar', 272 | 'Keble':'Keble', 273 | 'Ge_ez':'Ge_ez', 274 | 'Glagolitic':'Glagolitic' 275 | } 276 | -------------------------------------------------------------------------------- /data/cifarloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | import sys 7 | if sys.version_info[0] == 2: 8 | import cPickle as pickle 9 | else: 10 | import pickle 11 | 12 | import random 13 | import torch 14 | import torch.utils.data as data 15 | from .utils import download_url, check_integrity 16 | from .utils import TransformTwice, TransformKtimes, RandomTranslateWithReflect, TwoStreamBatchSampler 17 | from .concat import ConcatDataset 18 | import torchvision.transforms as transforms 19 | 20 | class CIFAR10(data.Dataset): 21 | """`CIFAR10 `_ Dataset. 22 | 23 | Args: 24 | root (string): Root directory of dataset where directory 25 | ``cifar-10-batches-py`` exists or will be saved to if download is set to True. 26 | train (bool, optional): If True, creates dataset from training set, otherwise 27 | creates from test set. 28 | transform (callable, optional): A function/transform that takes in an PIL image 29 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 30 | target_transform (callable, optional): A function/transform that takes in the 31 | target and transforms it. 32 | download (bool, optional): If true, downloads the dataset from the internet and 33 | puts it in root directory. If dataset is already downloaded, it is not 34 | downloaded again. 35 | 36 | """ 37 | base_folder = 'cifar-10-batches-py' 38 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 39 | filename = "cifar-10-python.tar.gz" 40 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a' 41 | train_list = [ 42 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 43 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 44 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 45 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 46 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 47 | ] 48 | 49 | test_list = [ 50 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 51 | ] 52 | meta = { 53 | 'filename': 'batches.meta', 54 | 'key': 'label_names', 55 | 'md5': '5ff9c542aee3614f3951f8cda6e48888', 56 | } 57 | 58 | def __init__(self, root, split='train+test', 59 | transform=None, target_transform=None, 60 | download=False, target_list = range(5)): 61 | self.root = os.path.expanduser(root) 62 | self.transform = transform 63 | self.target_transform = target_transform 64 | 65 | if download: 66 | self.download() 67 | 68 | if not self._check_integrity(): 69 | raise RuntimeError('Dataset not found or corrupted.' + 70 | ' You can use download=True to download it') 71 | downloaded_list = [] 72 | if split=='train': 73 | downloaded_list = self.train_list 74 | elif split=='test': 75 | downloaded_list = self.test_list 76 | elif split=='train+test': 77 | downloaded_list.extend(self.train_list) 78 | downloaded_list.extend(self.test_list) 79 | 80 | self.data = [] 81 | self.targets = [] 82 | 83 | # now load the picked numpy arrays 84 | for file_name, checksum in downloaded_list: 85 | file_path = os.path.join(self.root, self.base_folder, file_name) 86 | with open(file_path, 'rb') as f: 87 | if sys.version_info[0] == 2: 88 | entry = pickle.load(f) 89 | else: 90 | entry = pickle.load(f, encoding='latin1') 91 | self.data.append(entry['data']) 92 | if 'labels' in entry: 93 | self.targets.extend(entry['labels']) 94 | else: 95 | # self.targets.extend(entry['coarse_labels']) 96 | self.targets.extend(entry['fine_labels']) 97 | 98 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 99 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 100 | self._load_meta() 101 | 102 | ind = [i for i in range(len(self.targets)) if self.targets[i] in target_list] 103 | 104 | self.data = self.data[ind] 105 | self.targets = np.array(self.targets) 106 | self.targets = self.targets[ind].tolist() 107 | 108 | 109 | 110 | def _load_meta(self): 111 | path = os.path.join(self.root, self.base_folder, self.meta['filename']) 112 | if not check_integrity(path, self.meta['md5']): 113 | raise RuntimeError('Dataset metadata file not found or corrupted.' + 114 | ' You can use download=True to download it') 115 | with open(path, 'rb') as infile: 116 | if sys.version_info[0] == 2: 117 | data = pickle.load(infile) 118 | else: 119 | data = pickle.load(infile, encoding='latin1') 120 | self.classes = data[self.meta['key']] 121 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} 122 | # x = self.class_to_idx 123 | # sorted_x = sorted(x.items(), key=lambda kv: kv[1]) 124 | # print(sorted_x) 125 | 126 | def __getitem__(self, index): 127 | """ 128 | Args: 129 | index (int): Index 130 | 131 | Returns: 132 | tuple: (image, target) where target is index of the target class. 133 | """ 134 | img, target = self.data[index], self.targets[index] 135 | 136 | # doing this so that it is consistent with all other datasets 137 | # to return a PIL Image 138 | img = Image.fromarray(img) 139 | 140 | if self.transform is not None: 141 | img = self.transform(img) 142 | 143 | if self.target_transform is not None: 144 | target = self.target_transform(target) 145 | 146 | return img, target, index 147 | 148 | def __len__(self): 149 | return len(self.data) 150 | 151 | def _check_integrity(self): 152 | root = self.root 153 | for fentry in (self.train_list + self.test_list): 154 | filename, md5 = fentry[0], fentry[1] 155 | fpath = os.path.join(root, self.base_folder, filename) 156 | if not check_integrity(fpath, md5): 157 | return False 158 | return True 159 | 160 | def download(self): 161 | import tarfile 162 | 163 | if self._check_integrity(): 164 | print('Files already downloaded and verified') 165 | return 166 | 167 | download_url(self.url, self.root, self.filename, self.tgz_md5) 168 | 169 | # extract file 170 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: 171 | tar.extractall(path=self.root) 172 | 173 | def __repr__(self): 174 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 175 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 176 | tmp = 'train' if self.train is True else 'test' 177 | fmt_str += ' Split: {}\n'.format(tmp) 178 | fmt_str += ' Root Location: {}\n'.format(self.root) 179 | tmp = ' Transforms (if any): ' 180 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 181 | tmp = ' Target Transforms (if any): ' 182 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 183 | return fmt_str 184 | 185 | 186 | class CIFAR100(CIFAR10): 187 | """`CIFAR100 `_ Dataset. 188 | 189 | This is a subclass of the `CIFAR10` Dataset. 190 | """ 191 | base_folder = 'cifar-100-python' 192 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 193 | filename = "cifar-100-python.tar.gz" 194 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 195 | train_list = [ 196 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 197 | ] 198 | 199 | test_list = [ 200 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 201 | ] 202 | meta = { 203 | 'filename': 'meta', 204 | 'key': 'fine_label_names', 205 | # 'key': 'coarse_label_names', 206 | 'md5': '7973b15100ade9c7d40fb424638fde48', 207 | } 208 | 209 | def CIFAR10Data(root, split='train', aug=None, target_list=range(5)): 210 | if aug==None: 211 | transform = transforms.Compose([ 212 | transforms.ToTensor(), 213 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 214 | ]) 215 | elif aug=='once': 216 | transform = transforms.Compose([ 217 | transforms.RandomCrop(32, padding=4), 218 | transforms.RandomHorizontalFlip(), 219 | transforms.ToTensor(), 220 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 221 | ]) 222 | elif aug=='twice': 223 | transform = TransformTwice(transforms.Compose([ 224 | RandomTranslateWithReflect(4), 225 | transforms.RandomHorizontalFlip(), 226 | transforms.ToTensor(), 227 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 228 | ])) 229 | dataset = CIFAR10(root=root, split=split, transform=transform, target_list=target_list) 230 | return dataset 231 | 232 | def CIFAR10Loader(root, batch_size, split='train', num_workers=2, aug=None, shuffle=True, target_list=range(5)): 233 | dataset = CIFAR10Data(root, split, aug,target_list) 234 | loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 235 | return loader 236 | 237 | def CIFAR10LoaderMix(root, batch_size, split='train',num_workers=2, aug=None, shuffle=True, labeled_list=range(5), unlabeled_list=range(5, 10), new_labels=None): 238 | if aug==None: 239 | transform = transforms.Compose([ 240 | transforms.ToTensor(), 241 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 242 | ]) 243 | elif aug=='once': 244 | transform = transforms.Compose([ 245 | transforms.RandomCrop(32, padding=4), 246 | transforms.RandomHorizontalFlip(), 247 | transforms.ToTensor(), 248 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 249 | ]) 250 | elif aug=='twice': 251 | transform = TransformTwice(transforms.Compose([ 252 | RandomTranslateWithReflect(4), 253 | transforms.RandomHorizontalFlip(), 254 | transforms.ToTensor(), 255 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 256 | ])) 257 | dataset_labeled = CIFAR10(root=root, split=split, transform=transform, target_list=labeled_list) 258 | dataset_unlabeled = CIFAR10(root=root, split=split, transform=transform, target_list=unlabeled_list) 259 | if new_labels is not None: 260 | dataset_unlabeled.targets = new_labels 261 | dataset_labeled.targets = np.concatenate((dataset_labeled.targets,dataset_unlabeled.targets)) 262 | dataset_labeled.data = np.concatenate((dataset_labeled.data,dataset_unlabeled.data),0) 263 | loader = data.DataLoader(dataset_labeled, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 264 | return loader 265 | 266 | def CIFAR10LoaderTwoStream(root, batch_size, split='train',num_workers=2, aug=None, shuffle=True, labeled_list=range(5), unlabeled_list=range(5, 10), unlabeled_batch_size=64): 267 | dataset_labeled = CIFAR10Data(root, split, aug, labeled_list) 268 | dataset_unlabeled = CIFAR10Data(root, split, aug, unlabeled_list) 269 | dataset = ConcatDataset((dataset_labeled, dataset_unlabeled)) 270 | labeled_idxs = range(len(dataset_labeled)) 271 | unlabeled_idxs = range(len(dataset_labeled), len(dataset_labeled)+len(dataset_unlabeled)) 272 | batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, unlabeled_batch_size) 273 | loader = data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers) 274 | loader.labeled_length = len(dataset_labeled) 275 | loader.unlabeled_length = len(dataset_unlabeled) 276 | return loader 277 | 278 | 279 | def CIFAR100Data(root, split='train', aug=None, target_list=range(80)): 280 | if aug==None: 281 | transform = transforms.Compose([ 282 | transforms.ToTensor(), 283 | transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)), 284 | ]) 285 | elif aug=='once': 286 | transform = transforms.Compose([ 287 | transforms.RandomCrop(32, padding=4), 288 | transforms.RandomHorizontalFlip(), 289 | transforms.ToTensor(), 290 | transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)), 291 | ]) 292 | elif aug=='twice': 293 | transform = TransformTwice(transforms.Compose([ 294 | transforms.RandomCrop(32, padding=4), 295 | transforms.RandomHorizontalFlip(), 296 | transforms.ToTensor(), 297 | transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)), 298 | ])) 299 | dataset = CIFAR100(root=root, split=split, transform=transform, target_list=target_list) 300 | return dataset 301 | 302 | def CIFAR100Loader(root, batch_size, split='train', num_workers=2, aug=None, shuffle=True, target_list=range(80)): 303 | dataset = CIFAR100Data(root, split, aug,target_list) 304 | loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 305 | return loader 306 | 307 | def CIFAR100LoaderMix(root, batch_size, split='train',num_workers=2, aug=None, shuffle=True, labeled_list=range(80), unlabeled_list=range(90, 100)): 308 | dataset_labeled = CIFAR100Data(root, split, aug, labeled_list) 309 | dataset_unlabeled = CIFAR100Data(root, split, aug, unlabeled_list) 310 | dataset_labeled.targets = np.concatenate((dataset_labeled.targets,dataset_unlabeled.targets)) 311 | dataset_labeled.data = np.concatenate((dataset_labeled.data,dataset_unlabeled.data),0) 312 | loader = data.DataLoader(dataset_labeled, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 313 | return loader 314 | 315 | def CIFAR100LoaderTwoStream(root, batch_size, split='train',num_workers=2, aug=None, shuffle=True, labeled_list=range(80), unlabeled_list=range(90, 100), unlabeled_batch_size=32): 316 | dataset_labeled = CIFAR100Data(root, split, aug, labeled_list) 317 | dataset_unlabeled = CIFAR100Data(root, split, aug, unlabeled_list) 318 | dataset = ConcatDataset((dataset_labeled, dataset_unlabeled)) 319 | labeled_idxs = range(len(dataset_labeled)) 320 | unlabeled_idxs = range(len(dataset_labeled), len(dataset_labeled)+len(dataset_unlabeled)) 321 | batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, unlabeled_batch_size) 322 | loader = data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers) 323 | loader.labeled_length = len(dataset_labeled) 324 | loader.unlabeled_length = len(dataset_unlabeled) 325 | return loader -------------------------------------------------------------------------------- /auto_novel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.optim import SGD, lr_scheduler 5 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score 6 | from sklearn.metrics import adjusted_rand_score as ari_score 7 | from sklearn.cluster import KMeans 8 | from utils.util import BCE, PairEnum, cluster_acc, Identity, AverageMeter, seed_torch 9 | from utils import ramps 10 | from models.resnet import ResNet, BasicBlock 11 | from data.cifarloader import CIFAR10Loader, CIFAR10LoaderMix, CIFAR100Loader, CIFAR100LoaderMix 12 | from data.svhnloader import SVHNLoader, SVHNLoaderMix 13 | from tqdm import tqdm 14 | import numpy as np 15 | import os 16 | 17 | def train(model, train_loader, labeled_eval_loader, unlabeled_eval_loader, args): 18 | optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 19 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) 20 | criterion1 = nn.CrossEntropyLoss() 21 | criterion2 = BCE() 22 | for epoch in range(args.epochs): 23 | loss_record = AverageMeter() 24 | model.train() 25 | exp_lr_scheduler.step() 26 | w = args.rampup_coefficient * ramps.sigmoid_rampup(epoch, args.rampup_length) 27 | for batch_idx, ((x, x_bar), label, idx) in enumerate(tqdm(train_loader)): 28 | x, x_bar, label = x.to(device), x_bar.to(device), label.to(device) 29 | output1, output2, feat = model(x) 30 | output1_bar, output2_bar, _ = model(x_bar) 31 | prob1, prob1_bar, prob2, prob2_bar=F.softmax(output1, dim=1), F.softmax(output1_bar, dim=1), F.softmax(output2, dim=1), F.softmax(output2_bar, dim=1) 32 | 33 | mask_lb = label0] = -1 46 | 47 | prob1_ulb, _= PairEnum(prob2[~mask_lb]) 48 | _, prob2_ulb = PairEnum(prob2_bar[~mask_lb]) 49 | 50 | loss_ce = criterion1(output1[mask_lb], label[mask_lb]) 51 | loss_bce = criterion2(prob1_ulb, prob2_ulb, target_ulb) 52 | consistency_loss = F.mse_loss(prob1, prob1_bar) + F.mse_loss(prob2, prob2_bar) 53 | loss = loss_ce + loss_bce + w * consistency_loss 54 | 55 | loss_record.update(loss.item(), x.size(0)) 56 | optimizer.zero_grad() 57 | loss.backward() 58 | optimizer.step() 59 | print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) 60 | print('test on labeled classes') 61 | args.head = 'head1' 62 | test(model, labeled_eval_loader, args) 63 | print('test on unlabeled classes') 64 | args.head='head2' 65 | test(model, unlabeled_eval_loader, args) 66 | 67 | 68 | def train_IL(model, train_loader, labeled_eval_loader, unlabeled_eval_loader, args): 69 | optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 70 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) 71 | criterion1 = nn.CrossEntropyLoss() 72 | criterion2 = BCE() 73 | for epoch in range(args.epochs): 74 | loss_record = AverageMeter() 75 | model.train() 76 | exp_lr_scheduler.step() 77 | w = args.rampup_coefficient * ramps.sigmoid_rampup(epoch, args.rampup_length) 78 | for batch_idx, ((x, x_bar), label, idx) in enumerate(tqdm(train_loader)): 79 | x, x_bar, label = x.to(device), x_bar.to(device), label.to(device) 80 | output1, output2, feat = model(x) 81 | output1_bar, output2_bar, _ = model(x_bar) 82 | prob1, prob1_bar, prob2, prob2_bar = F.softmax(output1, dim=1), F.softmax(output1_bar, dim=1), F.softmax(output2, dim=1), F.softmax(output2_bar, dim=1) 83 | 84 | mask_lb = label < args.num_labeled_classes 85 | 86 | rank_feat = (feat[~mask_lb]).detach() 87 | 88 | rank_idx = torch.argsort(rank_feat, dim=1, descending=True) 89 | rank_idx1, rank_idx2 = PairEnum(rank_idx) 90 | rank_idx1, rank_idx2 = rank_idx1[:, :args.topk], rank_idx2[:, :args.topk] 91 | 92 | rank_idx1, _ = torch.sort(rank_idx1, dim=1) 93 | rank_idx2, _ = torch.sort(rank_idx2, dim=1) 94 | 95 | rank_diff = rank_idx1 - rank_idx2 96 | rank_diff = torch.sum(torch.abs(rank_diff), dim=1) 97 | target_ulb = torch.ones_like(rank_diff).float().to(device) 98 | target_ulb[rank_diff > 0] = -1 99 | 100 | prob1_ulb, _ = PairEnum(prob2[~mask_lb]) 101 | _, prob2_ulb = PairEnum(prob2_bar[~mask_lb]) 102 | 103 | loss_ce = criterion1(output1[mask_lb], label[mask_lb]) 104 | 105 | label[~mask_lb] = (output2[~mask_lb]).detach().max(1)[1] + args.num_labeled_classes 106 | 107 | loss_ce_add = w * criterion1(output1[~mask_lb], label[~mask_lb]) / args.rampup_coefficient * args.increment_coefficient 108 | loss_bce = criterion2(prob1_ulb, prob2_ulb, target_ulb) 109 | consistency_loss = F.mse_loss(prob1, prob1_bar) + F.mse_loss(prob2, prob2_bar) 110 | 111 | loss = loss_ce + loss_bce + loss_ce_add + w * consistency_loss 112 | 113 | loss_record.update(loss.item(), x.size(0)) 114 | optimizer.zero_grad() 115 | loss.backward() 116 | optimizer.step() 117 | print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) 118 | print('test on labeled classes') 119 | args.head = 'head1' 120 | test(model, labeled_eval_loader, args) 121 | print('test on unlabeled classes') 122 | args.head='head2' 123 | test(model, unlabeled_eval_loader, args) 124 | 125 | def test(model, test_loader, args): 126 | model.eval() 127 | preds=np.array([]) 128 | targets=np.array([]) 129 | for batch_idx, (x, label, _) in enumerate(tqdm(test_loader)): 130 | x, label = x.to(device), label.to(device) 131 | output1, output2, _ = model(x) 132 | if args.head=='head1': 133 | output = output1 134 | else: 135 | output = output2 136 | _, pred = output.max(1) 137 | targets=np.append(targets, label.cpu().numpy()) 138 | preds=np.append(preds, pred.cpu().numpy()) 139 | acc, nmi, ari = cluster_acc(targets.astype(int), preds.astype(int)), nmi_score(targets, preds), ari_score(targets, preds) 140 | print('Test acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(acc, nmi, ari)) 141 | 142 | if __name__ == "__main__": 143 | import argparse 144 | parser = argparse.ArgumentParser( 145 | description='cluster', 146 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 147 | parser.add_argument('--lr', type=float, default=0.1) 148 | parser.add_argument('--gamma', type=float, default=0.1) 149 | parser.add_argument('--momentum', type=float, default=0.9) 150 | parser.add_argument('--weight_decay', type=float, default=1e-4) 151 | parser.add_argument('--epochs', default=200, type=int) 152 | parser.add_argument('--rampup_length', default=150, type=int) 153 | parser.add_argument('--rampup_coefficient', type=float, default=50) 154 | parser.add_argument('--increment_coefficient', type=float, default=0.05) 155 | parser.add_argument('--step_size', default=170, type=int) 156 | parser.add_argument('--batch_size', default=128, type=int) 157 | parser.add_argument('--num_unlabeled_classes', default=5, type=int) 158 | parser.add_argument('--num_labeled_classes', default=5, type=int) 159 | parser.add_argument('--dataset_root', type=str, default='./data/datasets/CIFAR/') 160 | parser.add_argument('--exp_root', type=str, default='./data/experiments/') 161 | parser.add_argument('--warmup_model_dir', type=str, default='./data/experiments/pretrain/auto_novel/resnet_rotnet_cifar10.pth') 162 | parser.add_argument('--topk', default=5, type=int) 163 | parser.add_argument('--IL', action='store_true', default=False, help='w/ incremental learning') 164 | parser.add_argument('--model_name', type=str, default='resnet') 165 | parser.add_argument('--dataset_name', type=str, default='cifar10', help='options: cifar10, cifar100, svhn') 166 | parser.add_argument('--seed', default=1, type=int) 167 | parser.add_argument('--mode', type=str, default='train') 168 | args = parser.parse_args() 169 | args.cuda = torch.cuda.is_available() 170 | device = torch.device("cuda" if args.cuda else "cpu") 171 | seed_torch(args.seed) 172 | runner_name = os.path.basename(__file__).split(".")[0] 173 | model_dir= os.path.join(args.exp_root, runner_name) 174 | if not os.path.exists(model_dir): 175 | os.makedirs(model_dir) 176 | args.model_dir = model_dir+'/'+'{}.pth'.format(args.model_name) 177 | 178 | model = ResNet(BasicBlock, [2,2,2,2], args.num_labeled_classes, args.num_unlabeled_classes).to(device) 179 | 180 | num_classes = args.num_labeled_classes + args.num_unlabeled_classes 181 | 182 | if args.mode=='train': 183 | state_dict = torch.load(args.warmup_model_dir) 184 | model.load_state_dict(state_dict, strict=False) 185 | for name, param in model.named_parameters(): 186 | if 'head' not in name and 'layer4' not in name: 187 | param.requires_grad = False 188 | 189 | if args.dataset_name == 'cifar10': 190 | mix_train_loader = CIFAR10LoaderMix(root=args.dataset_root, batch_size=args.batch_size, split='train', aug='twice', shuffle=True, labeled_list=range(args.num_labeled_classes), unlabeled_list=range(args.num_labeled_classes, num_classes)) 191 | labeled_train_loader = CIFAR10Loader(root=args.dataset_root, batch_size=args.batch_size, split='train', aug='once', shuffle=True, target_list = range(args.num_labeled_classes)) 192 | unlabeled_eval_loader = CIFAR10Loader(root=args.dataset_root, batch_size=args.batch_size, split='train', aug=None, shuffle=False, target_list = range(args.num_labeled_classes, num_classes)) 193 | unlabeled_eval_loader_test = CIFAR10Loader(root=args.dataset_root, batch_size=args.batch_size, split='test', aug=None, shuffle=False, target_list = range(args.num_labeled_classes, num_classes)) 194 | labeled_eval_loader = CIFAR10Loader(root=args.dataset_root, batch_size=args.batch_size, split='test', aug=None, shuffle=False, target_list = range(args.num_labeled_classes)) 195 | all_eval_loader = CIFAR10Loader(root=args.dataset_root, batch_size=args.batch_size, split='test', aug=None, shuffle=False, target_list = range(num_classes)) 196 | elif args.dataset_name == 'cifar100': 197 | mix_train_loader = CIFAR100LoaderMix(root=args.dataset_root, batch_size=args.batch_size, split='train', aug='twice', shuffle=True, labeled_list=range(args.num_labeled_classes), unlabeled_list=range(args.num_labeled_classes, num_classes)) 198 | labeled_train_loader = CIFAR100Loader(root=args.dataset_root, batch_size=args.batch_size, split='train', aug='once', shuffle=True, target_list = range(args.num_labeled_classes)) 199 | unlabeled_eval_loader = CIFAR100Loader(root=args.dataset_root, batch_size=args.batch_size, split='train', aug=None, shuffle=False, target_list = range(args.num_labeled_classes, num_classes)) 200 | unlabeled_eval_loader_test = CIFAR100Loader(root=args.dataset_root, batch_size=args.batch_size, split='test', aug=None, shuffle=False, target_list = range(args.num_labeled_classes, num_classes)) 201 | labeled_eval_loader = CIFAR100Loader(root=args.dataset_root, batch_size=args.batch_size, split='test', aug=None, shuffle=False, target_list = range(args.num_labeled_classes)) 202 | all_eval_loader = CIFAR100Loader(root=args.dataset_root, batch_size=args.batch_size, split='test', aug=None, shuffle=False, target_list = range(num_classes)) 203 | elif args.dataset_name == 'svhn': 204 | mix_train_loader = SVHNLoaderMix(root=args.dataset_root, batch_size=args.batch_size, split='train', aug='twice', shuffle=True, labeled_list=range(args.num_labeled_classes), unlabeled_list=range(args.num_labeled_classes, num_classes)) 205 | labeled_train_loader = SVHNLoader(root=args.dataset_root, batch_size=args.batch_size, split='train', aug='once', shuffle=True, target_list = range(args.num_labeled_classes)) 206 | unlabeled_eval_loader = SVHNLoader(root=args.dataset_root, batch_size=args.batch_size, split='train', aug=None, shuffle=False, target_list = range(args.num_labeled_classes, num_classes)) 207 | unlabeled_eval_loader_test = SVHNLoader(root=args.dataset_root, batch_size=args.batch_size, split='test', aug=None, shuffle=False, target_list = range(args.num_labeled_classes, num_classes)) 208 | labeled_eval_loader = SVHNLoader(root=args.dataset_root, batch_size=args.batch_size, split='test', aug=None, shuffle=False, target_list = range(args.num_labeled_classes)) 209 | all_eval_loader = SVHNLoader(root=args.dataset_root, batch_size=args.batch_size, split='test', aug=None, shuffle=False, target_list = range(num_classes)) 210 | 211 | if args.mode == 'train': 212 | if args.IL: 213 | save_weight = model.head1.weight.data.clone() 214 | save_bias = model.head1.bias.data.clone() 215 | model.head1 = nn.Linear(512, num_classes).to(device) 216 | model.head1.weight.data[:args.num_labeled_classes] = save_weight 217 | model.head1.bias.data[:] = torch.min(save_bias) - 1. 218 | model.head1.bias.data[:args.num_labeled_classes] = save_bias 219 | train_IL(model, mix_train_loader, labeled_eval_loader, unlabeled_eval_loader, args) 220 | else: 221 | train(model, mix_train_loader, labeled_eval_loader, unlabeled_eval_loader, args) 222 | torch.save(model.state_dict(), args.model_dir) 223 | print("model saved to {}.".format(args.model_dir)) 224 | else: 225 | print("model loaded from {}.".format(args.model_dir)) 226 | if args.IL: 227 | model.head1 = nn.Linear(512, num_classes).to(device) 228 | model.load_state_dict(torch.load(args.model_dir)) 229 | 230 | print('Evaluating on Head1') 231 | args.head = 'head1' 232 | print('test on labeled classes (test split)') 233 | test(model, labeled_eval_loader, args) 234 | if args.IL: 235 | print('test on unlabeled classes (test split)') 236 | test(model, unlabeled_eval_loader_test, args) 237 | print('test on all classes (test split)') 238 | test(model, all_eval_loader, args) 239 | print('Evaluating on Head2') 240 | args.head = 'head2' 241 | print('test on unlabeled classes (train split)') 242 | test(model, unlabeled_eval_loader, args) 243 | print('test on unlabeled classes (test split)') 244 | test(model, unlabeled_eval_loader_test, args) --------------------------------------------------------------------------------