├── README.md ├── dataloader.py ├── eval.py ├── figs └── reconstruction.jpg ├── models ├── builer.py ├── resnet.py └── vgg.py ├── requirements.txt ├── run ├── eval.sh ├── evalall.sh └── train.sh ├── tools ├── decode.py ├── encode.py ├── generate_list.py └── reconstruct.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Auto-Encoder trained on ImageNet 2 | 3 | Train VGG-like and ResNet-like auto-encoder on image dataset like ImageNet 4 | 5 | 6 | ![imagenet-autoencoder/reconstruction.jpg at main · Horizon2333/imagenet-autoencoder (github.com)](https://github.com/Horizon2333/imagenet-autoencoder/blob/main/figs/reconstruction.jpg) 7 | 8 | 9 | 1. [Project Structure](#project-structure) 10 | 2. [Install](#install) 11 | 3. [Data preparing](#data-preparing) 12 | 4. [Train and Evaluate](#train-and-evaluate) 13 | 5. [Tools](#tools) 14 | 6. [Model zoo](#model-zoo) 15 | 16 | ## Project Structure 17 | 18 | ``` 19 | $imagenet-autoencoder 20 | |──figs # result images 21 | |── *.jpg 22 | |──models 23 | |──builder.py # build autoencoder models 24 | |──resnet.py # resnet-like autoencoder 25 | |──vgg.py # vgg-like autoencoder 26 | |──run 27 | |──eval.sh # command to evaluate single checkpoint 28 | |──evalall.sh # command to evaluate all checkpoints in specific folder 29 | |──train.sh # command to train auto-encoder 30 | |──tools 31 | |──decode.py # decode random latent code to images 32 | |──encode.py # encode single image to latent code 33 | |──generate_list.py # generate image list for training 34 | |──reconstrust.py # reconstruct the images to see difference 35 | |──dataloader.py # dataset and dataloader 36 | |──eval.py # evaluate checkpoints 37 | |──train.py # train models 38 | |──utils.py # other utility function 39 | |──requirements.txt 40 | |──README.md 41 | ``` 42 | 43 | ## Install 44 | 45 | 46 | 1. Clone the project 47 | ```shell 48 | git clone https://github.com/Horizon2333/imagenet-autoencoder 49 | cd imagenet-autoencoder 50 | ``` 51 | 2. Install dependencies 52 | ```shell 53 | pip install -r requirements.txt 54 | ``` 55 | 56 | ## Data Preparing 57 | 58 | Your dataset should looks like: 59 | 60 | ``` 61 | $your_dataset_path 62 | |──class1 63 | |──xxxx.jpg 64 | |──... 65 | |──class2 66 | |──xxxx.jpg 67 | |──... 68 | |──... 69 | |──classN 70 | |──xxxx.jpg 71 | |──... 72 | ``` 73 | 74 | The you can use ```tools/generate_list.py``` to generate list of training samples. Here we do not use ```torchvision.datasets.ImageFolder``` because it is very slow when dataset is pretty large. You can run 75 | 76 | ```shell 77 | python tools/generate_list.py --name {name your dataset such as caltech256} --path {path to your dataset} 78 | ``` 79 | 80 | Then two files will be generated under ```list``` folder, one ```*_list.txt``` save every image path and its class(here no use); one ```*_name.txt``` save index of every class and its class name. 81 | 82 | ## Train and Evaluate 83 | 84 | For training 85 | 86 | ```shell 87 | bash run/train.sh {model architecture such as vgg16} {you dataset name} 88 | # For example 89 | bash run/train.sh vgg16 caltech256 90 | ``` 91 | 92 | For evaluating single checkpoint: 93 | 94 | ```shell 95 | bash run/eval.sh {model architecture} {checkpoint path} {dataset name} 96 | # For example 97 | bash run/eval.sh vgg16 results/caltech256-vgg16/099.pth caltech101 98 | ``` 99 | 100 | For evaluating all checkpoints under specific folder: 101 | 102 | ```shell 103 | bash run/evalall.sh {model architecture} {checkpoints path} {dataset name} 104 | # For example 105 | bash run/evalall.sh vgg16 results/caltech256-vgg16/ caltech101 106 | ``` 107 | When all checkpoints are evaluated, a scatter diagram ```figs/evalall.jpg``` will be generated to show the evaluate loss trend. 108 | 109 | For model architecture, now we support ```vgg11,vgg13,vgg16,vgg19``` and ```resnet18, resnet34, resnet50, resnet101, resnet152```. 110 | 111 | ## Tools 112 | 113 | We provide several tools to better visualize the auto-encoder results. 114 | 115 | ```reconstruct.py``` 116 | 117 | Reconstruct images from original one. This code will sample 64 of them and save the comparison results to ```figs/reconstruction.jpg```. 118 | 119 | ```shell 120 | python tools/reconstruct.py --arch {model architecture} --resume {checkpoint path} --val_list {*_list.txt of your dataset} 121 | # For example 122 | python tools/reconstruct.py --arch vgg16 --resume results/caltech256-vgg16/099.pth --val_list caltech101_list.txt 123 | ``` 124 | 125 | ```encode.py``` and ```decode.py``` 126 | 127 | Encode image to latent code or decode latent code to images. 128 | 129 | ```encode.py``` can transfer single image to latent code. 130 | 131 | ```shell 132 | python tools/encode.py --arch {model architecture} --resume {checkpoint path} --img_path {image path} 133 | # For example 134 | python tools/encode.py --arch vgg16 --resume results/caltech256-vgg16/099.pth --img_path figs/reconstruction.jpg 135 | ``` 136 | 137 | ```decode.py``` can transform 128 random latent code to images. 138 | 139 | ```shell 140 | python tools/decode.py --arch {model architecture} --resume {checkpoint path} 141 | # For example 142 | python tools/decode.py --arch vgg16 --resume results/caltech256-vgg16/099.pth 143 | ``` 144 | 145 | The decoded results will be save as ```figs/generation.jpg``` 146 | 147 | ## Model zoo 148 | 149 | | Dataset | VGG11 | VGG13 | VGG16 | VGG19 | ResNet18 | ResNet34 | ResNet50 | ResNet101 | ResNet152 | 150 | | :--------: | :---: | :---: | :---: | :---: | :------: | :------: | :------: | :-------: | :--------: | 151 | | Caltech256 | [link](https://drive.google.com/file/d/1gebnzAnFDpT9mmzr2dDVZ39FxqZHSuD4/view?usp=sharing) | [link](https://drive.google.com/file/d/1JRooEtKw2-2R_u-pswX2C8mAl_GgAlhH/view?usp=sharing) | [link](https://drive.google.com/file/d/12ysuL1rzIedcL_KD3VNDcZn9lGwxCWFu/view?usp=sharing) | [link](https://drive.google.com/file/d/1ydCY3llYJLL3asZ45-EGPUYxB-jlLVFo/view?usp=sharing) | [link](https://drive.google.com/file/d/1vokB8J17t34qk8qN37cVrEes06wzJzzG/view?usp=sharing) | [link](https://drive.google.com/file/d/1EMfNI6uAMdx-T1QmYg-UQHNWLxkaub6c/view?usp=sharing) | [link](https://drive.google.com/file/d/1-lA1dtP9q9ABom7c3qbMy7JYnnQvsI9H/view?usp=sharing) | [link](https://drive.google.com/file/d/1yNzkPhf2LAzu0mVm3ZedTObl_s-2pg1J/view?usp=sharing) | [link](https://drive.google.com/file/d/1HX7zaMK4ug6GjdUljG8Jqc4OvT8aLTrD/view?usp=sharing) | 152 | | Objects365 | | | [link](https://drive.google.com/file/d/16ozLClq8_Kpoc1Ln8dgIkQC4v7a1OTyz/view?usp=sharing) | [link](https://drive.google.com/file/d/1nR_9_WsYXGzBvzLdsxlEba9XyBwg1aD7/view?usp=sharing) | | | [link](https://drive.google.com/file/d/1FLPcRcAKaYBZrJQ7uYz0ST0WPrgacwm6/view?usp=sharing) | [link](https://drive.google.com/file/d/1pVtZpQn2kT1e2ZhG1MBvLLAMEVI30mVL/view?usp=sharing) | | 153 | | ImageNet | | | [link](https://drive.google.com/file/d/1WwJiQ1kBcNCZ37F6PJ_0bIL0ZeU3_sV8/view?usp=sharing) | | | | | | | 154 | 155 | Note that the size of Objects365 dataset is about half of ImageNet dataset(128 million images, much larger than Caltech256), so the performance may be comparable. 156 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import random 4 | 5 | from PIL import Image, ImageFilter 6 | 7 | import torch 8 | import torch.utils.data as data 9 | 10 | from torchvision.transforms import transforms 11 | 12 | class ImageDataset(data.Dataset): 13 | def __init__(self, ann_file, transform=None): 14 | self.ann_file = ann_file 15 | self.transform = transform 16 | self.init() 17 | 18 | def init(self): 19 | 20 | self.im_names = [] 21 | self.targets = [] 22 | with open(self.ann_file, 'r') as f: 23 | lines = f.readlines() 24 | for line in lines: 25 | data = line.strip().split(' ') 26 | self.im_names.append(data[0]) 27 | self.targets.append(int(data[1])) 28 | 29 | def __getitem__(self, index): 30 | im_name = self.im_names[index] 31 | target = self.targets[index] 32 | 33 | img = Image.open(im_name).convert('RGB') 34 | if img is None: 35 | print(im_name) 36 | 37 | img = self.transform(img) 38 | 39 | return img, img 40 | 41 | def __len__(self): 42 | return len(self.im_names) 43 | 44 | def train_loader(args): 45 | 46 | # [NO] do not use normalize here cause it's very hard to converge 47 | # [NO] do not use colorjitter cause it lead to performance drop in both train set and val set 48 | 49 | # [?] guassian blur will lead to a significantly drop in train loss while val loss remain the same 50 | 51 | augmentation = [ 52 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 53 | #transforms.RandomGrayscale(p=0.2), 54 | #transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 55 | transforms.RandomHorizontalFlip(), 56 | transforms.ToTensor(), 57 | ] 58 | 59 | train_trans = transforms.Compose(augmentation) 60 | 61 | train_dataset = ImageDataset(args.train_list, transform=train_trans) 62 | 63 | if args.parallel == 1: 64 | train_sampler = torch.utils.data.distributed.DistributedSampler( 65 | train_dataset, 66 | rank=args.rank, 67 | num_replicas=args.world_size, 68 | shuffle=True) 69 | else: 70 | train_sampler = None 71 | 72 | train_loader = torch.utils.data.DataLoader( 73 | train_dataset, 74 | shuffle=(train_sampler is None), 75 | batch_size=args.batch_size, 76 | num_workers=args.workers, 77 | pin_memory=True, 78 | sampler=train_sampler, 79 | drop_last=(train_sampler is None)) 80 | 81 | return train_loader 82 | 83 | def val_loader(args): 84 | 85 | val_trans = transforms.Compose([ 86 | transforms.Resize(256), 87 | transforms.CenterCrop(224), 88 | transforms.ToTensor() 89 | ]) 90 | 91 | val_dataset = ImageDataset(args.val_list, transform=val_trans) 92 | 93 | val_loader = torch.utils.data.DataLoader( 94 | val_dataset, 95 | shuffle=False, 96 | batch_size=args.batch_size, 97 | num_workers=args.workers, 98 | pin_memory=True) 99 | 100 | return val_loader 101 | 102 | class GaussianBlur(object): 103 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 104 | 105 | def __init__(self, sigma=[.1, 2.]): 106 | self.sigma = sigma 107 | 108 | def __call__(self, x): 109 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 110 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 111 | return x -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import time 5 | import argparse 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | import utils 13 | import models.builer as builder 14 | import dataloader 15 | 16 | def get_args(): 17 | # parse the args 18 | print('=> parse the args ...') 19 | parser = argparse.ArgumentParser(description='Evaluate for auto encoder') 20 | parser.add_argument('--arch', default='vgg16', type=str, 21 | help='backbone architechture') 22 | parser.add_argument('--val_list', type=str) 23 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 24 | help='number of data loading workers (default: 0)') 25 | parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', 26 | help='mini-batch size (default: 256), this is the total ' 27 | 'batch size of all GPUs on the current node when ' 28 | 'using Data Parallel or Distributed Data Parallel') 29 | 30 | parser.add_argument('-p', '--print-freq', default=20, type=int, 31 | metavar='N', help='print frequency (default: 10)') 32 | 33 | parser.add_argument('--resume', type=str) 34 | parser.add_argument('--folder', type=str) 35 | parser.add_argument('--start_epoch', default=0, type=int) 36 | parser.add_argument('--epochs', default=100, type=int) 37 | 38 | args = parser.parse_args() 39 | 40 | args.parallel = 0 41 | 42 | return args 43 | 44 | def main(args): 45 | print('=> torch version : {}'.format(torch.__version__)) 46 | ngpus_per_node = torch.cuda.device_count() 47 | print('=> ngpus : {}'.format(ngpus_per_node)) 48 | 49 | utils.init_seeds(1, cuda_deterministic=False) 50 | 51 | print('=> modeling the network ...') 52 | model = builder.BuildAutoEncoder(args) 53 | total_params = sum(p.numel() for p in model.parameters()) 54 | print('=> num of params: {} ({}M)'.format(total_params, int(total_params * 4 / (1024*1024)))) 55 | 56 | print('=> building the dataloader ...') 57 | val_loader = dataloader.val_loader(args) 58 | 59 | print('=> building the criterion ...') 60 | criterion = nn.MSELoss() 61 | 62 | print('=> starting evaluating engine ...') 63 | if args.folder: 64 | best_loss = None 65 | best_epoch = 1 66 | losses = [] 67 | for epoch in range(args.start_epoch, args.epochs): 68 | print() 69 | print("Epoch {}".format(epoch+1)) 70 | resume_path = os.path.join(args.folder, "%03d.pth" % epoch) 71 | print('=> loading pth from {} ...'.format(resume_path)) 72 | utils.load_dict(resume_path, model) 73 | loss = do_evaluate(val_loader, model, criterion, args) 74 | print("Evaluate loss : {:.4f}".format(loss)) 75 | 76 | losses.append(loss) 77 | if best_loss: 78 | if loss < best_loss: 79 | best_loss = loss 80 | best_epoch = epoch + 1 81 | else: 82 | best_loss = loss 83 | print() 84 | print("Best loss : {:.4f} Appears in {}".format(best_loss, best_epoch)) 85 | 86 | max_loss = max(losses) 87 | 88 | plt.figure(figsize=(7,7)) 89 | 90 | plt.xlabel("epoch") 91 | plt.ylabel("loss") 92 | plt.xlim((0,args.epochs+1)) 93 | plt.ylim([0, float('%.1g' % (1.22*max_loss))]) 94 | 95 | plt.scatter(range(1, args.epochs+1), losses, s=9) 96 | 97 | plt.savefig("figs/evalall.jpg") 98 | 99 | else: 100 | print('=> loading pth from {} ...'.format(args.resume)) 101 | utils.load_dict(args.resume, model) 102 | loss = do_evaluate(val_loader, model, criterion, args) 103 | print("Evaluate loss : {:.4f}".format(loss)) 104 | 105 | 106 | def do_evaluate(val_loader, model, criterion, args): 107 | batch_time = utils.AverageMeter('Time', ':6.2f') 108 | data_time = utils.AverageMeter('Data', ':2.2f') 109 | losses = utils.AverageMeter('Loss', ':.4f') 110 | 111 | progress = utils.ProgressMeter( 112 | len(val_loader), 113 | [batch_time, data_time, losses], 114 | prefix="Evaluate ") 115 | end = time.time() 116 | 117 | model.eval() 118 | with torch.no_grad(): 119 | for i, (input, target) in enumerate(val_loader): 120 | # measure data loading time 121 | data_time.update(time.time() - end) 122 | 123 | input = input.cuda(non_blocking=True) 124 | target = target.cuda(non_blocking=True) 125 | 126 | output = model(input) 127 | 128 | loss = criterion(output, target) 129 | 130 | # record loss 131 | losses.update(loss.item(), input.size(0)) 132 | batch_time.update(time.time() - end) 133 | end = time.time() 134 | 135 | if i % args.print_freq == 0: 136 | progress.display(i) 137 | 138 | return losses.avg 139 | 140 | if __name__ == '__main__': 141 | 142 | args = get_args() 143 | 144 | main(args) 145 | 146 | 147 | -------------------------------------------------------------------------------- /figs/reconstruction.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Horizon2333/imagenet-autoencoder/536633ec1c0e9afe2dd91ce74b56e6e13479b6bd/figs/reconstruction.jpg -------------------------------------------------------------------------------- /models/builer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.parallel as parallel 3 | 4 | from . import vgg, resnet 5 | 6 | def BuildAutoEncoder(args): 7 | 8 | if args.arch in ["vgg11", "vgg13", "vgg16", "vgg19"]: 9 | configs = vgg.get_configs(args.arch) 10 | model = vgg.VGGAutoEncoder(configs) 11 | 12 | elif args.arch in ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]: 13 | configs, bottleneck = resnet.get_configs(args.arch) 14 | model = resnet.ResNetAutoEncoder(configs, bottleneck) 15 | 16 | else: 17 | return None 18 | 19 | if args.parallel == 1: 20 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 21 | model = parallel.DistributedDataParallel( 22 | model.to(args.gpu), 23 | device_ids=[args.gpu], 24 | output_device=args.gpu 25 | ) 26 | 27 | else: 28 | model = nn.DataParallel(model).cuda() 29 | 30 | return model -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def get_configs(arch='resnet50'): 5 | 6 | # True or False means wether to use BottleNeck 7 | 8 | if arch == 'resnet18': 9 | return [2, 2, 2, 2], False 10 | elif arch == 'resnet34': 11 | return [3, 4, 6, 3], False 12 | elif arch == 'resnet50': 13 | return [3, 4, 6, 3], True 14 | elif arch == 'resnet101': 15 | return [3, 4, 23, 3], True 16 | elif arch == 'resnet152': 17 | return [3, 8, 36, 3], True 18 | else: 19 | raise ValueError("Undefined model") 20 | 21 | class ResNetAutoEncoder(nn.Module): 22 | 23 | def __init__(self, configs, bottleneck): 24 | 25 | super(ResNetAutoEncoder, self).__init__() 26 | 27 | self.encoder = ResNetEncoder(configs=configs, bottleneck=bottleneck) 28 | self.decoder = ResNetDecoder(configs=configs[::-1], bottleneck=bottleneck) 29 | 30 | def forward(self, x): 31 | 32 | x = self.encoder(x) 33 | x = self.decoder(x) 34 | 35 | return x 36 | 37 | class ResNet(nn.Module): 38 | 39 | def __init__(self, configs, bottleneck=False, num_classes=1000): 40 | super(ResNet, self).__init__() 41 | 42 | self.encoder = ResNetEncoder(configs, bottleneck) 43 | 44 | self.avpool = nn.AdaptiveAvgPool2d((1,1)) 45 | 46 | if bottleneck: 47 | self.fc = nn.Linear(in_features=2048, out_features=num_classes) 48 | else: 49 | self.fc = nn.Linear(in_features=512, out_features=num_classes) 50 | 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv2d): 53 | nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu") 54 | if m.bias is not None: 55 | nn.init.constant_(m.bias, 0) 56 | elif isinstance(m, nn.BatchNorm2d): 57 | nn.init.constant_(m.weight, 1) 58 | nn.init.constant_(m.bias, 0) 59 | elif isinstance(m, nn.Linear): 60 | nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu") 61 | nn.init.constant_(m.bias, 0) 62 | 63 | def forward(self, x): 64 | 65 | x = self.encoder(x) 66 | 67 | x = self.avpool(x) 68 | 69 | x = torch.flatten(x, 1) 70 | 71 | x = self.fc(x) 72 | 73 | return x 74 | 75 | 76 | class ResNetEncoder(nn.Module): 77 | 78 | def __init__(self, configs, bottleneck=False): 79 | super(ResNetEncoder, self).__init__() 80 | 81 | if len(configs) != 4: 82 | raise ValueError("Only 4 layers can be configued") 83 | 84 | self.conv1 = nn.Sequential( 85 | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False), 86 | nn.BatchNorm2d(num_features=64), 87 | nn.ReLU(inplace=True), 88 | ) 89 | 90 | if bottleneck: 91 | 92 | self.conv2 = EncoderBottleneckBlock(in_channels=64, hidden_channels=64, up_channels=256, layers=configs[0], downsample_method="pool") 93 | self.conv3 = EncoderBottleneckBlock(in_channels=256, hidden_channels=128, up_channels=512, layers=configs[1], downsample_method="conv") 94 | self.conv4 = EncoderBottleneckBlock(in_channels=512, hidden_channels=256, up_channels=1024, layers=configs[2], downsample_method="conv") 95 | self.conv5 = EncoderBottleneckBlock(in_channels=1024, hidden_channels=512, up_channels=2048, layers=configs[3], downsample_method="conv") 96 | 97 | else: 98 | 99 | self.conv2 = EncoderResidualBlock(in_channels=64, hidden_channels=64, layers=configs[0], downsample_method="pool") 100 | self.conv3 = EncoderResidualBlock(in_channels=64, hidden_channels=128, layers=configs[1], downsample_method="conv") 101 | self.conv4 = EncoderResidualBlock(in_channels=128, hidden_channels=256, layers=configs[2], downsample_method="conv") 102 | self.conv5 = EncoderResidualBlock(in_channels=256, hidden_channels=512, layers=configs[3], downsample_method="conv") 103 | 104 | def forward(self, x): 105 | 106 | x = self.conv1(x) 107 | x = self.conv2(x) 108 | x = self.conv3(x) 109 | x = self.conv4(x) 110 | x = self.conv5(x) 111 | 112 | return x 113 | 114 | class ResNetDecoder(nn.Module): 115 | 116 | def __init__(self, configs, bottleneck=False): 117 | super(ResNetDecoder, self).__init__() 118 | 119 | if len(configs) != 4: 120 | raise ValueError("Only 4 layers can be configued") 121 | 122 | if bottleneck: 123 | 124 | self.conv1 = DecoderBottleneckBlock(in_channels=2048, hidden_channels=512, down_channels=1024, layers=configs[0]) 125 | self.conv2 = DecoderBottleneckBlock(in_channels=1024, hidden_channels=256, down_channels=512, layers=configs[1]) 126 | self.conv3 = DecoderBottleneckBlock(in_channels=512, hidden_channels=128, down_channels=256, layers=configs[2]) 127 | self.conv4 = DecoderBottleneckBlock(in_channels=256, hidden_channels=64, down_channels=64, layers=configs[3]) 128 | 129 | 130 | else: 131 | 132 | self.conv1 = DecoderResidualBlock(hidden_channels=512, output_channels=256, layers=configs[0]) 133 | self.conv2 = DecoderResidualBlock(hidden_channels=256, output_channels=128, layers=configs[1]) 134 | self.conv3 = DecoderResidualBlock(hidden_channels=128, output_channels=64, layers=configs[2]) 135 | self.conv4 = DecoderResidualBlock(hidden_channels=64, output_channels=64, layers=configs[3]) 136 | 137 | self.conv5 = nn.Sequential( 138 | nn.BatchNorm2d(num_features=64), 139 | nn.ReLU(inplace=True), 140 | nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=7, stride=2, padding=3, output_padding=1, bias=False), 141 | ) 142 | 143 | self.gate = nn.Sigmoid() 144 | 145 | def forward(self, x): 146 | 147 | x = self.conv1(x) 148 | x = self.conv2(x) 149 | x = self.conv3(x) 150 | x = self.conv4(x) 151 | x = self.conv5(x) 152 | x = self.gate(x) 153 | 154 | return x 155 | 156 | class EncoderResidualBlock(nn.Module): 157 | 158 | def __init__(self, in_channels, hidden_channels, layers, downsample_method="conv"): 159 | super(EncoderResidualBlock, self).__init__() 160 | 161 | if downsample_method == "conv": 162 | 163 | for i in range(layers): 164 | 165 | if i == 0: 166 | layer = EncoderResidualLayer(in_channels=in_channels, hidden_channels=hidden_channels, downsample=True) 167 | else: 168 | layer = EncoderResidualLayer(in_channels=hidden_channels, hidden_channels=hidden_channels, downsample=False) 169 | 170 | self.add_module('%02d EncoderLayer' % i, layer) 171 | 172 | elif downsample_method == "pool": 173 | 174 | maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 175 | 176 | self.add_module('00 MaxPooling', maxpool) 177 | 178 | for i in range(layers): 179 | 180 | if i == 0: 181 | layer = EncoderResidualLayer(in_channels=in_channels, hidden_channels=hidden_channels, downsample=False) 182 | else: 183 | layer = EncoderResidualLayer(in_channels=hidden_channels, hidden_channels=hidden_channels, downsample=False) 184 | 185 | self.add_module('%02d EncoderLayer' % (i+1), layer) 186 | 187 | def forward(self, x): 188 | 189 | for name, layer in self.named_children(): 190 | 191 | x = layer(x) 192 | 193 | return x 194 | 195 | class EncoderBottleneckBlock(nn.Module): 196 | 197 | def __init__(self, in_channels, hidden_channels, up_channels, layers, downsample_method="conv"): 198 | super(EncoderBottleneckBlock, self).__init__() 199 | 200 | if downsample_method == "conv": 201 | 202 | for i in range(layers): 203 | 204 | if i == 0: 205 | layer = EncoderBottleneckLayer(in_channels=in_channels, hidden_channels=hidden_channels, up_channels=up_channels, downsample=True) 206 | else: 207 | layer = EncoderBottleneckLayer(in_channels=up_channels, hidden_channels=hidden_channels, up_channels=up_channels, downsample=False) 208 | 209 | self.add_module('%02d EncoderLayer' % i, layer) 210 | 211 | elif downsample_method == "pool": 212 | 213 | maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 214 | 215 | self.add_module('00 MaxPooling', maxpool) 216 | 217 | for i in range(layers): 218 | 219 | if i == 0: 220 | layer = EncoderBottleneckLayer(in_channels=in_channels, hidden_channels=hidden_channels, up_channels=up_channels, downsample=False) 221 | else: 222 | layer = EncoderBottleneckLayer(in_channels=up_channels, hidden_channels=hidden_channels, up_channels=up_channels, downsample=False) 223 | 224 | self.add_module('%02d EncoderLayer' % (i+1), layer) 225 | 226 | def forward(self, x): 227 | 228 | for name, layer in self.named_children(): 229 | 230 | x = layer(x) 231 | 232 | return x 233 | 234 | 235 | class DecoderResidualBlock(nn.Module): 236 | 237 | def __init__(self, hidden_channels, output_channels, layers): 238 | super(DecoderResidualBlock, self).__init__() 239 | 240 | for i in range(layers): 241 | 242 | if i == layers - 1: 243 | layer = DecoderResidualLayer(hidden_channels=hidden_channels, output_channels=output_channels, upsample=True) 244 | else: 245 | layer = DecoderResidualLayer(hidden_channels=hidden_channels, output_channels=hidden_channels, upsample=False) 246 | 247 | self.add_module('%02d EncoderLayer' % i, layer) 248 | 249 | def forward(self, x): 250 | 251 | for name, layer in self.named_children(): 252 | 253 | x = layer(x) 254 | 255 | return x 256 | 257 | class DecoderBottleneckBlock(nn.Module): 258 | 259 | def __init__(self, in_channels, hidden_channels, down_channels, layers): 260 | super(DecoderBottleneckBlock, self).__init__() 261 | 262 | for i in range(layers): 263 | 264 | if i == layers - 1: 265 | layer = DecoderBottleneckLayer(in_channels=in_channels, hidden_channels=hidden_channels, down_channels=down_channels, upsample=True) 266 | else: 267 | layer = DecoderBottleneckLayer(in_channels=in_channels, hidden_channels=hidden_channels, down_channels=in_channels, upsample=False) 268 | 269 | self.add_module('%02d EncoderLayer' % i, layer) 270 | 271 | 272 | def forward(self, x): 273 | 274 | for name, layer in self.named_children(): 275 | 276 | x = layer(x) 277 | 278 | return x 279 | 280 | 281 | class EncoderResidualLayer(nn.Module): 282 | 283 | def __init__(self, in_channels, hidden_channels, downsample): 284 | super(EncoderResidualLayer, self).__init__() 285 | 286 | if downsample: 287 | self.weight_layer1 = nn.Sequential( 288 | nn.Conv2d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=3, stride=2, padding=1, bias=False), 289 | nn.BatchNorm2d(num_features=hidden_channels), 290 | nn.ReLU(inplace=True), 291 | ) 292 | else: 293 | self.weight_layer1 = nn.Sequential( 294 | nn.Conv2d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1, bias=False), 295 | nn.BatchNorm2d(num_features=hidden_channels), 296 | nn.ReLU(inplace=True), 297 | ) 298 | 299 | self.weight_layer2 = nn.Sequential( 300 | nn.Conv2d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1, bias=False), 301 | nn.BatchNorm2d(num_features=hidden_channels), 302 | ) 303 | 304 | if downsample: 305 | self.downsample = nn.Sequential( 306 | nn.Conv2d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=1, stride=2, padding=0, bias=False), 307 | nn.BatchNorm2d(num_features=hidden_channels), 308 | ) 309 | else: 310 | self.downsample = None 311 | 312 | self.relu = nn.Sequential( 313 | nn.ReLU(inplace=True) 314 | ) 315 | 316 | def forward(self, x): 317 | 318 | identity = x 319 | 320 | x = self.weight_layer1(x) 321 | x = self.weight_layer2(x) 322 | 323 | if self.downsample is not None: 324 | identity = self.downsample(identity) 325 | 326 | x = x + identity 327 | 328 | x = self.relu(x) 329 | 330 | return x 331 | 332 | class EncoderBottleneckLayer(nn.Module): 333 | 334 | def __init__(self, in_channels, hidden_channels, up_channels, downsample): 335 | super(EncoderBottleneckLayer, self).__init__() 336 | 337 | if downsample: 338 | self.weight_layer1 = nn.Sequential( 339 | nn.Conv2d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=1, stride=2, padding=0, bias=False), 340 | nn.BatchNorm2d(num_features=hidden_channels), 341 | nn.ReLU(inplace=True), 342 | ) 343 | else: 344 | self.weight_layer1 = nn.Sequential( 345 | nn.Conv2d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=1, stride=1, padding=0, bias=False), 346 | nn.BatchNorm2d(num_features=hidden_channels), 347 | nn.ReLU(inplace=True), 348 | ) 349 | 350 | self.weight_layer2 = nn.Sequential( 351 | nn.Conv2d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1, bias=False), 352 | nn.BatchNorm2d(num_features=hidden_channels), 353 | nn.ReLU(inplace=True), 354 | ) 355 | 356 | self.weight_layer3 = nn.Sequential( 357 | nn.Conv2d(in_channels=hidden_channels, out_channels=up_channels, kernel_size=1, stride=1, padding=0, bias=False), 358 | nn.BatchNorm2d(num_features=up_channels), 359 | ) 360 | 361 | if downsample: 362 | self.downsample = nn.Sequential( 363 | nn.Conv2d(in_channels=in_channels, out_channels=up_channels, kernel_size=1, stride=2, padding=0, bias=False), 364 | nn.BatchNorm2d(num_features=up_channels), 365 | ) 366 | elif (in_channels != up_channels): 367 | self.downsample = None 368 | self.up_scale = nn.Sequential( 369 | nn.Conv2d(in_channels=in_channels, out_channels=up_channels, kernel_size=1, stride=1, padding=0, bias=False), 370 | nn.BatchNorm2d(num_features=up_channels), 371 | ) 372 | else: 373 | self.downsample = None 374 | self.up_scale = None 375 | 376 | self.relu = nn.Sequential( 377 | nn.ReLU(inplace=True) 378 | ) 379 | 380 | def forward(self, x): 381 | 382 | identity = x 383 | 384 | x = self.weight_layer1(x) 385 | x = self.weight_layer2(x) 386 | x = self.weight_layer3(x) 387 | 388 | if self.downsample is not None: 389 | identity = self.downsample(identity) 390 | elif self.up_scale is not None: 391 | identity = self.up_scale(identity) 392 | 393 | x = x + identity 394 | 395 | x = self.relu(x) 396 | 397 | return x 398 | 399 | class DecoderResidualLayer(nn.Module): 400 | 401 | def __init__(self, hidden_channels, output_channels, upsample): 402 | super(DecoderResidualLayer, self).__init__() 403 | 404 | self.weight_layer1 = nn.Sequential( 405 | nn.BatchNorm2d(num_features=hidden_channels), 406 | nn.ReLU(inplace=True), 407 | nn.Conv2d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1, bias=False), 408 | ) 409 | 410 | if upsample: 411 | self.weight_layer2 = nn.Sequential( 412 | nn.BatchNorm2d(num_features=hidden_channels), 413 | nn.ReLU(inplace=True), 414 | nn.ConvTranspose2d(in_channels=hidden_channels, out_channels=output_channels, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False) 415 | ) 416 | else: 417 | self.weight_layer2 = nn.Sequential( 418 | nn.BatchNorm2d(num_features=hidden_channels), 419 | nn.ReLU(inplace=True), 420 | nn.Conv2d(in_channels=hidden_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1, bias=False), 421 | ) 422 | 423 | if upsample: 424 | self.upsample = nn.Sequential( 425 | nn.BatchNorm2d(num_features=hidden_channels), 426 | nn.ReLU(inplace=True), 427 | nn.ConvTranspose2d(in_channels=hidden_channels, out_channels=output_channels, kernel_size=1, stride=2, output_padding=1, bias=False) 428 | ) 429 | else: 430 | self.upsample = None 431 | 432 | def forward(self, x): 433 | 434 | identity = x 435 | 436 | x = self.weight_layer1(x) 437 | x = self.weight_layer2(x) 438 | 439 | if self.upsample is not None: 440 | identity = self.upsample(identity) 441 | 442 | x = x + identity 443 | 444 | return x 445 | 446 | class DecoderBottleneckLayer(nn.Module): 447 | 448 | def __init__(self, in_channels, hidden_channels, down_channels, upsample): 449 | super(DecoderBottleneckLayer, self).__init__() 450 | 451 | self.weight_layer1 = nn.Sequential( 452 | nn.BatchNorm2d(num_features=in_channels), 453 | nn.ReLU(inplace=True), 454 | nn.Conv2d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=1, stride=1, padding=0, bias=False), 455 | ) 456 | 457 | self.weight_layer2 = nn.Sequential( 458 | nn.BatchNorm2d(num_features=hidden_channels), 459 | nn.ReLU(inplace=True), 460 | nn.Conv2d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1, bias=False), 461 | ) 462 | 463 | if upsample: 464 | self.weight_layer3 = nn.Sequential( 465 | nn.BatchNorm2d(num_features=hidden_channels), 466 | nn.ReLU(inplace=True), 467 | nn.ConvTranspose2d(in_channels=hidden_channels, out_channels=down_channels, kernel_size=1, stride=2, output_padding=1, bias=False) 468 | ) 469 | else: 470 | self.weight_layer3 = nn.Sequential( 471 | nn.BatchNorm2d(num_features=hidden_channels), 472 | nn.ReLU(inplace=True), 473 | nn.Conv2d(in_channels=hidden_channels, out_channels=down_channels, kernel_size=1, stride=1, padding=0, bias=False) 474 | ) 475 | 476 | if upsample: 477 | self.upsample = nn.Sequential( 478 | nn.BatchNorm2d(num_features=in_channels), 479 | nn.ReLU(inplace=True), 480 | nn.ConvTranspose2d(in_channels=in_channels, out_channels=down_channels, kernel_size=1, stride=2, output_padding=1, bias=False) 481 | ) 482 | elif (in_channels != down_channels): 483 | self.upsample = None 484 | self.down_scale = nn.Sequential( 485 | nn.BatchNorm2d(num_features=in_channels), 486 | nn.ReLU(inplace=True), 487 | nn.Conv2d(in_channels=in_channels, out_channels=down_channels, kernel_size=1, stride=1, padding=0, bias=False) 488 | ) 489 | else: 490 | self.upsample = None 491 | self.down_scale = None 492 | 493 | def forward(self, x): 494 | 495 | identity = x 496 | 497 | x = self.weight_layer1(x) 498 | x = self.weight_layer2(x) 499 | x = self.weight_layer3(x) 500 | 501 | if self.upsample is not None: 502 | identity = self.upsample(identity) 503 | elif self.down_scale is not None: 504 | identity = self.down_scale(identity) 505 | 506 | x = x + identity 507 | 508 | return x 509 | 510 | if __name__ == "__main__": 511 | 512 | configs, bottleneck = get_configs("resnet152") 513 | 514 | encoder = ResNetEncoder(configs, bottleneck) 515 | 516 | input = torch.randn((5,3,224,224)) 517 | 518 | print(input.shape) 519 | 520 | output = encoder(input) 521 | 522 | print(output.shape) 523 | 524 | decoder = ResNetDecoder(configs[::-1], bottleneck) 525 | 526 | output = decoder(output) 527 | 528 | print(output.shape) -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def get_configs(arch='vgg16'): 5 | 6 | if arch == 'vgg11': 7 | configs = [1, 1, 2, 2, 2] 8 | elif arch == 'vgg13': 9 | configs = [2, 2, 2, 2, 2] 10 | elif arch == 'vgg16': 11 | configs = [2, 2, 3, 3, 3] 12 | elif arch == 'vgg19': 13 | configs = [2, 2, 4, 4, 4] 14 | else: 15 | raise ValueError("Undefined model") 16 | 17 | return configs 18 | 19 | class VGGAutoEncoder(nn.Module): 20 | 21 | def __init__(self, configs): 22 | 23 | super(VGGAutoEncoder, self).__init__() 24 | 25 | # VGG without Bn as AutoEncoder is hard to train 26 | self.encoder = VGGEncoder(configs=configs, enable_bn=True) 27 | self.decoder = VGGDecoder(configs=configs[::-1], enable_bn=True) 28 | 29 | 30 | def forward(self, x): 31 | 32 | x = self.encoder(x) 33 | x = self.decoder(x) 34 | 35 | return x 36 | 37 | class VGG(nn.Module): 38 | 39 | def __init__(self, configs, num_classes=1000, img_size=224, enable_bn=False): 40 | super(VGG, self).__init__() 41 | 42 | self.encoder = VGGEncoder(configs=configs, enable_bn=enable_bn) 43 | 44 | self.img_size = img_size / 32 45 | 46 | self.fc = nn.Sequential( 47 | nn.Linear(in_features=int(self.img_size*self.img_size*512), out_features=4096), 48 | nn.Dropout(p=0.5), 49 | nn.ReLU(inplace=True), 50 | nn.Linear(in_features=4096, out_features=4096), 51 | nn.Dropout(p=0.5), 52 | nn.ReLU(inplace=True), 53 | nn.Linear(in_features=4096, out_features=num_classes) 54 | ) 55 | 56 | for m in self.modules(): 57 | if isinstance(m, nn.Conv2d): 58 | nn.init.normal_(m.weight, mean=0, std=0.01) 59 | if m.bias is not None: 60 | nn.init.constant_(m.bias, 0) 61 | if isinstance(m, nn.Linear): 62 | nn.init.normal_(m.weight, mean=0, std=0.01) 63 | nn.init.constant_(m.bias, 0) 64 | 65 | def forward(self, x): 66 | 67 | x = self.encoder(x) 68 | 69 | x = torch.flatten(x, 1) 70 | 71 | x = self.fc(x) 72 | 73 | return x 74 | 75 | class VGGEncoder(nn.Module): 76 | 77 | def __init__(self, configs, enable_bn=False): 78 | 79 | super(VGGEncoder, self).__init__() 80 | 81 | if len(configs) != 5: 82 | 83 | raise ValueError("There should be 5 stage in VGG") 84 | 85 | self.conv1 = EncoderBlock(input_dim=3, output_dim=64, hidden_dim=64, layers=configs[0], enable_bn=enable_bn) 86 | self.conv2 = EncoderBlock(input_dim=64, output_dim=128, hidden_dim=128, layers=configs[1], enable_bn=enable_bn) 87 | self.conv3 = EncoderBlock(input_dim=128, output_dim=256, hidden_dim=256, layers=configs[2], enable_bn=enable_bn) 88 | self.conv4 = EncoderBlock(input_dim=256, output_dim=512, hidden_dim=512, layers=configs[3], enable_bn=enable_bn) 89 | self.conv5 = EncoderBlock(input_dim=512, output_dim=512, hidden_dim=512, layers=configs[4], enable_bn=enable_bn) 90 | 91 | def forward(self, x): 92 | 93 | x = self.conv1(x) 94 | x = self.conv2(x) 95 | x = self.conv3(x) 96 | x = self.conv4(x) 97 | x = self.conv5(x) 98 | 99 | return x 100 | 101 | class VGGDecoder(nn.Module): 102 | 103 | def __init__(self, configs, enable_bn=False): 104 | 105 | super(VGGDecoder, self).__init__() 106 | 107 | if len(configs) != 5: 108 | 109 | raise ValueError("There should be 5 stage in VGG") 110 | 111 | self.conv1 = DecoderBlock(input_dim=512, output_dim=512, hidden_dim=512, layers=configs[0], enable_bn=enable_bn) 112 | self.conv2 = DecoderBlock(input_dim=512, output_dim=256, hidden_dim=512, layers=configs[1], enable_bn=enable_bn) 113 | self.conv3 = DecoderBlock(input_dim=256, output_dim=128, hidden_dim=256, layers=configs[2], enable_bn=enable_bn) 114 | self.conv4 = DecoderBlock(input_dim=128, output_dim=64, hidden_dim=128, layers=configs[3], enable_bn=enable_bn) 115 | self.conv5 = DecoderBlock(input_dim=64, output_dim=3, hidden_dim=64, layers=configs[4], enable_bn=enable_bn) 116 | self.gate = nn.Sigmoid() 117 | 118 | def forward(self, x): 119 | 120 | x = self.conv1(x) 121 | x = self.conv2(x) 122 | x = self.conv3(x) 123 | x = self.conv4(x) 124 | x = self.conv5(x) 125 | x = self.gate(x) 126 | 127 | return x 128 | 129 | class EncoderBlock(nn.Module): 130 | 131 | def __init__(self, input_dim, hidden_dim, output_dim, layers, enable_bn=False): 132 | 133 | super(EncoderBlock, self).__init__() 134 | 135 | if layers == 1: 136 | 137 | layer = EncoderLayer(input_dim=input_dim, output_dim=output_dim, enable_bn=enable_bn) 138 | 139 | self.add_module('0 EncoderLayer', layer) 140 | 141 | else: 142 | 143 | for i in range(layers): 144 | 145 | if i == 0: 146 | layer = EncoderLayer(input_dim=input_dim, output_dim=hidden_dim, enable_bn=enable_bn) 147 | elif i == (layers - 1): 148 | layer = EncoderLayer(input_dim=hidden_dim, output_dim=output_dim, enable_bn=enable_bn) 149 | else: 150 | layer = EncoderLayer(input_dim=hidden_dim, output_dim=hidden_dim, enable_bn=enable_bn) 151 | 152 | self.add_module('%d EncoderLayer' % i, layer) 153 | 154 | maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 155 | 156 | self.add_module('%d MaxPooling' % layers, maxpool) 157 | 158 | def forward(self, x): 159 | 160 | for name, layer in self.named_children(): 161 | 162 | x = layer(x) 163 | 164 | return x 165 | 166 | class DecoderBlock(nn.Module): 167 | 168 | def __init__(self, input_dim, hidden_dim, output_dim, layers, enable_bn=False): 169 | 170 | super(DecoderBlock, self).__init__() 171 | 172 | upsample = nn.ConvTranspose2d(in_channels=input_dim, out_channels=hidden_dim, kernel_size=2, stride=2) 173 | 174 | self.add_module('0 UpSampling', upsample) 175 | 176 | if layers == 1: 177 | 178 | layer = DecoderLayer(input_dim=input_dim, output_dim=output_dim, enable_bn=enable_bn) 179 | 180 | self.add_module('1 DecoderLayer', layer) 181 | 182 | else: 183 | 184 | for i in range(layers): 185 | 186 | if i == 0: 187 | layer = DecoderLayer(input_dim=input_dim, output_dim=hidden_dim, enable_bn=enable_bn) 188 | elif i == (layers - 1): 189 | layer = DecoderLayer(input_dim=hidden_dim, output_dim=output_dim, enable_bn=enable_bn) 190 | else: 191 | layer = DecoderLayer(input_dim=hidden_dim, output_dim=hidden_dim, enable_bn=enable_bn) 192 | 193 | self.add_module('%d DecoderLayer' % (i+1), layer) 194 | 195 | def forward(self, x): 196 | 197 | for name, layer in self.named_children(): 198 | 199 | x = layer(x) 200 | 201 | return x 202 | 203 | class EncoderLayer(nn.Module): 204 | 205 | def __init__(self, input_dim, output_dim, enable_bn): 206 | super(EncoderLayer, self).__init__() 207 | 208 | if enable_bn: 209 | self.layer = nn.Sequential( 210 | nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1), 211 | nn.BatchNorm2d(output_dim), 212 | nn.ReLU(inplace=True), 213 | ) 214 | else: 215 | self.layer = nn.Sequential( 216 | nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1), 217 | nn.ReLU(inplace=True), 218 | ) 219 | 220 | def forward(self, x): 221 | 222 | return self.layer(x) 223 | 224 | class DecoderLayer(nn.Module): 225 | 226 | def __init__(self, input_dim, output_dim, enable_bn): 227 | super(DecoderLayer, self).__init__() 228 | 229 | if enable_bn: 230 | self.layer = nn.Sequential( 231 | nn.BatchNorm2d(input_dim), 232 | nn.ReLU(inplace=True), 233 | nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1), 234 | ) 235 | else: 236 | self.layer = nn.Sequential( 237 | nn.ReLU(inplace=True), 238 | nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1), 239 | ) 240 | 241 | def forward(self, x): 242 | 243 | return self.layer(x) 244 | 245 | if __name__ == "__main__": 246 | 247 | input = torch.randn((5,3,224,224)) 248 | 249 | configs = get_configs() 250 | 251 | model = VGGAutoEncoder(configs) 252 | 253 | output = model(input) 254 | 255 | print(output.shape) 256 | 257 | 258 | 259 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | Pillow 3 | loguru 4 | numpy 5 | argparse 6 | matplotlib 7 | torchvision 8 | -------------------------------------------------------------------------------- /run/eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export CUDA_VISIBLE_DEVICES=0,1,2,3 3 | 4 | # settings 5 | MODEL_ARC=$1 6 | CKPT=$2 7 | DATASET=$3 8 | 9 | # CUDA_LAUNCH_BLOCKING=1 10 | python3 -u eval.py \ 11 | --arch $MODEL_ARC \ 12 | --val_list list/${DATASET}_list.txt \ 13 | --workers 16 \ 14 | --batch-size 128 \ 15 | --print-freq 10 \ 16 | --resume ${CKPT} -------------------------------------------------------------------------------- /run/evalall.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 3 | 4 | # settings 5 | MODEL_ARC=$1 6 | FOLDER=$2 7 | DATASET=$3 8 | 9 | # CUDA_LAUNCH_BLOCKING=1 10 | python3 -u eval.py \ 11 | --arch $MODEL_ARC \ 12 | --val_list list/${DATASET}_list.txt \ 13 | --workers 16 \ 14 | --batch-size 256 \ 15 | --print-freq 10 \ 16 | --folder ${FOLDER} \ 17 | --start_epoch 0 \ 18 | --epochs 100 19 | -------------------------------------------------------------------------------- /run/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 3 | 4 | # settings 5 | MODEL_ARC=$1 6 | DATASET=$2 7 | OUTPUT=results/${DATASET}-${MODEL_ARC}/ 8 | mkdir -p ${OUTPUT} 9 | 10 | # CUDA_LAUNCH_BLOCKING=1 11 | python3 -u train.py \ 12 | --arch $MODEL_ARC \ 13 | --train_list list/${DATASET}_list.txt \ 14 | --workers 16 \ 15 | --epochs 100 \ 16 | --start-epoch 0 \ 17 | --batch-size 256 \ 18 | --learning-rate 0.05 \ 19 | --momentum 0.9 \ 20 | --weight-decay 1e-4 \ 21 | --print-freq 10 \ 22 | --pth-save-fold ${OUTPUT} \ 23 | --pth-save-epoch 1 \ 24 | --parallel 1 \ 25 | --dist-url 'tcp://localhost:10001' 2>&1 | tee ${OUTPUT}/output.log -------------------------------------------------------------------------------- /tools/decode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | 5 | import matplotlib.pyplot as plt 6 | 7 | import torch 8 | 9 | from torchvision.transforms import transforms 10 | 11 | import sys 12 | sys.path.append("./") 13 | 14 | import utils 15 | import models.builer as builder 16 | 17 | import os 18 | 19 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 20 | 21 | def get_args(): 22 | # parse the args 23 | print('=> parse the args ...') 24 | parser = argparse.ArgumentParser(description='Trainer for auto encoder') 25 | parser.add_argument('--arch', default='vgg16', type=str, 26 | help='backbone architechture') 27 | parser.add_argument('--resume', type=str) 28 | 29 | args = parser.parse_args() 30 | 31 | args.parallel = 0 32 | args.batch_size = 1 33 | args.workers = 0 34 | 35 | return args 36 | 37 | def random_sample(arch): 38 | 39 | if arch in ["vgg11", "vgg13", "vgg16", "vgg19", "resnet18", "resnet34"]: 40 | return torch.randn((1,512,7,7)) 41 | elif arch in ["resnet50", "resnet101", "resnet152"]: 42 | return torch.randn((1,2048,7,7)) 43 | else: 44 | raise NotImplementedError("Do not have implemention except VGG and ResNet") 45 | 46 | def main(args): 47 | print('=> torch version : {}'.format(torch.__version__)) 48 | 49 | utils.init_seeds(1, cuda_deterministic=False) 50 | 51 | print('=> modeling the network ...') 52 | model = builder.BuildAutoEncoder(args) 53 | total_params = sum(p.numel() for p in model.parameters()) 54 | print('=> num of params: {} ({}M)'.format(total_params, int(total_params * 4 / (1024*1024)))) 55 | 56 | print('=> loading pth from {} ...'.format(args.resume)) 57 | utils.load_dict(args.resume, model) 58 | 59 | trans = transforms.ToPILImage() 60 | 61 | plt.figure(figsize=(16, 9)) 62 | 63 | model.eval() 64 | print('=> Genarating ...') 65 | with torch.no_grad(): 66 | for i in range(128): 67 | 68 | input = random_sample(arch=args.arch).cuda() 69 | 70 | output = model.module.decoder(input) 71 | 72 | output = trans(output.squeeze().cpu()) 73 | 74 | plt.subplot(8,16,i+1, xticks=[], yticks=[]) 75 | plt.imshow(output) 76 | 77 | plt.savefig('figs/generation.jpg') 78 | 79 | if __name__ == '__main__': 80 | 81 | args = get_args() 82 | 83 | main(args) 84 | 85 | 86 | -------------------------------------------------------------------------------- /tools/encode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | 5 | from PIL import Image 6 | 7 | import torch 8 | from torchvision.transforms import transforms 9 | 10 | import sys 11 | sys.path.append("./") 12 | 13 | import utils 14 | import models.builer as builder 15 | 16 | import os 17 | 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 19 | 20 | def get_args(): 21 | # parse the args 22 | print('=> parse the args ...') 23 | parser = argparse.ArgumentParser(description='Encoder for auto encoder') 24 | parser.add_argument('--arch', default='vgg16', type=str, 25 | help='backbone architechture') 26 | parser.add_argument('--resume', type=str) 27 | parser.add_argument('--img_path',type=str) 28 | 29 | args = parser.parse_args() 30 | 31 | args.parallel = 0 32 | args.batch_size = 1 33 | args.workers = 0 34 | 35 | return args 36 | 37 | def encode(model, img): 38 | 39 | with torch.no_grad(): 40 | 41 | code = model.module.encoder(img).cpu().numpy() 42 | 43 | return code 44 | 45 | def main(args): 46 | print('=> torch version : {}'.format(torch.__version__)) 47 | 48 | utils.init_seeds(1, cuda_deterministic=False) 49 | 50 | print('=> modeling the network ...') 51 | model = builder.BuildAutoEncoder(args) 52 | total_params = sum(p.numel() for p in model.parameters()) 53 | print('=> num of params: {} ({}M)'.format(total_params, int(total_params * 4 / (1024*1024)))) 54 | 55 | print('=> loading pth from {} ...'.format(args.resume)) 56 | utils.load_dict(args.resume, model) 57 | 58 | trans = transforms.Compose([ 59 | transforms.Resize(256), 60 | transforms.CenterCrop(224), 61 | transforms.ToTensor() 62 | ]) 63 | 64 | img = Image.open(args.img_path).convert("RGB") 65 | 66 | img = trans(img).unsqueeze(0).cuda() 67 | 68 | model.eval() 69 | 70 | code = encode(model, img) 71 | 72 | print(code.shape) 73 | 74 | # To do : any other postprocessing 75 | 76 | if __name__ == '__main__': 77 | 78 | args = get_args() 79 | 80 | main(args) 81 | 82 | 83 | -------------------------------------------------------------------------------- /tools/generate_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--name', type=str) 6 | parser.add_argument('--path', type=str) 7 | args = parser.parse_args() 8 | 9 | folders = os.listdir(args.path) 10 | 11 | if not os.path.exists("list"): 12 | os.makedirs("list") 13 | 14 | fl = open('list/' + args.name + '_list.txt', 'w') 15 | fn = open('list/' + args.name + '_name.txt', 'w') 16 | 17 | for i, folder in enumerate(folders): 18 | 19 | fn.write(str(i) + ' ' + folder + '\n') 20 | 21 | folder_path = os.path.join(args.path, folder) 22 | files = os.listdir(folder_path) 23 | 24 | for file in files: 25 | 26 | fl.write('{} {}\n'.format(os.path.join(folder_path, file), i)) 27 | 28 | fl.close() 29 | fn.close() 30 | -------------------------------------------------------------------------------- /tools/reconstruct.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | 5 | import matplotlib.pyplot as plt 6 | 7 | import torch 8 | 9 | from torchvision.transforms import transforms 10 | 11 | import sys 12 | sys.path.append("./") 13 | 14 | import utils 15 | import models.builer as builder 16 | import dataloader 17 | 18 | import os 19 | 20 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 21 | 22 | def get_args(): 23 | # parse the args 24 | print('=> parse the args ...') 25 | parser = argparse.ArgumentParser(description='Trainer for auto encoder') 26 | parser.add_argument('--arch', default='vgg16', type=str, 27 | help='backbone architechture') 28 | parser.add_argument('--resume', type=str) 29 | parser.add_argument('--val_list', type=str) 30 | 31 | args = parser.parse_args() 32 | 33 | args.parallel = 0 34 | args.batch_size = 1 35 | args.workers = 0 36 | 37 | return args 38 | 39 | def main(args): 40 | print('=> torch version : {}'.format(torch.__version__)) 41 | 42 | utils.init_seeds(1, cuda_deterministic=False) 43 | 44 | print('=> modeling the network ...') 45 | model = builder.BuildAutoEncoder(args) 46 | total_params = sum(p.numel() for p in model.parameters()) 47 | print('=> num of params: {} ({}M)'.format(total_params, int(total_params * 4 / (1024*1024)))) 48 | 49 | print('=> loading pth from {} ...'.format(args.resume)) 50 | utils.load_dict(args.resume, model) 51 | 52 | print('=> building the dataloader ...') 53 | train_loader = dataloader.val_loader(args) 54 | 55 | plt.figure(figsize=(16, 9)) 56 | 57 | model.eval() 58 | print('=> reconstructing ...') 59 | with torch.no_grad(): 60 | for i, (input, target) in enumerate(train_loader): 61 | 62 | input = input.cuda(non_blocking=True) 63 | target = target.cuda(non_blocking=True) 64 | 65 | output = model(input) 66 | 67 | input = transforms.ToPILImage()(input.squeeze().cpu()) 68 | output = transforms.ToPILImage()(output.squeeze().cpu()) 69 | 70 | plt.subplot(8,16,2*i+1, xticks=[], yticks=[]) 71 | plt.imshow(input) 72 | 73 | plt.subplot(8,16,2*i+2, xticks=[], yticks=[]) 74 | plt.imshow(output) 75 | 76 | if i == 63: 77 | break 78 | 79 | plt.savefig('figs/reconstruction.jpg') 80 | 81 | if __name__ == '__main__': 82 | 83 | args = get_args() 84 | 85 | main(args) 86 | 87 | 88 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import time 5 | import argparse 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.multiprocessing as mp 10 | 11 | import utils 12 | import models.builer as builder 13 | import dataloader 14 | 15 | def get_args(): 16 | # parse the args 17 | print('=> parse the args ...') 18 | parser = argparse.ArgumentParser(description='Trainer for auto encoder') 19 | parser.add_argument('--arch', default='vgg16', type=str, 20 | help='backbone architechture') 21 | parser.add_argument('--train_list', type=str) 22 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 23 | help='number of data loading workers (default: 0)') 24 | parser.add_argument('--epochs', default=20, type=int, metavar='N', 25 | help='number of total epochs to run') 26 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 27 | help='manual epoch number (useful on restarts)') 28 | parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', 29 | help='mini-batch size (default: 256), this is the total ' 30 | 'batch size of all GPUs on the current node when ' 31 | 'using Data Parallel or Distributed Data Parallel') 32 | 33 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 34 | metavar='LR', help='initial learning rate', dest='lr') 35 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 36 | help='momentum') 37 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 38 | metavar='W', help='weight decay (default: 1e-4)', 39 | dest='weight_decay') 40 | 41 | parser.add_argument('-p', '--print-freq', default=100, type=int, 42 | metavar='N', help='print frequency (default: 10)') 43 | 44 | parser.add_argument('--pth-save-fold', default='results/tmp', type=str, 45 | help='The folder to save pths') 46 | parser.add_argument('--pth-save-epoch', default=1, type=int, 47 | help='The epoch to save pth') 48 | parser.add_argument('--parallel', type=int, default=1, 49 | help='1 for parallel, 0 for non-parallel') 50 | parser.add_argument('--dist-url', default='tcp://localhost:10001', type=str, 51 | help='url used to set up distributed training') 52 | 53 | args = parser.parse_args() 54 | 55 | return args 56 | 57 | def main(args): 58 | print('=> torch version : {}'.format(torch.__version__)) 59 | ngpus_per_node = torch.cuda.device_count() 60 | print('=> ngpus : {}'.format(ngpus_per_node)) 61 | 62 | if args.parallel == 1: 63 | # single machine multi card 64 | args.gpus = ngpus_per_node 65 | args.nodes = 1 66 | args.nr = 0 67 | args.world_size = args.gpus * args.nodes 68 | 69 | args.workers = int(args.workers / args.world_size) 70 | args.batch_size = int(args.batch_size / args.world_size) 71 | mp.spawn(main_worker, nprocs=args.gpus, args=(args,)) 72 | else: 73 | args.world_size = 1 74 | main_worker(ngpus_per_node, args) 75 | 76 | def main_worker(gpu, args): 77 | utils.init_seeds(1 + gpu, cuda_deterministic=False) 78 | if args.parallel == 1: 79 | args.gpu = gpu 80 | args.rank = args.nr * args.gpus + args.gpu 81 | 82 | torch.cuda.set_device(gpu) 83 | torch.distributed.init_process_group(backend='nccl', init_method=args.dist_url, world_size=args.world_size, rank=args.rank) 84 | 85 | else: 86 | # two dummy variable, not real 87 | args.rank = 0 88 | args.gpus = 1 89 | if args.rank == 0: 90 | print('=> modeling the network {} ...'.format(args.arch)) 91 | model = builder.BuildAutoEncoder(args) 92 | if args.rank == 0: 93 | total_params = sum(p.numel() for p in model.parameters()) 94 | print('=> num of params: {} ({}M)'.format(total_params, int(total_params * 4 / (1024*1024)))) 95 | 96 | if args.rank == 0: 97 | print('=> building the oprimizer ...') 98 | # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), args.lr,) 99 | optimizer = torch.optim.SGD( 100 | filter(lambda p: p.requires_grad, model.parameters()), 101 | args.lr, 102 | momentum=args.momentum, 103 | weight_decay=args.weight_decay) 104 | if args.rank == 0: 105 | print('=> building the dataloader ...') 106 | train_loader = dataloader.train_loader(args) 107 | 108 | if args.rank == 0: 109 | print('=> building the criterion ...') 110 | criterion = nn.MSELoss() 111 | 112 | global iters 113 | iters = 0 114 | 115 | model.train() 116 | if args.rank == 0: 117 | print('=> starting training engine ...') 118 | for epoch in range(args.start_epoch, args.epochs): 119 | 120 | global current_lr 121 | current_lr = utils.adjust_learning_rate_cosine(optimizer, epoch, args) 122 | 123 | train_loader.sampler.set_epoch(epoch) 124 | 125 | # train for one epoch 126 | do_train(train_loader, model, criterion, optimizer, epoch, args) 127 | 128 | # save pth 129 | if epoch % args.pth_save_epoch == 0 and args.rank == 0: 130 | state_dict = model.state_dict() 131 | 132 | torch.save( 133 | { 134 | 'epoch': epoch + 1, 135 | 'arch': args.arch, 136 | 'state_dict': state_dict, 137 | 'optimizer' : optimizer.state_dict(), 138 | }, 139 | os.path.join(args.pth_save_fold, '{}.pth'.format(str(epoch).zfill(3))) 140 | ) 141 | 142 | print(' : save pth for epoch {}'.format(epoch + 1)) 143 | 144 | 145 | def do_train(train_loader, model, criterion, optimizer, epoch, args): 146 | batch_time = utils.AverageMeter('Time', ':6.2f') 147 | data_time = utils.AverageMeter('Data', ':2.2f') 148 | losses = utils.AverageMeter('Loss', ':.4f') 149 | learning_rate = utils.AverageMeter('LR', ':.4f') 150 | 151 | progress = utils.ProgressMeter( 152 | len(train_loader), 153 | [batch_time, data_time, losses, learning_rate], 154 | prefix="Epoch: [{}]".format(epoch+1)) 155 | end = time.time() 156 | 157 | # update lr 158 | learning_rate.update(current_lr) 159 | 160 | for i, (input, target) in enumerate(train_loader): 161 | # measure data loading time 162 | data_time.update(time.time() - end) 163 | global iters 164 | iters += 1 165 | 166 | input = input.cuda(non_blocking=True) 167 | target = target.cuda(non_blocking=True) 168 | 169 | output = model(input) 170 | 171 | loss = criterion(output, target) 172 | 173 | # compute gradient and do solver step 174 | optimizer.zero_grad() 175 | # backward 176 | loss.backward() 177 | # update weights 178 | optimizer.step() 179 | 180 | # syn for logging 181 | torch.cuda.synchronize() 182 | 183 | # record loss 184 | losses.update(loss.item(), input.size(0)) 185 | 186 | # measure elapsed time 187 | if args.rank == 0: 188 | batch_time.update(time.time() - end) 189 | end = time.time() 190 | 191 | if i % args.print_freq == 0 and args.rank == 0: 192 | progress.display(i) 193 | 194 | if __name__ == '__main__': 195 | 196 | args = get_args() 197 | 198 | main(args) 199 | 200 | 201 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import random 5 | import numpy as np 6 | 7 | from loguru import logger 8 | 9 | import torch 10 | from torch.backends import cudnn 11 | 12 | def init_seeds(seed=0, cuda_deterministic=True): 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | os.environ['PYTHONHASHSEED'] = str(seed) 18 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 19 | if cuda_deterministic: # slower, more reproducible 20 | cudnn.deterministic = True 21 | cudnn.benchmark = False 22 | else: # faster, less reproducible 23 | cudnn.deterministic = False 24 | cudnn.benchmark = True 25 | 26 | class AverageMeter(object): 27 | """Computes and stores the average and current value""" 28 | 29 | def __init__(self, name, fmt=':f'): 30 | self.name = name 31 | self.fmt = fmt 32 | self.reset() 33 | 34 | def reset(self): 35 | self.val = 0 36 | self.avg = 0 37 | self.sum = 0 38 | self.count = 0 39 | 40 | def update(self, val, n=1): 41 | self.val = val 42 | self.sum += val * n 43 | self.count += n 44 | self.avg = self.sum / self.count 45 | 46 | def __str__(self): 47 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 48 | return fmtstr.format(**self.__dict__) 49 | 50 | 51 | class ProgressMeter(object): 52 | def __init__(self, num_batches, meters, prefix=""): 53 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 54 | self.meters = meters 55 | self.prefix = prefix 56 | 57 | def display(self, batch): 58 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 59 | entries += [str(meter) for meter in self.meters] 60 | logger.info("\t".join(entries)) 61 | 62 | def _get_batch_fmtstr(self, num_batches): 63 | num_digits = len(str(num_batches // 1)) 64 | fmt = '{:' + str(num_digits) + 'd}' 65 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 66 | 67 | def adjust_learning_rate(optimizer, epoch, args): 68 | decay = args.lr_drop_ratio if epoch in args.lr_drop_epoch else 1.0 69 | lr = args.lr * decay 70 | global current_lr 71 | current_lr = lr 72 | for param_group in optimizer.param_groups: 73 | param_group['lr'] = lr 74 | args.lr = current_lr 75 | return current_lr 76 | 77 | def adjust_learning_rate_cosine(optimizer, epoch, args): 78 | """cosine learning rate annealing without restart""" 79 | lr = args.lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 80 | global current_lr 81 | current_lr = lr 82 | for param_group in optimizer.param_groups: 83 | param_group['lr'] = lr 84 | return current_lr 85 | 86 | def load_dict(resume_path, model): 87 | if os.path.isfile(resume_path): 88 | checkpoint = torch.load(resume_path) 89 | model_dict = model.state_dict() 90 | model_dict.update(checkpoint['state_dict']) 91 | model.load_state_dict(model_dict) 92 | # delete to release more space 93 | del checkpoint 94 | else: 95 | sys.exit("=> No checkpoint found at '{}'".format(resume_path)) 96 | return model --------------------------------------------------------------------------------