├── .gitignore ├── LICENSE ├── README.md ├── check_dataset.py ├── check_model.py ├── cub200.py ├── dog.py ├── logs └── tinyimagenet-200-resnet32 │ └── 0 │ └── model_best.pth.tar ├── models ├── modules.py ├── resnet_cifar.py ├── resnet_ilsvrc.py └── vgg_cifar.py ├── split ├── cifar100-test ├── cifar100-train ├── cifar100-val ├── cub200-test ├── cub200-train ├── cub200-val ├── dog-test ├── dog-train ├── dog-val ├── stl10-test ├── stl10-train └── stl10-val ├── train └── meta_optimizers.py ├── train_l2t_ww.py └── utils └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 alinlab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning What and Where to Transfer (ICML 2019) 2 | Learning What and Where to Transfer (ICML 2019) 3 | https://arxiv.org/abs/1905.05901 4 | 5 | ## Requirements 6 | 7 | - `python>=3.6` 8 | - `pytorch>=1.0` 9 | - `torchvision` 10 | - `cuda>=9.0` 11 | 12 | **Note.** The reported results in our paper were obtained in the old-version pytorch (`pytorch=1.0`, `cuda=9.0`). We recently executed again the experiment commands as described below using the recent version (`pytorch=1.6.0`, `torchvision=0.7.0`, `cuda=10.1`), and obtained similar results as reported in the paper. 13 | 14 | ## Prepare Datasets 15 | 16 | You can download CUB-200 and Stanford Dogs datasets 17 | - CUB-200: from http://www.vision.caltech.edu/visipedia/CUB-200-2011.html 18 | - Stanford Dogs: http://vision.stanford.edu/aditya86/ImageNetDogs/ 19 | 20 | You need to run the below pre-processing script for DataLoader. 21 | 22 | ```bash 23 | python cub200.py /data/CUB_200_2011 24 | python dog.py /data/dog 25 | ``` 26 | 27 | ## Train L2T-ww 28 | 29 | You can train L2T-ww models with the same settings in our paper. 30 | 31 | ```bash 32 | python train_l2t_ww.py --dataset cub200 --datasplit cub200 --dataroot /data/CUB_200_2011 33 | python train_l2t_ww.py --dataset dog --datasplit dog --dataroot /data/dog 34 | python train_l2t_ww.py --dataset cifar100 --datasplit cifar100 --dataroot /data/ --experiment logs/cifar100_0/ --source-path logs --source-model resnet32 --source-domain tinyimagenet-200 --target-model vgg9_bn --pairs 4-0,4-1,4-2,4-3,4-4,9-0,9-1,9-2,9-3,9-4,14-0,14-1,14-2,14-3,14-4 --batchSize 128 35 | python train_l2t_ww.py --dataset stl10 --datasplit stl10 --dataroot /data/ --experiment logs/stl10_0/ --source-path logs --source-model resnet32 --source-domain tinyimagenet-200 --target-model vgg9_bn --pairs 4-0,4-1,4-2,4-3,4-4,9-0,9-1,9-2,9-3,9-4,14-0,14-1,14-2,14-3,14-4 --batchSize 128 36 | ``` 37 | 38 | -------------------------------------------------------------------------------- /check_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.transforms as transforms 4 | import torchvision.datasets as dset 5 | import torch.utils.data as data 6 | 7 | 8 | class FolderSubset(data.Dataset): 9 | def __init__(self, dataset, classes, indices): 10 | self.dataset = dataset 11 | self.classes = classes 12 | self.indices = indices 13 | 14 | self.update_classes() 15 | 16 | def update_classes(self): 17 | for i in self.indices: 18 | img_path, cls = self.dataset.samples[i] 19 | cls = self.classes.index(cls) 20 | self.dataset.samples[i] = (img_path, cls) 21 | 22 | def __getitem__(self, idx): 23 | return self.dataset[self.indices[idx]] 24 | 25 | def __len__(self): 26 | return len(self.indices) 27 | 28 | class STL10Subset(data.Dataset): 29 | def __init__(self, dataset, classes, indices): 30 | self.dataset = dataset 31 | self.classes = classes 32 | self.indices = indices 33 | 34 | def __getitem__(self, idx): 35 | return self.dataset[self.indices[idx]] 36 | 37 | def __len__(self): 38 | return len(self.indices) 39 | 40 | class CIFARSubset(data.Dataset): 41 | def __init__(self, dataset, classes, indices): 42 | self.dataset = dataset 43 | self.classes = classes 44 | self.indices = indices 45 | 46 | # self.update_classes() 47 | 48 | # def update_classes(self): 49 | # for i in self.indices: 50 | # if self.dataset.train: 51 | # self.dataset.train_labels[i] = self.classes.index(self.dataset.train_labels[i]) 52 | # else: 53 | # self.dataset.test_labels[i] = self.classes.index(self.dataset.test_labels[i]) 54 | 55 | def __getitem__(self, idx): 56 | return self.dataset[self.indices[idx]] 57 | 58 | def __len__(self): 59 | return len(self.indices) 60 | 61 | 62 | def check_split(opt): 63 | splits = [] 64 | for split in ['train', 'val', 'test']: 65 | splits.append(torch.load('split/' + opt.datasplit + '-' + split)) 66 | 67 | return splits 68 | 69 | 70 | def check_dataset(opt): 71 | normalize_transform = transforms.Compose([transforms.ToTensor(), 72 | transforms.Normalize((0.485, 0.456, 0.406), 73 | (0.229, 0.224, 0.225))]) 74 | train_large_transform = transforms.Compose([transforms.RandomResizedCrop(224), 75 | transforms.RandomHorizontalFlip()]) 76 | val_large_transform = transforms.Compose([transforms.Resize(256), 77 | transforms.CenterCrop(224)]) 78 | train_small_transform = transforms.Compose([transforms.Pad(4), 79 | transforms.RandomCrop(32), 80 | transforms.RandomHorizontalFlip()]) 81 | 82 | splits = check_split(opt) 83 | 84 | if opt.dataset in ['cub200', 'indoor', 'stanford40', 'dog']: 85 | train, val = 'train', 'test' 86 | train_transform = transforms.Compose([train_large_transform, normalize_transform]) 87 | val_transform = transforms.Compose([val_large_transform, normalize_transform]) 88 | sets = [dset.ImageFolder(root=os.path.join(opt.dataroot, train), transform=train_transform), 89 | dset.ImageFolder(root=os.path.join(opt.dataroot, train), transform=val_transform), 90 | dset.ImageFolder(root=os.path.join(opt.dataroot, val), transform=val_transform)] 91 | sets = [FolderSubset(dataset, *split) for dataset, split in zip(sets, splits)] 92 | 93 | opt.num_classes = len(splits[0][0]) 94 | 95 | elif opt.dataset == 'stl10': 96 | train_transform = transforms.Compose([transforms.Resize(32), 97 | train_small_transform, normalize_transform]) 98 | val_transform = transforms.Compose([transforms.Resize(32), normalize_transform]) 99 | sets = [dset.STL10(opt.dataroot, split='train', transform=train_transform, download=True), 100 | dset.STL10(opt.dataroot, split='train', transform=val_transform, download=True), 101 | dset.STL10(opt.dataroot, split='test', transform=val_transform, download=True)] 102 | sets = [STL10Subset(dataset, *split) for dataset, split in zip(sets, splits)] 103 | 104 | opt.num_classes = len(splits[0][0]) 105 | 106 | elif opt.dataset in ['cifar10', 'cifar100']: 107 | train_transform = transforms.Compose([train_small_transform, normalize_transform]) 108 | val_transform = normalize_transform 109 | CIFAR = dset.CIFAR10 if opt.dataset == 'cifar10' else dset.CIFAR100 110 | 111 | sets = [CIFAR(opt.dataroot, download=True, train=True, transform=train_transform), 112 | CIFAR(opt.dataroot, download=True, train=True, transform=val_transform), 113 | CIFAR(opt.dataroot, download=True, train=False, transform=val_transform)] 114 | sets = [CIFARSubset(dataset, *split) for dataset, split in zip(sets, splits)] 115 | 116 | opt.num_classes = len(splits[0][0]) 117 | 118 | else: 119 | raise Exception('Unknown dataset') 120 | 121 | loaders = [torch.utils.data.DataLoader(dataset, 122 | batch_size=opt.batchSize, 123 | shuffle=True, 124 | num_workers=0) for dataset in sets] 125 | return loaders 126 | -------------------------------------------------------------------------------- /check_model.py: -------------------------------------------------------------------------------- 1 | from models import resnet_ilsvrc 2 | from models import resnet_cifar as cresnet, vgg_cifar as cvgg 3 | 4 | 5 | def check_model(opt): 6 | if opt.model.startswith('resnet'): 7 | if opt.dataset in ['cub200', 'indoor', 'stanford40', 'flowers102', 'dog', 'tinyimagenet']: 8 | ResNet = resnet_ilsvrc.__dict__[opt.model] 9 | model = ResNet(num_classes=opt.num_classes) 10 | else: 11 | ResNet = cresnet.__dict__[opt.model] 12 | model = ResNet(num_classes=opt.num_classes) 13 | 14 | return model 15 | 16 | elif opt.model.startswith('vgg'): 17 | VGG = cvgg.__dict__[opt.model] 18 | model = VGG(num_classes=opt.num_classes) 19 | 20 | return model 21 | 22 | else: 23 | raise Exception('Unknown model') 24 | -------------------------------------------------------------------------------- /cub200.py: -------------------------------------------------------------------------------- 1 | import os, sys, shutil 2 | 3 | """ 4 | Usage: 5 | python cub200.py /data/CUB_200_2011 6 | """ 7 | 8 | def read(filename): 9 | with open(filename) as f: 10 | return f.readlines() 11 | 12 | def main(): 13 | datadir = sys.argv[1] 14 | images = read(os.path.join(datadir, 'images.txt')) 15 | splits = read(os.path.join(datadir, 'train_test_split.txt')) 16 | assert len(images) == len(splits) 17 | paths = {'train': [], 'test': []} 18 | for filename, split in zip(images, splits): 19 | idx1, filename = filename.split() 20 | idx2, split = split.split() 21 | 22 | assert idx1 == idx2 23 | if split == '1': 24 | paths['train'].append(filename) 25 | else: 26 | paths['test'].append(filename) 27 | print('# of training images:', len(paths['train'])) 28 | print('# of test images:', len(paths['test'])) 29 | 30 | counter = 0 31 | for split in ['train', 'test']: 32 | for d in sorted(os.listdir(os.path.join(datadir, 'images'))): 33 | os.makedirs(os.path.join(datadir, split, d)) 34 | 35 | for p in paths[split]: 36 | shutil.copy(os.path.join(datadir, 'images', p), 37 | os.path.join(datadir, split, p)) 38 | counter += 1 39 | if counter % 100 == 0: 40 | print('.', end='') 41 | print('Done') 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /dog.py: -------------------------------------------------------------------------------- 1 | import os, sys, shutil 2 | from scipy import io 3 | 4 | """ 5 | Usage: 6 | python scripts/dog.py /data/dog 7 | """ 8 | 9 | def read(filename): 10 | with open(filename) as f: 11 | return f.readlines() 12 | 13 | def main(): 14 | datadir = sys.argv[1] 15 | count = 0 16 | for split in ['train', 'test']: 17 | for c in os.listdir(os.path.join(datadir, 'Images')): 18 | os.makedirs(os.path.join(datadir, split, c)) 19 | files = io.loadmat(os.path.join(datadir, split + '_list.mat'))['file_list'] 20 | for f in files: 21 | shutil.copy(os.path.join(datadir, 'Images', f[0][0]), 22 | os.path.join(datadir, split, f[0][0])) 23 | count += 1 24 | print(count, 'Done') 25 | 26 | if __name__ == '__main__': 27 | main() 28 | -------------------------------------------------------------------------------- /logs/tinyimagenet-200-resnet32/0/model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/L2T-ww/f4dde04e8d5d5725dc3bff2f59cb0d0c26d0bcbe/logs/tinyimagenet-200-resnet32/0/model_best.pth.tar -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class View(nn.Module): 4 | def __init__(self, *size): 5 | super(View, self).__init__() 6 | self.size = size 7 | 8 | def forward(self, x): 9 | return x.view(x.size()[0], *self.size) 10 | 11 | class BasicResBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, inplanes, planes, stride=1, downsample=None, batchnorm_affine=True): 15 | super(BasicResBlock, self).__init__() 16 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 17 | self.bn1 = nn.BatchNorm2d(planes, affine=batchnorm_affine) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes, affine=batchnorm_affine) 21 | self.downsample = downsample 22 | self.stride = stride 23 | 24 | def forward(self, x): 25 | residual = x 26 | 27 | out = self.conv1(x) 28 | #out = self.bn1(out) 29 | out = self.relu(out) 30 | 31 | out = self.conv2(out) 32 | #out = self.bn2(out) 33 | 34 | if self.downsample is not None: 35 | residual = self.downsample(x) 36 | 37 | out += residual 38 | out = self.relu(out) 39 | 40 | return out 41 | 42 | def conv3x3(in_planes, out_planes, stride=1): 43 | """3x3 convolution with padding""" 44 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 45 | padding=1, bias=False) 46 | 47 | 48 | class Bottleneck(nn.Module): 49 | expansion = 4 50 | 51 | def __init__(self, inplanes, planes, stride=1, downsample=None): 52 | super(Bottleneck, self).__init__() 53 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(planes) 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 56 | padding=1, bias=False) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 59 | self.bn3 = nn.BatchNorm2d(planes * 4) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.downsample = downsample 62 | self.stride = stride 63 | 64 | def forward(self, x): 65 | residual = x 66 | 67 | out = self.conv1(x) 68 | out = self.bn1(out) 69 | out = self.relu(out) 70 | 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv3(out) 76 | out = self.bn3(out) 77 | 78 | if self.downsample is not None: 79 | residual = self.downsample(x) 80 | 81 | out += residual 82 | out = self.relu(out) 83 | 84 | return out 85 | -------------------------------------------------------------------------------- /models/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from torchvision.models.resnet import BasicBlock, Bottleneck 6 | from .modules import View, Bottleneck 7 | 8 | __all__ = ['CResNet', 'cresnet14', 'cresnet20', 'cresnet32', 'cresnet44', 'resnet101', 9 | 'resnet152'] 10 | 11 | 12 | class CResNet(nn.Module): 13 | 14 | def __init__(self, n, block, num_classes=10, lwf=False, num_source_cls=200, growing=False): 15 | self.inplanes = 16 16 | self.growing = growing 17 | super(CResNet, self).__init__() 18 | self.conv1 = nn.Conv2d(3, 16, 3, padding=1) 19 | self.bn1 = nn.BatchNorm2d(16) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.block1 = BasicBlock(16, 16) 22 | self.blocks2 = [] 23 | for i in range(n): 24 | self.blocks2.append(block(16)) 25 | self.blocks2 = nn.Sequential(*self.blocks2) 26 | downsample = nn.Sequential(nn.Conv2d(16, 32, kernel_size=1, stride=2, bias=False), 27 | nn.BatchNorm2d(32)) 28 | self.block3 = BasicBlock(16, 32, 2, downsample) 29 | self.blocks4 = [] 30 | for i in range(n): 31 | self.blocks4.append(block(32)) 32 | self.blocks4 = nn.Sequential(*self.blocks4) 33 | downsample = nn.Sequential(nn.Conv2d(32, 64, kernel_size=1, stride=2, bias=False), 34 | nn.BatchNorm2d(64)) 35 | self.block5 = BasicBlock(32, 64, 2, downsample) 36 | self.blocks6 = [] 37 | for i in range(n): 38 | self.blocks6.append(block(64)) 39 | self.blocks6 = nn.Sequential(*self.blocks6) 40 | if self.growing: 41 | self.block5_add = BasicBlock(32, 64, 2, downsample) 42 | self.blocks6_add = [] 43 | for i in range(n): 44 | self.blocks6_add.append(block(64)) 45 | self.blocks6_add = nn.Sequential(*self.blocks6_add) 46 | self.gamma = nn.Parameter(torch.zeros(1).fill_(10)) 47 | self.avgpool = nn.AvgPool2d(8) 48 | self.view = View(-1) 49 | 50 | self.fc = nn.Linear(64,num_classes) 51 | 52 | self.lwf = lwf 53 | if self.lwf: 54 | self.lwf_lyr = nn.Linear(64, num_source_cls) 55 | 56 | self.alphas = nn.ParameterList([nn.Parameter(torch.rand(3, 1, 1, 1)*0.1), 57 | nn.Parameter(torch.rand(3, 1, 1, 1)*0.1), 58 | nn.Parameter(torch.rand(3, 1, 1, 1)*0.1)]) 59 | 60 | if self.growing: 61 | self.fc = nn.Linear(64*2, num_classes) 62 | 63 | for m in self.modules(): 64 | if isinstance(m, nn.Conv2d): 65 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 66 | m.weight.data.normal_(0, math.sqrt(2. / n)) 67 | if m.bias is not None: 68 | m.bias.data.zero_() 69 | elif isinstance(m, nn.BatchNorm2d): 70 | m.weight.data.fill_(1) 71 | m.bias.data.zero_() 72 | 73 | 74 | def forward(self, x, i=None): 75 | c1 = self.conv1(x) 76 | b1 = self.bn1(c1) 77 | r1 = self.relu(b1) 78 | f0 = self.block1(r1) 79 | 80 | # for i in range(len(self.blocks2)): 81 | # f = self.blocks2[i](f) 82 | f1 = self.blocks2(f0) 83 | f2 = self.block3(f1) 84 | # for i in range(len(self.blocks4)): 85 | # if i == 0: 86 | # f3 = self.blocks4[i](f2) 87 | # else: 88 | # f3 = self.blocks4[i](f3) 89 | f3 = self.blocks4(f2) 90 | f4 = self.block5(f3) 91 | # for i in range(len(self.blocks6)): 92 | # if i == 0: 93 | # f5 = self.blocks6[i](f4) 94 | # else: 95 | # f5 = self.blocks6[i](f5) 96 | f5 = self.blocks6(f4) 97 | 98 | f6 = self.avgpool(f5) 99 | f7 = self.view(f6) 100 | if self.lwf: 101 | old_out = self.lwf_lyr(f7) 102 | 103 | f7 = self.fc(f7) 104 | #f8 = self.fc(f7) 105 | 106 | if self.lwf: 107 | return x, [r1, f1, f3, f5], old_out 108 | else: 109 | return f7, [r1, f1, f3, f5] 110 | 111 | def forward_with_features(self, x): 112 | feat = [] 113 | x = self.conv1(x) 114 | x = self.bn1(x) 115 | x = self.relu(x) 116 | 117 | x = self.block1(x) 118 | feat = [x] 119 | for i in range(len(self.blocks2)): 120 | x = self.blocks2[i](x) 121 | feat.append(x) 122 | 123 | x = self.block3(x) 124 | feat.append(x) 125 | for i in range(len(self.blocks4)): 126 | x = self.blocks4[i](x) 127 | feat.append(x) 128 | 129 | x = self.block5(x) 130 | feat.append(x) 131 | for i in range(len(self.blocks6)): 132 | x = self.blocks6[i](x) 133 | feat.append(x) 134 | 135 | x = self.avgpool(x) 136 | x = self.view(x) 137 | x = self.fc(x) 138 | return x, feat 139 | 140 | def resnet32(num_classes=10, growing=False): 141 | model = CResNet(4, block=lambda k: BasicBlock(k, k),num_classes=num_classes, growing=growing) 142 | 143 | return model 144 | 145 | -------------------------------------------------------------------------------- /models/resnet_ilsvrc.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152'] 8 | 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 65 | padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * 4) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | 98 | def __init__(self, block, layers, num_classes=1000, meta=None): 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 110 | self.avgpool = nn.AvgPool2d(7, stride=1) 111 | self.fc = nn.Linear(512 * block.expansion, num_classes) 112 | self.lwf = False 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | elif isinstance(m, nn.BatchNorm2d): 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | 122 | def _make_layer(self, block, planes, blocks, stride=1): 123 | downsample = None 124 | if stride != 1 or self.inplanes != planes * block.expansion: 125 | downsample = nn.Sequential( 126 | nn.Conv2d(self.inplanes, planes * block.expansion, 127 | kernel_size=1, stride=stride, bias=False), 128 | nn.BatchNorm2d(planes * block.expansion), 129 | ) 130 | 131 | layers = [] 132 | layers.append(block(self.inplanes, planes, stride, downsample)) 133 | self.inplanes = planes * block.expansion 134 | for i in range(1, blocks): 135 | layers.append(block(self.inplanes, planes)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def forward(self, x): 140 | f1 = self.conv1(x) 141 | b1 = self.bn1(f1) 142 | r1 = self.relu(b1) 143 | p1 = self.maxpool(r1) 144 | 145 | f2 = self.layer1(p1) 146 | f3 = self.layer2(f2) 147 | f4 = self.layer3(f3) 148 | f5 = self.layer4(f4) 149 | 150 | f6 = self.avgpool(f5) 151 | f6 = f6.view(f6.size(0), -1) 152 | f7 = self.fc(f6) 153 | 154 | return f7, [r1, f2, f3, f4, f5] 155 | 156 | def forward_with_features(self, x): 157 | return self.forward(x) 158 | 159 | 160 | def resnet18(pretrained=False, meta=False, **kwargs): 161 | """Constructs a ResNet-18 model. 162 | 163 | Args: 164 | pretrained (bool): If True, returns a model pre-trained on ImageNet 165 | """ 166 | if meta: 167 | model = ResNet_meta(BasicBlock, [2, 2, 2, 2], **kwargs) 168 | if pretrained: 169 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 170 | else: 171 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 172 | if pretrained: 173 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 174 | return model 175 | 176 | 177 | def resnet34(pretrained=False, meta=False, **kwargs): 178 | """Constructs a ResNet-34 model. 179 | 180 | Args: 181 | pretrained (bool): If True, returns a model pre-trained on ImageNet 182 | """ 183 | if meta: 184 | model = ResNet_meta(BasicBlock, [3, 4, 6, 3], **kwargs) 185 | if pretrained: 186 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 187 | else: 188 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 189 | if pretrained: 190 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 191 | return model 192 | 193 | 194 | def resnet50(pretrained=False, **kwargs): 195 | """Constructs a ResNet-50 model. 196 | 197 | Args: 198 | pretrained (bool): If True, returns a model pre-trained on ImageNet 199 | """ 200 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 201 | if pretrained: 202 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 203 | return model 204 | 205 | 206 | def resnet101(pretrained=False, **kwargs): 207 | """Constructs a ResNet-101 model. 208 | 209 | Args: 210 | pretrained (bool): If True, returns a model pre-trained on ImageNet 211 | """ 212 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 213 | if pretrained: 214 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 215 | return model 216 | 217 | 218 | def resnet152(pretrained=False, **kwargs): 219 | """Constructs a ResNet-152 model. 220 | 221 | Args: 222 | pretrained (bool): If True, returns a model pre-trained on ImageNet 223 | """ 224 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 225 | if pretrained: 226 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 227 | return model 228 | -------------------------------------------------------------------------------- /models/vgg_cifar.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.model_zoo as model_zoo 7 | import torch.nn.functional as F 8 | import math 9 | 10 | 11 | __all__ = [ 12 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 13 | 'vgg19_bn', 'vgg19', 'vgg9_bn' 14 | ] 15 | 16 | 17 | ## For models pre-trained on ImageNet 18 | #model_urls = { 19 | # 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 20 | # 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 21 | # 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 22 | # 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 23 | # 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 24 | # 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 25 | # 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 26 | # 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 27 | #} 28 | 29 | 30 | class VGG(nn.Module): 31 | 32 | def __init__(self, features, num_classes=10, init_weights=True): 33 | super(VGG, self).__init__() 34 | self.features = features 35 | self.classifier = nn.Sequential( 36 | nn.Linear(512, 512), 37 | nn.ReLU(True), 38 | nn.Dropout(), 39 | nn.Linear(512, 512), 40 | nn.ReLU(True), 41 | nn.Dropout(), 42 | nn.Linear(512, num_classes), 43 | ) 44 | if init_weights: 45 | self._initialize_weights() 46 | 47 | def forward(self, x): 48 | feat = [] 49 | for layer in self.features: 50 | if isinstance(layer, nn.MaxPool2d): 51 | feat.append(x) 52 | x = layer(x) 53 | x = x.view(x.size(0), -1) 54 | x = self.classifier(x) 55 | return x, feat 56 | 57 | def _initialize_weights(self): 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 61 | if m.bias is not None: 62 | nn.init.constant_(m.bias, 0) 63 | elif isinstance(m, nn.BatchNorm2d): 64 | nn.init.constant_(m.weight, 1) 65 | nn.init.constant_(m.bias, 0) 66 | elif isinstance(m, nn.Linear): 67 | nn.init.normal_(m.weight, 0, 0.01) 68 | nn.init.constant_(m.bias, 0) 69 | 70 | 71 | # The customized model from https://arxiv.org/abs/1803.00443 72 | class VGG_small(nn.Module): 73 | 74 | def __init__(self, features, num_classes=10, init_weights=True, lwf=False, num_source_cls=200, no_ft=False): 75 | super(VGG_small, self).__init__() 76 | self.features = features 77 | 78 | self.num_classes = num_classes 79 | if isinstance(num_classes, list): 80 | fcs = [] 81 | for i in range(len(num_classes)): 82 | fcs.append(nn.Linear(512, num_classes[i])) 83 | self.classifier = nn.ModuleList(fcs) 84 | else: 85 | self.classifier = nn.Linear(512, num_classes) 86 | 87 | #self.classifier = nn.Linear(512, num_classes) 88 | self.lwf = lwf 89 | if self.lwf: 90 | self.lwf_lyr = nn.Linear(512, num_source_cls) 91 | 92 | self.no_ft = no_ft 93 | if self.no_ft: 94 | self.outputs_branch = nn.ModuleList( 95 | [nn.Linear(64, num_classes), 96 | nn.Linear(128, num_classes), 97 | nn.Linear(256, num_classes)]) 98 | 99 | self.alphas = nn.ParameterList([nn.Parameter(torch.rand(3, 1, 1, 1)*0.1), 100 | nn.Parameter(torch.rand(3, 1, 1, 1)*0.1), 101 | nn.Parameter(torch.rand(3, 1, 1, 1)*0.1)]) 102 | 103 | self.new_classifier = nn.Linear(512, num_classes) 104 | self.new_bn = nn.ModuleList() 105 | for layer in self.features: 106 | if isinstance(layer, nn.BatchNorm2d): 107 | self.new_bn.append(nn.BatchNorm2d(layer.num_features)) 108 | 109 | self.w1 = nn.Linear(64, 16) 110 | self.w2 = nn.Linear(128, 32) 111 | self.w3 = nn.Linear(256, 64) 112 | self.w = nn.ModuleList([self.w1, self.w2, self.w3]) 113 | if init_weights: 114 | self._initialize_weights() 115 | 116 | def forward(self, x, idx=-1): 117 | feat = [] 118 | for layer in self.features: 119 | if isinstance(layer, nn.MaxPool2d): 120 | feat.append(x) 121 | x = layer(x) 122 | x = F.avg_pool2d(x, x.size(3)) 123 | x = x.view(x.size(0), -1) 124 | 125 | if self.lwf: 126 | old_out = self.lwf_lyr(x) 127 | 128 | if isinstance(self.num_classes, list): 129 | x = self.classifier[idx](x) 130 | else: 131 | x = self.classifier(x) 132 | 133 | 134 | if self.lwf: 135 | return x, feat, old_out 136 | else: 137 | return x, feat 138 | 139 | def forward_with_features(self, x): 140 | return self.forward(x) 141 | 142 | def forward_with_combine_features(self, x, fs, metanet): 143 | return self.combine_forward(x, fs, metanet, 0) 144 | 145 | def _initialize_weights(self): 146 | for m in self.modules(): 147 | if isinstance(m, nn.Conv2d): 148 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 149 | if m.bias is not None: 150 | nn.init.constant_(m.bias, 0) 151 | elif isinstance(m, nn.BatchNorm2d): 152 | nn.init.constant_(m.weight, 1) 153 | nn.init.constant_(m.bias, 0) 154 | elif isinstance(m, nn.Linear): 155 | nn.init.normal_(m.weight, 0, 0.01) 156 | nn.init.constant_(m.bias, 0) 157 | 158 | 159 | def make_layers(cfg, batch_norm=False): 160 | layers = [] 161 | in_channels = 3 162 | for v in cfg: 163 | if v == 'M': 164 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 165 | else: 166 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 167 | if batch_norm: 168 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 169 | else: 170 | layers += [conv2d, nn.ReLU(inplace=True)] 171 | in_channels = v 172 | return nn.Sequential(*layers) 173 | 174 | 175 | cfg = { 176 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 177 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 178 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 179 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 180 | } 181 | 182 | cfg_small = { 183 | 'A': [64, 'M', 128, 'M', 512, 'M'], 184 | 'B': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 185 | } 186 | 187 | 188 | def vgg4(**kwargs): 189 | """ 190 | VGG 4-layer model (configuration_small "A") 191 | """ 192 | model = VGG_small(make_layers(cfg_small['A']), **kwargs) 193 | return model 194 | 195 | 196 | def vgg4_bn(**kwargs): 197 | """ 198 | VGG 4-layer model (configuration_small "A") with batch normalization 199 | """ 200 | model = VGG_small(make_layers(cfg_small['A'], batch_norm=True), **kwargs) 201 | return model 202 | 203 | 204 | def vgg9(**kwargs): 205 | """ 206 | VGG 9-layer model (configuration_small "B") 207 | """ 208 | model = VGG_small(make_layers(cfg_small['B']), **kwargs) 209 | return model 210 | 211 | 212 | def vgg9_bn(**kwargs): 213 | """ 214 | VGG 9-layer model (configuration_small "B") with batch normalization 215 | """ 216 | model = VGG_small(make_layers(cfg_small['B'], batch_norm=True), **kwargs) 217 | return model 218 | 219 | 220 | 221 | def vgg11(**kwargs): 222 | """ 223 | VGG 11-layer model (configuration "A") 224 | """ 225 | model = VGG(make_layers(cfg['A']), **kwargs) 226 | return model 227 | 228 | 229 | def vgg11_bn(**kwargs): 230 | """ 231 | VGG 11-layer model (configuration "A") with batch normalization 232 | """ 233 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 234 | return model 235 | 236 | 237 | def vgg13(**kwargs): 238 | """ 239 | VGG 13-layer model (configuration "B") 240 | """ 241 | model = VGG(make_layers(cfg['B']), **kwargs) 242 | return model 243 | 244 | 245 | def vgg13_bn(**kwargs): 246 | """ 247 | VGG 13-layer model (configuration "B") with batch normalization 248 | """ 249 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 250 | return model 251 | 252 | 253 | def vgg16(**kwargs): 254 | """ 255 | VGG 16-layer model (configuration "D") 256 | """ 257 | model = VGG(make_layers(cfg['D']), **kwargs) 258 | return model 259 | 260 | 261 | def vgg16_bn(**kwargs): 262 | """ 263 | VGG 16-layer model (configuration "D") with batch normalization 264 | """ 265 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 266 | return model 267 | 268 | 269 | def vgg19(**kwargs): 270 | """ 271 | VGG 19-layer model (configuration "E") 272 | """ 273 | model = VGG(make_layers(cfg['E']), **kwargs) 274 | return model 275 | 276 | 277 | def vgg19_bn(**kwargs): 278 | """ 279 | VGG 19-layer model (configuration 'E') with batch normalization 280 | """ 281 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 282 | return model 283 | 284 | 285 | 286 | 287 | if __name__ == "__main__": 288 | pass 289 | # x = torch.Tensor(5,3,32,32) 290 | # net = vgg4_bn() 291 | # y, feat = net(x) 292 | # 293 | # print (y.size()) 294 | # print() 295 | # for i in range(len(feat)): 296 | # print (feat[i].size()) 297 | 298 | 299 | -------------------------------------------------------------------------------- /split/cifar100-test: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/L2T-ww/f4dde04e8d5d5725dc3bff2f59cb0d0c26d0bcbe/split/cifar100-test -------------------------------------------------------------------------------- /split/cifar100-train: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/L2T-ww/f4dde04e8d5d5725dc3bff2f59cb0d0c26d0bcbe/split/cifar100-train -------------------------------------------------------------------------------- /split/cifar100-val: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/L2T-ww/f4dde04e8d5d5725dc3bff2f59cb0d0c26d0bcbe/split/cifar100-val -------------------------------------------------------------------------------- /split/cub200-test: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/L2T-ww/f4dde04e8d5d5725dc3bff2f59cb0d0c26d0bcbe/split/cub200-test -------------------------------------------------------------------------------- /split/cub200-train: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/L2T-ww/f4dde04e8d5d5725dc3bff2f59cb0d0c26d0bcbe/split/cub200-train -------------------------------------------------------------------------------- /split/cub200-val: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/L2T-ww/f4dde04e8d5d5725dc3bff2f59cb0d0c26d0bcbe/split/cub200-val -------------------------------------------------------------------------------- /split/dog-test: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/L2T-ww/f4dde04e8d5d5725dc3bff2f59cb0d0c26d0bcbe/split/dog-test -------------------------------------------------------------------------------- /split/dog-train: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/L2T-ww/f4dde04e8d5d5725dc3bff2f59cb0d0c26d0bcbe/split/dog-train -------------------------------------------------------------------------------- /split/dog-val: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/L2T-ww/f4dde04e8d5d5725dc3bff2f59cb0d0c26d0bcbe/split/dog-val -------------------------------------------------------------------------------- /split/stl10-test: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/L2T-ww/f4dde04e8d5d5725dc3bff2f59cb0d0c26d0bcbe/split/stl10-test -------------------------------------------------------------------------------- /split/stl10-train: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/L2T-ww/f4dde04e8d5d5725dc3bff2f59cb0d0c26d0bcbe/split/stl10-train -------------------------------------------------------------------------------- /split/stl10-val: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/L2T-ww/f4dde04e8d5d5725dc3bff2f59cb0d0c26d0bcbe/split/stl10-val -------------------------------------------------------------------------------- /train/meta_optimizers.py: -------------------------------------------------------------------------------- 1 | import torch, copy 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | def _copy(state): 8 | if isinstance(state, torch.Tensor): 9 | return state.cpu().clone() 10 | 11 | elif isinstance(state, dict): 12 | new_state = {} 13 | for key in state: 14 | new_state[key] = _copy(state[key]) 15 | 16 | elif isinstance(state, list): 17 | new_state = [] 18 | for item in state: 19 | new_state.append(_copy(item)) 20 | 21 | else: 22 | new_state = copy.deepcopy(state) 23 | 24 | return new_state 25 | 26 | class MetaSGD(optim.SGD): 27 | def __init__(self, params, modules, lr=0.1, momentum=0, weight_decay=0, rollback=False, cpu=False): 28 | super(MetaSGD, self).__init__(params, lr, momentum=momentum, weight_decay=weight_decay) 29 | self.prev_states = [] 30 | self.modules = modules + [self] 31 | self.rollback = rollback 32 | self.cpu = cpu 33 | 34 | def parameters(self): 35 | for pg in self.param_groups: 36 | for p in pg['params']: 37 | yield p 38 | 39 | def get_state(self): 40 | if self.cpu: 41 | return _copy([m.state_dict() for m in self.modules]) 42 | else: 43 | return copy.deepcopy([m.state_dict() for m in self.modules]) 44 | 45 | def set_state(self, state): 46 | for m, s in zip(self.modules, state): 47 | m.load_state_dict(s) 48 | 49 | def step(self, objective, *args, **kwargs): 50 | if objective is not None: 51 | self.prev_states.append((self.get_state(), objective, args, kwargs)) 52 | loss = objective(*args, **kwargs) 53 | loss.backward() 54 | super(MetaSGD, self).step() 55 | 56 | def meta_backward(self): 57 | alpha_groups = [] 58 | for pg in self.param_groups: 59 | alpha_groups.append([]) 60 | for p in pg['params']: 61 | if p.grad is None: 62 | p.grad = torch.zeros_like(p.data) 63 | grad = p.grad.data.clone() 64 | alpha_groups[-1].append((grad, torch.zeros_like(grad))) 65 | 66 | curr_state = self.get_state() 67 | for prev_state in reversed(self.prev_states): 68 | state, objective, args, kwargs = prev_state 69 | self.set_state(state) 70 | loss = objective(*args, **kwargs) 71 | grad = torch.autograd.grad(loss, list(self.parameters()), 72 | create_graph=True, allow_unused=True) 73 | grad = {p: g for p, g in zip(self.parameters(), grad)} 74 | X = 0.0 75 | for pg, ag in zip(self.param_groups, alpha_groups): 76 | lr = pg['lr'] 77 | wd = pg['weight_decay'] 78 | momentum = pg['momentum'] 79 | for p, a in zip(pg['params'], ag): 80 | g = grad[p] 81 | if g is not None: 82 | X = X+g.mul(a[0].mul(-lr)+a[1]).sum() 83 | self.zero_grad() 84 | X.backward() 85 | for pg, ag in zip(self.param_groups, alpha_groups): 86 | lr = pg['lr'] 87 | wd = pg['weight_decay'] 88 | momentum = pg['momentum'] 89 | for p, a in zip(pg['params'], ag): 90 | a_new = (a[0].mul(1-lr*wd).add_(wd, a[1]).add_(p.grad.data), 91 | a[1].mul(momentum).add_(-lr*momentum, a[0])) 92 | a[0].copy_(a_new[0]) 93 | a[1].copy_(a_new[1]) 94 | self.prev_states = [] 95 | if not self.rollback: 96 | self.set_state(curr_state) 97 | 98 | def __len__(self): 99 | return len(self.prev_states) 100 | 101 | def meta_backward_all(self, objective, outer_args): 102 | alpha_groups = [] 103 | for pg in self.param_groups: 104 | alpha_groups.append([]) 105 | for p in pg['params']: 106 | if p.grad is None: 107 | p.grad = torch.zeros_like(p.data) 108 | grad = p.grad.data 109 | alpha_groups[-1].append((torch.zeros_like(grad), torch.zeros_like(grad))) 110 | 111 | curr_state = self.get_state() 112 | for prev_state, o_args in zip(reversed(self.prev_states), outer_args): 113 | grad = torch.autograd.grad(objective(*o_args), list(self.parameters()), allow_unused=True) 114 | grad = {p: g for p, g in zip(self.parameters(), grad)} 115 | for pg, ag in zip(self.param_groups, alpha_groups): 116 | for i, p in enumerate(pg['params']): 117 | if grad[p] is not None: 118 | ag[i][0].add_(grad[p]) 119 | 120 | state, objective, args, kwargs = prev_state 121 | self.set_state(state) 122 | loss = objective(*args, **kwargs) 123 | grad = torch.autograd.grad(loss, list(self.parameters()), 124 | create_graph=True, allow_unused=True) 125 | grad = {p: g for p, g in zip(self.parameters(), grad)} 126 | X = 0.0 127 | for pg, ag in zip(self.param_groups, alpha_groups): 128 | lr = pg['lr'] 129 | wd = pg['weight_decay'] 130 | momentum = pg['momentum'] 131 | for p, a in zip(pg['params'], ag): 132 | g = grad[p] 133 | if g is not None: 134 | X = X+g.mul(a[0].mul(-lr)+a[1]).sum() 135 | self.zero_grad() 136 | X.backward() 137 | for pg, ag in zip(self.param_groups, alpha_groups): 138 | lr = pg['lr'] 139 | wd = pg['weight_decay'] 140 | momentum = pg['momentum'] 141 | for p, a in zip(pg['params'], ag): 142 | a_new = (a[0].mul(1-lr*wd).add_(wd, a[1]).add_(p.grad.data), 143 | a[1].mul(momentum).add_(-lr*momentum, a[0])) 144 | a[0].copy_(a_new[0]) 145 | a[1].copy_(a_new[1]) 146 | self.prev_states = [] 147 | self.set_state(curr_state) 148 | 149 | def test_metaSGD(): 150 | v1 = torch.nn.Parameter(torch.Tensor([1., 3.])) 151 | v2 = torch.nn.Parameter(torch.Tensor([[-1., -2.], [1., 0.]])) 152 | module = nn.Module() 153 | module.v1 = v1 154 | module.v2 = v2 155 | 156 | lmbd = torch.nn.Parameter(torch.zeros(2, 2)) 157 | 158 | sgd = MetaSGD([v1, v2], [module], lr=0.1, momentum=0.9, weight_decay=0.01) 159 | 160 | def inner_objective(): 161 | return v1.pow(2).mean() + (lmbd*(v2**2)).sum() 162 | 163 | def outer_objective(): 164 | return (v1*v2).mean() 165 | 166 | for _ in range(10): 167 | sgd.zero_grad() 168 | sgd.step(inner_objective) 169 | 170 | sgd.zero_grad() 171 | lmbd.grad.zero_() 172 | outer_objective().backward() 173 | sgd.meta_backward() 174 | 175 | print(lmbd.grad) 176 | -------------------------------------------------------------------------------- /train_l2t_ww.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | 10 | from check_dataset import check_dataset 11 | from check_model import check_model 12 | from utils.utils import AverageMeter, accuracy, set_logging_config 13 | from train.meta_optimizers import MetaSGD 14 | 15 | torch.backends.cudnn.benchmark = True 16 | 17 | 18 | def _get_num_features(model): 19 | if model.startswith('resnet'): 20 | n = int(model[6:]) 21 | if n in [18, 34, 50, 101, 152]: 22 | return [64, 64, 128, 256, 512] 23 | else: 24 | n = (n-2) // 6 25 | return [16]*n+[32]*n+[64]*n 26 | elif model.startswith('vgg'): 27 | n = int(model[3:].split('_')[0]) 28 | if n == 9: 29 | return [64, 128, 256, 512, 512] 30 | elif n == 11: 31 | return [64, 128, 256, 512, 512] 32 | 33 | raise NotImplementedError 34 | 35 | 36 | class FeatureMatching(nn.ModuleList): 37 | def __init__(self, source_model, target_model, pairs): 38 | super(FeatureMatching, self).__init__() 39 | self.src_list = _get_num_features(source_model) 40 | self.tgt_list = _get_num_features(target_model) 41 | self.pairs = pairs 42 | 43 | for src_idx, tgt_idx in pairs: 44 | self.append(nn.Conv2d(self.tgt_list[tgt_idx], self.src_list[src_idx], 1)) 45 | 46 | def forward(self, source_features, target_features, 47 | weight, beta, loss_weight): 48 | 49 | matching_loss = 0.0 50 | for i, (src_idx, tgt_idx) in enumerate(self.pairs): 51 | sw = source_features[src_idx].size(3) 52 | tw = target_features[tgt_idx].size(3) 53 | if sw == tw: 54 | diff = source_features[src_idx] - self[i](target_features[tgt_idx]) 55 | else: 56 | diff = F.interpolate( 57 | source_features[src_idx], 58 | scale_factor=tw / sw, 59 | mode='bilinear' 60 | ) - self[i](target_features[tgt_idx]) 61 | diff = diff.pow(2).mean(3).mean(2) 62 | if loss_weight is None and weight is None: 63 | diff = diff.mean(1).mean(0).mul(beta[i]) 64 | elif loss_weight is None: 65 | diff = diff.mul(weight[i]).sum(1).mean(0).mul(beta[i]) 66 | elif weight is None: 67 | diff = (diff.sum(1)*(loss_weight[i].squeeze())).mean(0).mul(beta[i]) 68 | else: 69 | diff = (diff.mul(weight[i]).sum(1)*(loss_weight[i].squeeze())).mean(0).mul(beta[i]) 70 | matching_loss = matching_loss + diff 71 | return matching_loss 72 | 73 | 74 | class WeightNetwork(nn.ModuleList): 75 | def __init__(self, source_model, pairs): 76 | super(WeightNetwork, self).__init__() 77 | n = _get_num_features(source_model) 78 | for i, _ in pairs: 79 | self.append(nn.Linear(n[i], n[i])) 80 | self[-1].weight.data.zero_() 81 | self[-1].bias.data.zero_() 82 | self.pairs = pairs 83 | 84 | def forward(self, source_features): 85 | outputs = [] 86 | for i, (idx, _) in enumerate(self.pairs): 87 | f = source_features[idx] 88 | f = F.avg_pool2d(f, f.size(2)).view(-1, f.size(1)) 89 | outputs.append(F.softmax(self[i](f), 1)) 90 | return outputs 91 | 92 | 93 | class LossWeightNetwork(nn.ModuleList): 94 | def __init__(self, source_model, pairs, weight_type='relu', init=None): 95 | super(LossWeightNetwork, self).__init__() 96 | n = _get_num_features(source_model) 97 | if weight_type == 'const': 98 | self.weights = nn.Parameter(torch.zeros(len(pairs))) 99 | else: 100 | for i, _ in pairs: 101 | l = nn.Linear(n[i], 1) 102 | if init is not None: 103 | nn.init.constant_(l.bias, init) 104 | self.append(l) 105 | self.pairs = pairs 106 | self.weight_type = weight_type 107 | 108 | def forward(self, source_features): 109 | outputs = [] 110 | if self.weight_type == 'const': 111 | for w in F.softplus(self.weights.mul(10)): 112 | outputs.append(w.view(1, 1)) 113 | else: 114 | for i, (idx, _) in enumerate(self.pairs): 115 | f = source_features[idx] 116 | f = F.avg_pool2d(f, f.size(2)).view(-1, f.size(1)) 117 | if self.weight_type == 'relu': 118 | outputs.append(F.relu(self[i](f))) 119 | elif self.weight_type == 'relu-avg': 120 | outputs.append(F.relu(self[i](f.div(f.size(1))))) 121 | elif self.weight_type == 'relu6': 122 | outputs.append(F.relu6(self[i](f))) 123 | return outputs 124 | 125 | 126 | def main(): 127 | parser = argparse.ArgumentParser(add_help=False) 128 | parser.add_argument('--dataroot', required=True, help='Path to the dataset') 129 | parser.add_argument('--dataset', default='cub200') 130 | parser.add_argument('--datasplit', default='cub200') 131 | parser.add_argument('--batchSize', type=int, default=64, help='Input batch size') 132 | parser.add_argument('--workers', type=int, default=4) 133 | 134 | parser.add_argument('--source-model', default='resnet34', type=str) 135 | parser.add_argument('--source-domain', default='imagenet', type=str) 136 | parser.add_argument('--source-path', type=str, default=None) 137 | parser.add_argument('--target-model', default='resnet18', type=str) 138 | parser.add_argument('--weight-path', type=str, default=None) 139 | parser.add_argument('--wnet-path', type=str, default=None) 140 | 141 | parser.add_argument('--epochs', type=int, default=200) 142 | parser.add_argument('--lr', type=float, default=0.1,help='Initial learning rate') 143 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 144 | parser.add_argument('--wd', type=float, default=0.0001, help='Weight decay') 145 | parser.add_argument('--nesterov', action='store_true') 146 | parser.add_argument('--schedule', action='store_true', default=True) 147 | parser.add_argument('--beta', type=float, default=0.5) 148 | parser.add_argument('--pairs', type=str, default='4-4,4-3,4-2,4-1,3-4,3-3,3-2,3-1,2-4,2-3,2-2,2-1,1-4,1-3,1-2,1-1') 149 | 150 | parser.add_argument('--meta-lr', type=float, default=1e-4, help='Initial learning rate for meta networks') 151 | parser.add_argument('--meta-wd', type=float, default=1e-4) 152 | parser.add_argument('--loss-weight', action='store_true', default=True) 153 | parser.add_argument('--loss-weight-type', type=str, default='relu6') 154 | parser.add_argument('--loss-weight-init', type=float, default=1.0) 155 | parser.add_argument('--T', type=int, default=2) 156 | parser.add_argument('--optimizer', type=str, default='adam') 157 | 158 | parser.add_argument('--experiment', default='logs', help='Where to store models') 159 | 160 | # default settings 161 | opt = parser.parse_args() 162 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 163 | os.makedirs(opt.experiment) 164 | set_logging_config(opt.experiment) 165 | logger = logging.getLogger('main') 166 | logger.info(' '.join(os.sys.argv)) 167 | logger.info(opt) 168 | 169 | # load source model 170 | if opt.source_domain == 'imagenet': 171 | from models import resnet_ilsvrc 172 | source_model = resnet_ilsvrc.__dict__[opt.source_model](pretrained=True).to(device) 173 | else: 174 | opt.model = opt.source_model 175 | weights = [] 176 | source_gen_params = [] 177 | source_path = os.path.join( 178 | opt.source_path, '{}-{}'.format(opt.source_domain, opt.source_model), 179 | '0', 180 | 'model_best.pth.tar' 181 | ) 182 | ckpt = torch.load(source_path) 183 | opt.num_classes = ckpt['num_classes'] 184 | source_model = check_model(opt).to(device) 185 | source_model.load_state_dict(ckpt['state_dict'], strict=False) 186 | 187 | pairs = [] 188 | for pair in opt.pairs.split(','): 189 | pairs.append((int(pair.split('-')[0]), 190 | int(pair.split('-')[1]))) 191 | 192 | wnet = WeightNetwork(opt.source_model, pairs).to(device) 193 | weight_params = list(wnet.parameters()) 194 | if opt.loss_weight: 195 | lwnet = LossWeightNetwork(opt.source_model, pairs, opt.loss_weight_type, opt.loss_weight_init).to(device) 196 | weight_params = weight_params + list(lwnet.parameters()) 197 | 198 | if opt.wnet_path is not None: 199 | ckpt = torch.load(opt.wnet_path) 200 | wnet.load_state_dict(ckpt['w']) 201 | if opt.loss_weight: 202 | lwnet.load_state_dict(ckpt['lw']) 203 | 204 | if opt.optimizer == 'sgd': 205 | source_optimizer = optim.SGD(weight_params, lr=opt.meta_lr, weight_decay=opt.meta_wd, momentum=opt.momentum, nesterov=opt.nesterov) 206 | else: 207 | source_optimizer = optim.Adam(weight_params, lr=opt.meta_lr, weight_decay=opt.meta_wd) 208 | 209 | # load dataloaders 210 | loaders = check_dataset(opt) 211 | 212 | # load target model 213 | opt.model = opt.target_model 214 | target_model = check_model(opt).to(device) 215 | target_branch = FeatureMatching(opt.source_model, 216 | opt.target_model, 217 | pairs).to(device) 218 | target_params = list(target_model.parameters()) + list(target_branch.parameters()) 219 | if opt.meta_lr == 0: 220 | target_optimizer = optim.SGD(target_params, lr=opt.lr, momentum=opt.momentum, weight_decay=opt.wd) 221 | else: 222 | target_optimizer = MetaSGD(target_params, 223 | [target_model, target_branch], 224 | lr=opt.lr, 225 | momentum=opt.momentum, 226 | weight_decay=opt.wd, rollback=True, cpu=opt.T>2) 227 | 228 | state = { 229 | 'target_model': target_model.state_dict(), 230 | 'target_branch': target_branch.state_dict(), 231 | 'target_optimizer': target_optimizer.state_dict(), 232 | 'w': wnet.state_dict(), 233 | 'best': (0.0, 0.0) 234 | } 235 | if opt.loss_weight: 236 | state['lw'] = lwnet.state_dict() 237 | 238 | scheduler = optim.lr_scheduler.CosineAnnealingLR(target_optimizer, opt.epochs) 239 | 240 | def validate(model, loader): 241 | acc = AverageMeter() 242 | model.eval() 243 | for x, y in loader: 244 | x, y = x.to(device), y.to(device) 245 | y_pred, _ = model(x) 246 | acc.update(accuracy(y_pred.data, y, topk=(1,))[0].item(), x.size(0)) 247 | return acc.avg 248 | 249 | def inner_objective(data, matching_only=False): 250 | x, y = data[0].to(device), data[1].to(device) 251 | y_pred, target_features = target_model.forward_with_features(x) 252 | 253 | with torch.no_grad(): 254 | s_pred, source_features = source_model.forward_with_features(x) 255 | 256 | weights = wnet(source_features) 257 | state['loss_weights'] = '' 258 | if opt.loss_weight: 259 | loss_weights = lwnet(source_features) 260 | state['loss_weights'] = ' '.join(['{:.2f}'.format(lw.mean().item()) for lw in loss_weights]) 261 | else: 262 | loss_weights = None 263 | beta = [opt.beta] * len(wnet) 264 | 265 | matching_loss = target_branch(source_features, 266 | target_features, 267 | weights, beta, loss_weights) 268 | 269 | state['accuracy'] = accuracy(y_pred.data, y, topk=(1,))[0].item() 270 | 271 | if matching_only: 272 | return matching_loss 273 | 274 | loss = F.cross_entropy(y_pred, y) 275 | state['loss'] = loss.item() 276 | return loss + matching_loss 277 | 278 | def outer_objective(data): 279 | x, y = data[0].to(device), data[1].to(device) 280 | y_pred, _ = target_model(x) 281 | state['accuracy'] = accuracy(y_pred.data, y, topk=(1,))[0].item() 282 | loss = F.cross_entropy(y_pred, y) 283 | state['loss'] = loss.item() 284 | return loss 285 | 286 | # source generator training 287 | state['iter'] = 0 288 | for epoch in range(opt.epochs): 289 | if opt.schedule: 290 | scheduler.step() 291 | 292 | state['epoch'] = epoch 293 | target_model.train() 294 | source_model.eval() 295 | for i, data in enumerate(loaders[0]): 296 | target_optimizer.zero_grad() 297 | inner_objective(data).backward() 298 | target_optimizer.step(None) 299 | 300 | logger.info('[Epoch {:3d}] [Iter {:3d}] [Loss {:.4f}] [Acc {:.4f}] [LW {}]'.format( 301 | state['epoch'], state['iter'], 302 | state['loss'], state['accuracy'], state['loss_weights'])) 303 | state['iter'] += 1 304 | 305 | for _ in range(opt.T): 306 | target_optimizer.zero_grad() 307 | target_optimizer.step(inner_objective, data, True) 308 | 309 | target_optimizer.zero_grad() 310 | target_optimizer.step(outer_objective, data) 311 | 312 | target_optimizer.zero_grad() 313 | source_optimizer.zero_grad() 314 | outer_objective(data).backward() 315 | target_optimizer.meta_backward() 316 | source_optimizer.step() 317 | 318 | acc = (validate(target_model, loaders[1]), 319 | validate(target_model, loaders[2])) 320 | 321 | if state['best'][0] < acc[0]: 322 | state['best'] = acc 323 | 324 | if state['epoch'] % 10 == 0: 325 | torch.save(state, os.path.join(opt.experiment, 'ckpt-{}.pth'.format(state['epoch']+1))) 326 | 327 | logger.info('[Epoch {}] [val {:.4f}] [test {:.4f}] [best {:.4f}]'.format(epoch, acc[0], acc[1], state['best'][1])) 328 | 329 | 330 | if __name__ == '__main__': 331 | main() 332 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # from nested_dict import nested_dict 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import logging 7 | 8 | 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value""" 11 | 12 | def __init__(self): 13 | self.reset() 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | 27 | 28 | def accuracy(output, target, topk=(1,)): 29 | """Computes the precision@k for the specified values of k""" 30 | maxk = max(topk) 31 | batch_size = target.size(0) 32 | 33 | _, pred = output.topk(maxk, 1, True, True) 34 | pred = pred.t() 35 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 36 | 37 | res = [] 38 | for k in topk: 39 | correct_k = correct[:k].view(-1).float().sum(0) 40 | res.append(correct_k.mul_(100.0 / batch_size)) 41 | return res 42 | 43 | 44 | def set_logging_config(logdir): 45 | logging.basicConfig(format="[%(asctime)s] [%(name)s] %(message)s", 46 | level=logging.INFO, 47 | handlers=[logging.FileHandler(os.path.join(logdir, 'log.txt')), 48 | logging.StreamHandler(os.sys.stdout)]) 49 | --------------------------------------------------------------------------------