├── data ├── __init__.py ├── base_dataset.pyc ├── data_loader.pyc ├── image_folder.pyc ├── data_loader.py ├── base_data_loader.py ├── single_dataset.py ├── aligned_random_dataset.py ├── custom_dataset_data_loader.py ├── image_folder.py ├── unaligned_dataset.py ├── coco_dataset.py ├── base_dataset.py ├── aligned_dataset.py └── cocoseg_dataset.py ├── models ├── __init__.py ├── models.pyc ├── base_model.pyc ├── cycle_attn_gan_model.pyc ├── models.py ├── test_model.py ├── base_model.py ├── pix2pix_model.py ├── cycle_gan_model.py ├── cycle_attn_gan_model.py └── networks.py ├── util ├── html.pyc ├── util.pyc ├── image_pool.pyc ├── visualizer.pyc ├── png.py ├── html.py ├── image_pool.py ├── inception_classify.py ├── util.py ├── get_data.py └── visualizer.py ├── options ├── train_options.pyc ├── test_options.py ├── train_options.py └── base_options.py ├── .gitignore ├── scripts ├── train_attngan.sh └── download_cyclegan_model.sh ├── README.md ├── datasets └── download_cyclegan_dataset.sh ├── test.py └── train.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /util/html.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyuanc91/Attention-GAN/HEAD/util/html.pyc -------------------------------------------------------------------------------- /util/util.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyuanc91/Attention-GAN/HEAD/util/util.pyc -------------------------------------------------------------------------------- /models/models.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyuanc91/Attention-GAN/HEAD/models/models.pyc -------------------------------------------------------------------------------- /util/image_pool.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyuanc91/Attention-GAN/HEAD/util/image_pool.pyc -------------------------------------------------------------------------------- /util/visualizer.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyuanc91/Attention-GAN/HEAD/util/visualizer.pyc -------------------------------------------------------------------------------- /data/base_dataset.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyuanc91/Attention-GAN/HEAD/data/base_dataset.pyc -------------------------------------------------------------------------------- /data/data_loader.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyuanc91/Attention-GAN/HEAD/data/data_loader.pyc -------------------------------------------------------------------------------- /data/image_folder.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyuanc91/Attention-GAN/HEAD/data/image_folder.pyc -------------------------------------------------------------------------------- /models/base_model.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyuanc91/Attention-GAN/HEAD/models/base_model.pyc -------------------------------------------------------------------------------- /options/train_options.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyuanc91/Attention-GAN/HEAD/options/train_options.pyc -------------------------------------------------------------------------------- /models/cycle_attn_gan_model.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyuanc91/Attention-GAN/HEAD/models/cycle_attn_gan_model.pyc -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/ 2 | checkpoints/ 3 | results/ 4 | */*.pyc 5 | */**/*.pyc 6 | */**/**/*.pyc 7 | */**/**/**/*.pyc 8 | */**/**/**/**/*.pyc 9 | */*.so* 10 | */**/*.so* 11 | */**/*.dylib* 12 | test/data/legacy_serialized.pt 13 | *~ 14 | -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | def CreateDataLoader(opt): 3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 4 | data_loader = CustomDatasetDataLoader() 5 | print(data_loader.name()) 6 | data_loader.initialize(opt) 7 | return data_loader 8 | -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseDataLoader(): 3 | def __init__(self): 4 | pass 5 | 6 | def initialize(self, opt): 7 | self.opt = opt 8 | pass 9 | 10 | def load_data(): 11 | return None 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /scripts/train_attngan.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --dataroot datasets/horse2zebra \ 3 | --name zebra_attngan \ 4 | --model cycle_attn_gan \ 5 | --niter 110 \ 6 | --lr_policy step \ 7 | --pool_size 50 \ 8 | --display_freq 2000 \ 9 | --no_dropout \ 10 | --display_id 0 \ 11 | --batchSize 4 \ 12 | --gpu_ids 0 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Attention-GAN 2 | This repository provides the PyTorch code for our paper “Attention-GAN for object transfiguration in wild images”([ECCV2018](https://eccv2018.org/openaccess/content_ECCV_2018/papers/Xinyuan_Chen_Attention-GAN_for_Object_ECCV_2018_paper.pdf)). This code is based on the PyTorch (0.4.1) implementation of [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). You may need to train several times as the quality of the results are sensitive to the initialization. 3 | ### Data Preparation 4 | bash datasets/download_cyclegan_dataset.sh horse2zebra 5 | ## Train 6 | bash scripts/train_attngan.sh 7 | -------------------------------------------------------------------------------- /scripts/download_cyclegan_model.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | echo "Note: available models are apple2orange, orange2apple, summer2winter_yosemite, winter2summer_yosemite, horse2zebra, zebra2horse, monet2photo, style_monet, style_cezanne, style_ukiyoe, style_vangogh, sat2map, map2sat, cityscapes_photo2label, cityscapes_label2photo, facades_photo2label, facades_label2photo, iphone2dslr_flower" 4 | 5 | echo "Specified [$FILE]" 6 | 7 | mkdir -p ./checkpoints/${FILE}_pretrained 8 | MODEL_FILE=./checkpoints/${FILE}_pretrained/latest_net_G.pth 9 | URL=http://efrosgans.eecs.berkeley.edu/cyclegan/pretrained_models/$FILE.pth 10 | 11 | wget -N $URL -O $MODEL_FILE 12 | 13 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | def create_model(opt): 2 | model = None 3 | print(opt.model) 4 | if opt.model == 'cycle_gan': 5 | # assert(opt.dataset_mode == 'unaligned') 6 | from .cycle_gan_model import CycleGANModel 7 | model = CycleGANModel() 8 | elif opt.model == 'cycle_attn_gan': 9 | from .cycle_attn_gan_model import CycleAttnGANModel 10 | model = CycleAttnGANModel() 11 | elif opt.model == 'test': 12 | assert(opt.dataset_mode == 'single') 13 | from .test_model import TestModel 14 | model = TestModel() 15 | else: 16 | raise ValueError("Model [%s] not recognized." % opt.model) 17 | model.initialize(opt) 18 | print("model [%s] was created" % (model.name())) 19 | return model 20 | -------------------------------------------------------------------------------- /datasets/download_cyclegan_dataset.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then 4 | echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos" 5 | exit 1 6 | fi 7 | 8 | URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip 9 | ZIP_FILE=./datasets/$FILE.zip 10 | TARGET_DIR=./datasets/$FILE/ 11 | wget -N $URL -O $ZIP_FILE 12 | mkdir $TARGET_DIR 13 | unzip $ZIP_FILE -d ./datasets/ 14 | rm $ZIP_FILE 15 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | class TestOptions(BaseOptions): 4 | def initialize(self): 5 | BaseOptions.initialize(self) 6 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 7 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 8 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 9 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 10 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 11 | self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') 12 | self.isTrain = False 13 | -------------------------------------------------------------------------------- /util/png.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import zlib 3 | 4 | def encode(buf, width, height): 5 | """ buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """ 6 | assert (width * height * 3 == len(buf)) 7 | bpp = 3 8 | 9 | def raw_data(): 10 | # reverse the vertical line order and add null bytes at the start 11 | row_bytes = width * bpp 12 | for row_start in range((height - 1) * width * bpp, -1, -row_bytes): 13 | yield b'\x00' 14 | yield buf[row_start:row_start + row_bytes] 15 | 16 | def chunk(tag, data): 17 | return [ 18 | struct.pack("!I", len(data)), 19 | tag, 20 | data, 21 | struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag))) 22 | ] 23 | 24 | SIGNATURE = b'\x89PNG\r\n\x1a\n' 25 | COLOR_TYPE_RGB = 2 26 | COLOR_TYPE_RGBA = 6 27 | bit_depth = 8 28 | return b''.join( 29 | [ SIGNATURE ] + 30 | chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) + 31 | chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) + 32 | chunk(b'IEND', b'') 33 | ) 34 | -------------------------------------------------------------------------------- /data/single_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torchvision.transforms as transforms 3 | from data.base_dataset import BaseDataset, get_transform 4 | from data.image_folder import make_dataset 5 | from PIL import Image 6 | 7 | 8 | class SingleDataset(BaseDataset): 9 | def initialize(self, opt): 10 | self.opt = opt 11 | self.root = opt.dataroot 12 | self.dir_A = os.path.join(opt.dataroot) 13 | 14 | self.A_paths = make_dataset(self.dir_A) 15 | 16 | self.A_paths = sorted(self.A_paths) 17 | 18 | self.transform = get_transform(opt) 19 | 20 | def __getitem__(self, index): 21 | A_path = self.A_paths[index] 22 | A_img = Image.open(A_path).convert('RGB') 23 | A = self.transform(A_img) 24 | if self.opt.which_direction == 'BtoA': 25 | input_nc = self.opt.output_nc 26 | else: 27 | input_nc = self.opt.input_nc 28 | 29 | if input_nc == 1: # RGB to gray 30 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 31 | A = tmp.unsqueeze(0) 32 | 33 | return {'A': A, 'A_paths': A_path} 34 | 35 | def __len__(self): 36 | return len(self.A_paths) 37 | 38 | def name(self): 39 | return 'SingleImageDataset' 40 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | from options.test_options import TestOptions 4 | from data.data_loader import CreateDataLoader 5 | from models.models import create_model 6 | from util.visualizer import Visualizer 7 | from util import html 8 | opt = TestOptions().parse() 9 | opt.nThreads = 1 # test code only supports nThreads = 1 10 | opt.batchSize = 1 # test code only supports batchSize = 1 11 | opt.serial_batches = True # no shuffle   12 | opt.no_flip = True # no flip 13 | 14 | data_loader = CreateDataLoader(opt) 15 | dataset = data_loader.load_data() 16 | model = create_model(opt) 17 | visualizer = Visualizer(opt) 18 | # create website 19 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) 20 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 21 | print('name:',opt.name) 22 | print('web dir:', web_dir) 23 | 24 | for i, data in enumerate(dataset): 25 | if i >= opt.how_many: 26 | break 27 | model.set_input(data) 28 | model.test() 29 | visuals = model.get_current_visuals() 30 | img_path = model.get_image_paths() 31 | print('%d:process image... %s' % (i,img_path)) 32 | visualizer.save_images(webpage, visuals, img_path) 33 | webpage.save() 34 | -------------------------------------------------------------------------------- /data/aligned_random_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | from data.base_dataset import BaseDataset, get_transform 4 | import torchvision.transforms as transforms 5 | import torch 6 | from data.base_dataset import BaseDataset 7 | from data.image_folder import make_dataset 8 | from PIL import Image 9 | 10 | 11 | class AlignedRandomDataset(BaseDataset): 12 | def initialize(self, opt): 13 | self.opt = opt 14 | self.root = opt.dataroot 15 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) 16 | 17 | self.AB_paths = sorted(make_dataset(self.dir_AB)) 18 | n_images = len(self.AB_paths) 19 | self.A_paths = self.AB_paths[:n_images/2] 20 | self.B_paths = self.AB_paths[n_images/2:] 21 | self.A_size = len(self.A_paths) 22 | self.B_size = len(self.B_paths) 23 | self.transform = get_transform(opt) 24 | 25 | def __getitem__(self, index): 26 | A_path = self.A_paths[index % self.A_size] 27 | index_A = index % self.A_size 28 | index_B = random.randint(0, self.B_size - 1) 29 | B_path = self.B_paths[index_B] 30 | 31 | A_img = Image.open(A_path).convert('RGB') 32 | B_img = Image.open(B_path).convert('RGB') 33 | width = A_img.size[0] 34 | height = A_img.size[1] 35 | A = A_img.crop((0,0,width/2,height)) 36 | B = B_img.crop((width/2,0,width,height)) 37 | 38 | A = self.transform(A) 39 | B = self.transform(B) 40 | 41 | return {'A': A, 'B': B, 42 | 'A_paths': A_path, 'B_paths': B_path} 43 | def __len__(self): 44 | return max(self.A_size, self.B_size) 45 | 46 | def name(self): 47 | return 'AlignedRandomDataset' 48 | -------------------------------------------------------------------------------- /models/test_model.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | from collections import OrderedDict 3 | import util.util as util 4 | from .base_model import BaseModel 5 | from . import networks 6 | 7 | 8 | class TestModel(BaseModel): 9 | def name(self): 10 | return 'TestModel' 11 | 12 | def initialize(self, opt): 13 | assert(not opt.isTrain) 14 | BaseModel.initialize(self, opt) 15 | self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) 16 | 17 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, 18 | opt.ngf, opt.which_model_netG, 19 | opt.norm, not opt.no_dropout, 20 | opt.init_type, 21 | self.gpu_ids) 22 | which_epoch = opt.which_epoch 23 | self.load_network(self.netG, 'G', which_epoch) 24 | 25 | print('---------- Networks initialized -------------') 26 | networks.print_network(self.netG) 27 | print('-----------------------------------------------') 28 | 29 | def set_input(self, input): 30 | # we need to use single_dataset mode 31 | input_A = input['A'] 32 | self.input_A.resize_(input_A.size()).copy_(input_A) 33 | self.image_paths = input['A_paths'] 34 | 35 | def test(self): 36 | self.real_A = Variable(self.input_A) 37 | self.fake_B = self.netG.forward(self.real_A) 38 | 39 | # get image paths 40 | def get_image_paths(self): 41 | return self.image_paths 42 | 43 | def get_current_visuals(self): 44 | real_A = util.tensor2im(self.real_A.data) 45 | fake_B = util.tensor2im(self.fake_B.data) 46 | return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) 47 | -------------------------------------------------------------------------------- /data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | 4 | 5 | def CreateDataset(opt): 6 | dataset = None 7 | if opt.dataset_mode == 'aligned': 8 | from data.aligned_dataset import AlignedDataset 9 | dataset = AlignedDataset() 10 | elif opt.dataset_mode == 'unaligned': 11 | from data.unaligned_dataset import UnalignedDataset 12 | dataset = UnalignedDataset() 13 | elif opt.dataset_mode == 'single': 14 | from data.single_dataset import SingleDataset 15 | dataset = SingleDataset() 16 | elif opt.dataset_mode == 'alignedrandom': 17 | from data.aligned_random_dataset import AlignedRandomDataset 18 | dataset = AlignedRandomDataset() 19 | elif opt.dataset_mode == 'Coco': 20 | from data.coco_dataset import UnalignedCocoDataset 21 | dataset = UnalignedCocoDataset() 22 | elif opt.dataset_mode == 'CocoSeg': 23 | from data.cocoseg_dataset import CocoSegDataset 24 | dataset = CocoSegDataset() 25 | else: 26 | raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) 27 | 28 | print("dataset [%s] was created" % (dataset.name())) 29 | dataset.initialize(opt) 30 | return dataset 31 | 32 | 33 | class CustomDatasetDataLoader(BaseDataLoader): 34 | def name(self): 35 | return 'CustomDatasetDataLoader' 36 | 37 | def initialize(self, opt): 38 | BaseDataLoader.initialize(self, opt) 39 | self.dataset = CreateDataset(opt) 40 | self.dataloader = torch.utils.data.DataLoader( 41 | self.dataset, 42 | batch_size=opt.batchSize, 43 | shuffle=not opt.serial_batches, 44 | num_workers=int(opt.nThreads)) 45 | 46 | def load_data(self): 47 | return self.dataloader 48 | 49 | def __len__(self): 50 | return min(len(self.dataset), self.opt.max_dataset_size) 51 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from data.data_loader import CreateDataLoader 4 | from models.models import create_model 5 | from util.visualizer import Visualizer 6 | import random 7 | from pdb import set_trace as st 8 | random.seed(10) 9 | 10 | opt = TrainOptions().parse() 11 | data_loader = CreateDataLoader(opt) 12 | dataset = data_loader.load_data() 13 | dataset_size = len(data_loader) 14 | print('#training images = %d' % dataset_size) 15 | 16 | model = create_model(opt) 17 | visualizer = Visualizer(opt) 18 | total_steps = 0 19 | 20 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): 21 | epoch_start_time = time.time() 22 | epoch_iter = 0 23 | 24 | for i, data in enumerate(dataset): 25 | iter_start_time = time.time() 26 | total_steps += opt.batchSize 27 | epoch_iter += opt.batchSize 28 | model.set_input(data) 29 | model.optimize_parameters() 30 | if total_steps % opt.display_freq == 0: 31 | visualizer.display_current_results(model.get_current_visuals(), epoch, total_steps, dataset_size) 32 | 33 | if total_steps % opt.print_freq == 0: 34 | errors = model.get_current_errors() 35 | t = (time.time() - iter_start_time) / opt.batchSize 36 | visualizer.print_current_errors(epoch, epoch_iter, errors, t) 37 | if opt.display_id>0: 38 | visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) 39 | if total_steps % opt.save_latest_freq == 0: 40 | print('saving the latest model (epoch %d, total_steps %d)' % 41 | (epoch, total_steps)) 42 | model.save('latest') 43 | if epoch % opt.save_epoch_freq == 0: 44 | print('saving the model at the end of epoch %d, iters %d' % 45 | (epoch, total_steps)) 46 | model.save('latest') 47 | model.save(epoch) 48 | print('End of epoch %d / %d \t Time Taken: %d sec' % 49 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 50 | model.update_learning_rate() 51 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | # print(self.img_dir) 16 | 17 | self.doc = dominate.document(title=title) 18 | if reflesh > 0: 19 | with self.doc.head: 20 | meta(http_equiv="reflesh", content=str(reflesh)) 21 | 22 | def get_image_dir(self): 23 | return self.img_dir 24 | 25 | def add_header(self, str): 26 | with self.doc: 27 | h3(str) 28 | 29 | def add_table(self, border=1): 30 | self.t = table(border=border, style="table-layout: fixed;") 31 | self.doc.add(self.t) 32 | 33 | def add_images(self, ims, txts, links, width=400): 34 | self.add_table() 35 | with self.t: 36 | with tr(): 37 | for im, txt, link in zip(ims, txts, links): 38 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 39 | with p(): 40 | with a(href=link): 41 | img(style="width:%dpx" % width, src=im) 42 | br() 43 | p(txt) 44 | 45 | 46 | def save(self): 47 | html_file = '%s/index.html' % self.web_dir 48 | f = open(html_file, 'wt') 49 | f.write(self.doc.render()) 50 | f.close() 51 | 52 | 53 | if __name__ == '__main__': 54 | html = HTML('web/', 'test_html') 55 | html.add_header('hello world') 56 | 57 | ims = [] 58 | txts = [] 59 | links = [] 60 | for n in range(4): 61 | ims.append('image_%d.png' % n) 62 | txts.append('text_%d' % n) 63 | links.append('image_%d.png' % n) 64 | html.add_images(ims, txts, links) 65 | html.save() 66 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | def make_dataset(dir): 24 | images = [] 25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 26 | 27 | for root, _, fnames in sorted(os.walk(dir)): 28 | for fname in fnames: 29 | if is_image_file(fname): 30 | path = os.path.join(root, fname) 31 | images.append(path) 32 | 33 | return images 34 | 35 | 36 | def default_loader(path): 37 | return Image.open(path).convert('RGB') 38 | 39 | 40 | class ImageFolder(data.Dataset): 41 | 42 | def __init__(self, root, transform=None, return_paths=False, 43 | loader=default_loader): 44 | imgs = make_dataset(root) 45 | if len(imgs) == 0: 46 | raise(RuntimeError("Found 0 images in: " + root + "\n" 47 | "Supported image extensions are: " + 48 | ",".join(IMG_EXTENSIONS))) 49 | 50 | self.root = root 51 | self.imgs = imgs 52 | self.transform = transform 53 | self.return_paths = return_paths 54 | self.loader = loader 55 | 56 | def __getitem__(self, index): 57 | path = self.imgs[index] 58 | img = self.loader(path) 59 | if self.transform is not None: 60 | img = self.transform(img) 61 | if self.return_paths: 62 | return img, path 63 | else: 64 | return img 65 | 66 | def __len__(self): 67 | return len(self.imgs) 68 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from pdb import set_trace as st 4 | 5 | class BaseModel(): 6 | def name(self): 7 | return 'BaseModel' 8 | 9 | def initialize(self, opt): 10 | self.opt = opt 11 | self.gpu_ids = opt.gpu_ids 12 | self.isTrain = opt.isTrain 13 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 14 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 15 | 16 | def set_input(self, input): 17 | self.input = input 18 | 19 | def forward(self): 20 | pass 21 | 22 | # used in test time, no backprop 23 | def test(self): 24 | pass 25 | 26 | def get_image_paths(self): 27 | pass 28 | 29 | def optimize_parameters(self): 30 | pass 31 | 32 | def get_current_visuals(self): 33 | return self.input 34 | 35 | def get_current_errors(self): 36 | return {} 37 | 38 | def save(self, label): 39 | pass 40 | 41 | # helper saving function that can be used by subclasses 42 | def save_network(self, network, network_label, epoch_label, gpu_ids): 43 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 44 | save_path = os.path.join(self.save_dir, save_filename) 45 | torch.save(network.cpu().state_dict(), save_path) 46 | if len(gpu_ids) and torch.cuda.is_available(): 47 | network.cuda(gpu_ids[0]) 48 | def isexist_network(self, network_label, epoch_label): 49 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 50 | save_path = os.path.join(self.save_dir, save_filename) 51 | return os.path.exists(save_path) 52 | 53 | # helper loading function that can be used by subclasses 54 | def load_network(self, network, network_label, epoch_label): 55 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 56 | save_path = os.path.join(self.save_dir, save_filename) 57 | # print('path: '+ save_path) 58 | network.load_state_dict(torch.load(save_path)) 59 | # update learning rate (called once every epoch) 60 | def update_learning_rate(self): 61 | for scheduler in self.schedulers: 62 | scheduler.step() 63 | lr = self.optimizers[0].param_groups[0]['lr'] 64 | print('learning rate = %.7f' % lr) 65 | -------------------------------------------------------------------------------- /data/unaligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torchvision.transforms as transforms 3 | from data.base_dataset import BaseDataset, get_transform 4 | from data.image_folder import make_dataset 5 | from PIL import Image 6 | import PIL 7 | import random 8 | from pdb import set_trace as st 9 | 10 | class UnalignedDataset(BaseDataset): 11 | def initialize(self, opt): 12 | self.opt = opt 13 | self.root = opt.dataroot 14 | self.isTrain = opt.isTrain 15 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') 16 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') 17 | 18 | self.A_paths = make_dataset(self.dir_A) 19 | self.B_paths = make_dataset(self.dir_B) 20 | 21 | self.A_paths = sorted(self.A_paths) 22 | self.B_paths = sorted(self.B_paths) 23 | self.A_size = len(self.A_paths) 24 | self.B_size = len(self.B_paths) 25 | self.transform = get_transform(opt) 26 | 27 | def __getitem__(self, index): 28 | A_path = self.A_paths[index % self.A_size] 29 | index_A = index % self.A_size 30 | if self.isTrain: 31 | index_B = random.randint(0, self.B_size - 1) 32 | else: 33 | index_B = index % self.B_size 34 | B_path = self.B_paths[index_B] 35 | # print('load B_path:',B_path) 36 | A_img = Image.open(A_path).convert('RGB') 37 | B_img = Image.open(B_path).convert('RGB') 38 | 39 | A = self.transform(A_img) 40 | B = self.transform(B_img) 41 | if self.opt.which_direction == 'BtoA': 42 | input_nc = self.opt.output_nc 43 | output_nc = self.opt.input_nc 44 | else: 45 | input_nc = self.opt.input_nc 46 | output_nc = self.opt.output_nc 47 | 48 | if input_nc == 1: # RGB to gray 49 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 50 | A = tmp.unsqueeze(0) 51 | 52 | if output_nc == 1: # RGB to gray 53 | tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 54 | B = tmp.unsqueeze(0) 55 | return {'A': A, 'B': B, 56 | 'A_paths': A_path, 'B_paths': B_path} 57 | 58 | def __len__(self): 59 | return max(self.A_size, self.B_size) 60 | 61 | def name(self): 62 | return 'UnalignedDataset' 63 | -------------------------------------------------------------------------------- /data/coco_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torchvision.transforms as transforms 3 | from data.base_dataset import BaseDataset, get_transform 4 | from data.image_folder import make_dataset 5 | from PIL import Image 6 | import PIL 7 | from pdb import set_trace as st 8 | import random 9 | class UnalignedCocoDataset(BaseDataset): 10 | def initialize(self, opt): 11 | from pycocotools.coco import COCO 12 | self.opt = opt 13 | self.root = opt.dataroot 14 | self.isTrain = opt.isTrain 15 | self.dataType = opt.dataType 16 | annFile='{}/annotations/instances_{}.json'.format(self.root,self.dataType) 17 | self.coco = COCO(annFile) 18 | catIds_A = self.coco.getCatIds(catNms=[opt.A_cats]) 19 | catIds_B = self.coco.getCatIds(catNms=[opt.B_cats]) 20 | self.imgIds_A = self.coco.getImgIds(catIds=catIds_A ) 21 | self.imgIds_B = self.coco.getImgIds(catIds=catIds_B ) 22 | self.A_size = len(self.imgIds_A) 23 | self.B_size = len(self.imgIds_B) 24 | self.transform = get_transform(opt) 25 | def __getitem__(self, index): 26 | coco = self.coco 27 | index_A = index % self.A_size 28 | if self.isTrain: 29 | index_B = random.randint(0, self.B_size - 1) 30 | else: 31 | index_B = index % self.B_size 32 | A_img_id = self.imgIds_A[index_A] 33 | B_img_id = self.imgIds_B[index_B] 34 | 35 | # print('(A, B) = (%d, %d)' % (index_A, index_B)) 36 | A_ann_ids = coco.getAnnIds(imgIds=A_img_id) 37 | B_ann_ids = coco.getAnnIds(imgIds=B_img_id) 38 | A_anns = coco.loadAnns(A_ann_ids) 39 | B_anns = coco.loadAnns(B_ann_ids) 40 | 41 | A_path = coco.loadImgs(A_img_id)[0]['file_name'] 42 | B_path = coco.loadImgs(B_img_id)[0]['file_name'] 43 | A_path = os.path.join(self.root, self.dataType, A_path) 44 | B_path = os.path.join(self.root, self.dataType, B_path) 45 | A_img = Image.open(A_path).convert('RGB') 46 | B_img = Image.open(B_path).convert('RGB') 47 | 48 | 49 | A_img = self.transform(A_img) 50 | B_img = self.transform(B_img) 51 | 52 | return {'A': A_img, 'B': B_img, 53 | 'A_paths': A_path, 'B_paths': B_path} 54 | 55 | def __len__(self): 56 | return max(self.A_size, self.B_size) 57 | 58 | def name(self): 59 | return 'UnalignedCocoDataset' 60 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | 5 | class BaseDataset(data.Dataset): 6 | def __init__(self): 7 | super(BaseDataset, self).__init__() 8 | 9 | def name(self): 10 | return 'BaseDataset' 11 | 12 | def initialize(self, opt): 13 | pass 14 | 15 | def get_transform(opt): 16 | transform_list = [] 17 | if opt.resize_or_crop == 'resize_and_crop': 18 | osize = [opt.loadSize, opt.loadSize] 19 | # transform_list.append(transforms.Scale(osize, Image.BICUBIC)) 20 | transform_list.append(transforms.Lambda( 21 | lambda img: __scale(img, osize))) 22 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 23 | elif opt.resize_or_crop == 'crop': 24 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 25 | elif opt.resize_or_crop == 'scale_width': 26 | transform_list.append(transforms.Lambda( 27 | lambda img: __scale_width(img, opt.fineSize))) 28 | elif opt.resize_or_crop == 'scale_width_and_crop': 29 | transform_list.append(transforms.Lambda( 30 | lambda img: __scale_width(img, opt.loadSize))) 31 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 32 | 33 | if opt.isTrain and not opt.no_flip: 34 | transform_list.append(transforms.RandomHorizontalFlip()) 35 | 36 | transform_list += [transforms.ToTensor(), 37 | transforms.Normalize((0.5, 0.5, 0.5), 38 | (0.5, 0.5, 0.5))] 39 | return transforms.Compose(transform_list) 40 | 41 | def __scale_width(img, target_width): 42 | ow, oh = img.size 43 | if (ow == target_width): 44 | return img 45 | w = target_width 46 | h = int(target_width * oh / ow) 47 | size=(w,h) 48 | interpolation=Image.BICUBIC 49 | if len(img.getbands())>3: 50 | img_rgb=Image.merge('RGB',img.split()[:3]) 51 | img_seg=img.split()[3] 52 | img_rgb=img_rgb.resize(size, interpolation) 53 | img_seg=img_seg.resize(size, interpolation) 54 | img_rgb.putalpha(img_seg) 55 | return img_rgb 56 | else: 57 | return img.resize(size, interpolation) 58 | 59 | def __scale(img, size): 60 | interpolation=Image.BICUBIC 61 | if len(img.getbands())>3: 62 | img_rgb=Image.merge('RGB',img.split()[:3]) 63 | img_seg=img.split()[3] 64 | img_rgb=img_rgb.resize(size, interpolation) 65 | img_seg=img_seg.resize(size, interpolation) 66 | img_rgb.putalpha(img_seg) 67 | return img_rgb 68 | else: 69 | return img.resize(size, interpolation) -------------------------------------------------------------------------------- /data/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | import torchvision.transforms as transforms 4 | import torch 5 | from data.base_dataset import BaseDataset 6 | from data.image_folder import make_dataset 7 | from PIL import Image 8 | 9 | 10 | class AlignedDataset(BaseDataset): 11 | def initialize(self, opt): 12 | self.opt = opt 13 | self.root = opt.dataroot 14 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) 15 | 16 | self.AB_paths = sorted(make_dataset(self.dir_AB)) 17 | 18 | assert(opt.resize_or_crop == 'resize_and_crop') 19 | 20 | transform_list = [transforms.ToTensor(), 21 | transforms.Normalize((0.5, 0.5, 0.5), 22 | (0.5, 0.5, 0.5))] 23 | 24 | self.transform = transforms.Compose(transform_list) 25 | 26 | def __getitem__(self, index): 27 | AB_path = self.AB_paths[index] 28 | AB = Image.open(AB_path).convert('RGB') 29 | AB = AB.resize((self.opt.loadSize * 2, self.opt.loadSize), Image.BICUBIC) 30 | AB = self.transform(AB) 31 | 32 | w_total = AB.size(2) 33 | w = int(w_total / 2) 34 | h = AB.size(1) 35 | w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1)) 36 | h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1)) 37 | 38 | A = AB[:, h_offset:h_offset + self.opt.fineSize, 39 | w_offset:w_offset + self.opt.fineSize] 40 | B = AB[:, h_offset:h_offset + self.opt.fineSize, 41 | w + w_offset:w + w_offset + self.opt.fineSize] 42 | 43 | if self.opt.which_direction == 'BtoA': 44 | input_nc = self.opt.output_nc 45 | output_nc = self.opt.input_nc 46 | else: 47 | input_nc = self.opt.input_nc 48 | output_nc = self.opt.output_nc 49 | 50 | if (not self.opt.no_flip) and random.random() < 0.5: 51 | idx = [i for i in range(A.size(2) - 1, -1, -1)] 52 | idx = torch.LongTensor(idx) 53 | A = A.index_select(2, idx) 54 | B = B.index_select(2, idx) 55 | 56 | if input_nc == 1: # RGB to gray 57 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 58 | A = tmp.unsqueeze(0) 59 | 60 | if output_nc == 1: # RGB to gray 61 | tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 62 | B = tmp.unsqueeze(0) 63 | 64 | return {'A': A, 'B': B, 65 | 'A_paths': AB_path, 'B_paths': AB_path} 66 | 67 | def __len__(self): 68 | return len(self.AB_paths) 69 | 70 | def name(self): 71 | return 'AlignedDataset' 72 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | class ImagePool(): 6 | def __init__(self, pool_size): 7 | self.pool_size = pool_size 8 | if self.pool_size > 0: 9 | self.num_imgs = 0 10 | self.images = [] 11 | self.attn_maps = [] 12 | 13 | def query(self, images): 14 | if self.pool_size == 0: 15 | return images 16 | return_images = [] 17 | for image in images.data: 18 | image = torch.unsqueeze(image, 0) 19 | if self.num_imgs < self.pool_size: 20 | self.num_imgs = self.num_imgs + 1 21 | self.images.append(image) 22 | return_images.append(image) 23 | else: 24 | p = random.uniform(0, 1) 25 | if p > 0.5: 26 | random_id = random.randint(0, self.pool_size-1) 27 | tmp = self.images[random_id].clone() 28 | self.images[random_id] = image 29 | return_images.append(tmp) 30 | else: 31 | return_images.append(image) 32 | return_images = Variable(torch.cat(return_images, 0)) 33 | return return_images 34 | 35 | def query_attn(self, images, attns): 36 | if self.pool_size == 0: 37 | return images,attn 38 | return_images = [] 39 | return_attns = [] 40 | i=0 41 | for image in images.data: 42 | attn=attns.data[i] 43 | image = torch.unsqueeze(image, 0) 44 | attn = torch.unsqueeze(attn, 0) 45 | if self.num_imgs < self.pool_size: 46 | self.num_imgs = self.num_imgs + 1 47 | self.images.append(image) 48 | self.attn_maps.append(attn) 49 | return_images.append(image) 50 | return_attns.append(attn) 51 | else: 52 | p = random.uniform(0, 1) 53 | if p > 0.5: 54 | random_id = random.randint(0, self.pool_size-1) 55 | tmp = self.images[random_id].clone() 56 | tmp_attn = self.attn_maps[random_id].clone() 57 | self.images[random_id]=image 58 | self.attn_maps[random_id]=attn 59 | return_images.append(tmp) 60 | return_attns.append(tmp_attn) 61 | else: 62 | return_images.append(image) 63 | return_attns.append(attn) 64 | i+=1 65 | return_images = Variable(torch.cat(return_images, 0)) 66 | return_attns = Variable(torch.cat(return_attns, 0)) 67 | return return_images, return_attns -------------------------------------------------------------------------------- /util/inception_classify.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from torch.nn import functional as F 5 | import torch.utils.data 6 | 7 | from torchvision.models.inception import inception_v3 8 | 9 | import numpy as np 10 | from scipy.stats import entropy 11 | 12 | def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1): 13 | """Computes the inception score of the generated images imgs 14 | imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1] 15 | cuda -- whether or not to run on GPU 16 | batch_size -- batch size for feeding into Inception v3 17 | splits -- number of splits 18 | """ 19 | N = len(imgs) 20 | 21 | assert batch_size > 0 22 | assert N > batch_size 23 | 24 | # Set up dtype 25 | if cuda: 26 | dtype = torch.cuda.FloatTensor 27 | else: 28 | if torch.cuda.is_available(): 29 | print("WARNING: You have a CUDA device, so you should probably set cuda=True") 30 | dtype = torch.FloatTensor 31 | 32 | # Set up dataloader 33 | dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) 34 | 35 | # Load inception model 36 | inception_model = inception_v3(pretrained=True, transform_input=True).type(dtype) 37 | inception_model.eval(); 38 | up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype) 39 | def get_pred(x): 40 | if resize: 41 | x = up(x) 42 | x = inception_model(x) 43 | return F.softmax(x).data.cpu().numpy() 44 | 45 | # Get predictions 46 | preds = np.zeros((N, 1000)) 47 | 48 | for i, batch in enumerate(dataloader, 0): 49 | batch = batch.type(dtype) 50 | batchv = Variable(batch) 51 | batch_size_i = batch.size()[0] 52 | 53 | preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv) 54 | 55 | # Now compute the mean kl-div 56 | split_scores = [] 57 | 58 | for k in range(splits): 59 | part = preds[k * (N // splits): (k+1) * (N // splits), :] 60 | py = np.mean(part, axis=0) 61 | scores = [] 62 | for i in range(part.shape[0]): 63 | pyx = part[i, :] 64 | scores.append(entropy(pyx, py)) 65 | split_scores.append(np.exp(np.mean(scores))) 66 | 67 | return np.mean(split_scores), np.std(split_scores) 68 | 69 | if __name__ == '__main__': 70 | class IgnoreLabelDataset(torch.utils.data.Dataset): 71 | def __init__(self, orig): 72 | self.orig = orig 73 | 74 | def __getitem__(self, index): 75 | return self.orig[index][0] 76 | 77 | def __len__(self): 78 | return len(self.orig) 79 | 80 | import torchvision.datasets as dset 81 | import torchvision.transforms as transforms 82 | 83 | cifar = dset.CIFAR10(root='data/', download=True, 84 | transform=transforms.Compose([ 85 | transforms.Scale(32), 86 | transforms.ToTensor(), 87 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 88 | ]) 89 | ) 90 | 91 | IgnoreLabelDataset(cifar) 92 | 93 | print ("Calculating Inception Score...") 94 | print (inception_score(IgnoreLabelDataset(cifar), cuda=True, batch_size=32, resize=True, splits=10)) -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 8 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 9 | self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 10 | self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') 11 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 12 | self.parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 13 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 14 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 15 | self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 16 | self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 17 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 18 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 19 | self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') 20 | self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') 21 | self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') 22 | self.parser.add_argument('--loss_attn_A', type=float, default=1.0, help='weight for attention sparse loss for A') 23 | self.parser.add_argument('--loss_attn_B', type=float, default=1.0, help='weight for attention sparse loss for B') 24 | self.parser.add_argument('--attn_cycle_weight', type=float, default=1.0, help='weight for attention cycle-consistent loss') 25 | self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 26 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 27 | self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') 28 | self.parser.add_argument('--lr_decay_iters', type=int, default=100, help='multiply by a gamma every lr_decay_iters iterations') 29 | self.parser.add_argument('--identity', type=float, default=0.0, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1') 30 | self.isTrain = True 31 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import inspect, re 6 | import matplotlib.cm as cm 7 | import os 8 | import collections 9 | from scipy.misc import imresize 10 | from pdb import set_trace as st 11 | 12 | # Converts a Tensor into a Numpy array 13 | # |imtype|: the desired type of the converted numpy array 14 | def tensor2im(image_tensor, imtype=np.uint8): 15 | image_numpy = image_tensor[0].cpu().float().numpy() 16 | if image_numpy.shape[0] == 1: 17 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 18 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 19 | return image_numpy.astype(imtype) 20 | def tensor2mask(mask_tensor, imtype=np.uint8): 21 | image_numpy = mask_tensor[0].cpu().float().numpy() 22 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 23 | return image_numpy.astype(imtype) 24 | def mask2im(img): 25 | return np.stack((img[:,:,0],img[:,:,0],img[:,:,0]),axis=2) 26 | def mask2heatmap(mask_tensor, imtype=np.uint8, size=None): 27 | # img should be [0,1], 1*W*H 28 | image_numpy = mask_tensor[0,0].cpu().float().numpy() 29 | if size: 30 | image_numpy=resize(image_numpy,size) 31 | heatmap_marked = cm.jet(image_numpy)[..., :3] * 255.0 32 | return heatmap_marked.astype(imtype) 33 | def overlay(seed_img, heatmap_marked, alpha=0.5, imtype=np.uint8): 34 | img = seed_img * alpha + heatmap_marked * (1. - alpha) 35 | return img.astype(imtype) 36 | def resize(img,size): 37 | return imresize(img, size, interp='bilinear', mode=None) 38 | def diagnose_network(net, name='network'): 39 | mean = 0.0 40 | count = 0 41 | for param in net.parameters(): 42 | if param.grad is not None: 43 | mean += torch.mean(torch.abs(param.grad.data)) 44 | count += 1 45 | if count > 0: 46 | mean = mean / count 47 | print(name) 48 | print(mean) 49 | 50 | 51 | def save_image(image_numpy, image_path): 52 | image_pil = Image.fromarray(image_numpy) 53 | image_pil.save(image_path) 54 | 55 | def info(object, spacing=10, collapse=1): 56 | """Print methods and doc strings. 57 | Takes module, class, list, dictionary, or string.""" 58 | methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)] 59 | processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s) 60 | print( "\n".join(["%s %s" % 61 | (method.ljust(spacing), 62 | processFunc(str(getattr(object, method).__doc__))) 63 | for method in methodList]) ) 64 | 65 | def varname(p): 66 | for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]: 67 | m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line) 68 | if m: 69 | return m.group(1) 70 | 71 | def print_numpy(x, val=True, shp=False): 72 | x = x.astype(np.float64) 73 | if shp: 74 | print('shape,', x.shape) 75 | if val: 76 | x = x.flatten() 77 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 78 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 79 | 80 | 81 | def mkdirs(paths): 82 | if isinstance(paths, list) and not isinstance(paths, str): 83 | for path in paths: 84 | mkdir(path) 85 | else: 86 | mkdir(paths) 87 | 88 | 89 | def mkdir(path): 90 | if not os.path.exists(path): 91 | os.makedirs(path) 92 | -------------------------------------------------------------------------------- /util/get_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import tarfile 4 | import requests 5 | from warnings import warn 6 | from zipfile import ZipFile 7 | from bs4 import BeautifulSoup 8 | from os.path import abspath, isdir, join, basename 9 | 10 | 11 | class GetData(object): 12 | """ 13 | 14 | Download CycleGAN or Pix2Pix Data. 15 | 16 | Args: 17 | technique : str 18 | One of: 'cyclegan' or 'pix2pix'. 19 | verbose : bool 20 | If True, print additional information. 21 | 22 | Examples: 23 | >>> from util.get_data import GetData 24 | >>> gd = GetData(technique='cyclegan') 25 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. 26 | 27 | """ 28 | 29 | def __init__(self, technique='cyclegan', verbose=True): 30 | url_dict = { 31 | 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets', 32 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' 33 | } 34 | self.url = url_dict.get(technique.lower()) 35 | self._verbose = verbose 36 | 37 | def _print(self, text): 38 | if self._verbose: 39 | print(text) 40 | 41 | @staticmethod 42 | def _get_options(r): 43 | soup = BeautifulSoup(r.text, 'lxml') 44 | options = [h.text for h in soup.find_all('a', href=True) 45 | if h.text.endswith(('.zip', 'tar.gz'))] 46 | return options 47 | 48 | def _present_options(self): 49 | r = requests.get(self.url) 50 | options = self._get_options(r) 51 | print('Options:\n') 52 | for i, o in enumerate(options): 53 | print("{0}: {1}".format(i, o)) 54 | choice = input("\nPlease enter the number of the " 55 | "dataset above you wish to download:") 56 | return options[int(choice)] 57 | 58 | def _download_data(self, dataset_url, save_path): 59 | if not isdir(save_path): 60 | os.makedirs(save_path) 61 | 62 | base = basename(dataset_url) 63 | temp_save_path = join(save_path, base) 64 | 65 | with open(temp_save_path, "wb") as f: 66 | r = requests.get(dataset_url) 67 | f.write(r.content) 68 | 69 | if base.endswith('.tar.gz'): 70 | obj = tarfile.open(temp_save_path) 71 | elif base.endswith('.zip'): 72 | obj = ZipFile(temp_save_path, 'r') 73 | else: 74 | raise ValueError("Unknown File Type: {0}.".format(base)) 75 | 76 | self._print("Unpacking Data...") 77 | obj.extractall(save_path) 78 | obj.close() 79 | os.remove(temp_save_path) 80 | 81 | def get(self, save_path, dataset=None): 82 | """ 83 | 84 | Download a dataset. 85 | 86 | Args: 87 | save_path : str 88 | A directory to save the data to. 89 | dataset : str, optional 90 | A specific dataset to download. 91 | Note: this must include the file extension. 92 | If None, options will be presented for you 93 | to choose from. 94 | 95 | Returns: 96 | save_path_full : str 97 | The absolute path to the downloaded data. 98 | 99 | """ 100 | if dataset is None: 101 | selected_dataset = self._present_options() 102 | else: 103 | selected_dataset = dataset 104 | 105 | save_path_full = join(save_path, selected_dataset.split('.')[0]) 106 | 107 | if isdir(save_path_full): 108 | warn("\n'{0}' already exists. Voiding Download.".format( 109 | save_path_full)) 110 | else: 111 | self._print('Downloading Data...') 112 | url = "{0}/{1}".format(self.url, selected_dataset) 113 | self._download_data(url, save_path=save_path) 114 | 115 | return abspath(save_path_full) 116 | -------------------------------------------------------------------------------- /data/cocoseg_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torchvision.transforms as transforms 3 | from data.base_dataset import BaseDataset, get_transform 4 | from data.image_folder import make_dataset 5 | from PIL import Image 6 | import PIL 7 | from pdb import set_trace as st 8 | import numpy as np 9 | import torch 10 | import random 11 | 12 | class CocoSegDataset(BaseDataset): 13 | def initialize(self, opt): 14 | from pycocotools.coco import COCO 15 | self.opt = opt 16 | self.root = opt.dataroot 17 | self.dataType = opt.dataType 18 | self.isTrain = opt.isTrain 19 | if self.dataType == 'test2017': 20 | annFile = '{}/annotations/image_info_{}.json'.format(self.root, self.dataType) 21 | else: 22 | annFile='{}/annotations/instances_{}.json'.format(self.root, self.dataType) 23 | self.coco = COCO(annFile) 24 | self.catIds_A = self.coco.getCatIds(catNms=[opt.A_cats]) 25 | self.catIds_B = self.coco.getCatIds(catNms=[opt.B_cats]) 26 | self.imgIds_A = self.coco.getImgIds(catIds=self.catIds_A ) 27 | self.imgIds_B = self.coco.getImgIds(catIds=self.catIds_B ) 28 | self.A_size = len(self.imgIds_A) - 1 29 | self.B_size = len(self.imgIds_B) 30 | self.transform = get_transform(opt) 31 | def __getitem__(self, index): 32 | coco = self.coco 33 | index_A = index % self.A_size 34 | # index_A=1238 35 | # index_A = 10 36 | if self.isTrain: 37 | index_B = random.randint(0, self.B_size - 1) 38 | else: 39 | index_B = index % self.B_size 40 | # index_B=1366 41 | # index_B = 14 42 | A_img_id = self.imgIds_A[index_A] 43 | B_img_id = self.imgIds_B[index_B] 44 | # A_img_id = 427523 45 | # B_img_id = 403916 46 | # print('(A, B) = (%d, %d)' % (index_A, index_B)) 47 | A_ann_ids = coco.getAnnIds(imgIds=A_img_id, catIds=self.catIds_A, iscrowd=None) 48 | B_ann_ids = coco.getAnnIds(imgIds=B_img_id, catIds=self.catIds_B, iscrowd=None) 49 | A_anns = coco.loadAnns(A_ann_ids) 50 | B_anns = coco.loadAnns(B_ann_ids) 51 | 52 | A_path = coco.loadImgs(A_img_id)[0]['file_name'] 53 | A_path = os.path.join(self.root, self.dataType, A_path) 54 | B_path = coco.loadImgs(B_img_id)[0]['file_name'] 55 | B_path = os.path.join(self.root, self.dataType, B_path) 56 | A_img = Image.open(A_path).convert('RGB') 57 | B_img = Image.open(B_path).convert('RGB') 58 | 59 | A_seg = self.getSegMask(A_anns) 60 | B_seg = self.getSegMask(B_anns) 61 | A_img.putalpha(A_seg) 62 | B_img.putalpha(B_seg) 63 | 64 | i=0 65 | A_input = self.transform(A_img) 66 | while (torch.sum(A_input[3])<5): 67 | # make sure target is large than 10 pixel after cropping 68 | A_input = self.transform(A_img) 69 | i+=1 70 | if i==100: 71 | print('exit target_A\'s total pixel less than 5, what should I do?') 72 | print('id:',A_img_id) 73 | break 74 | 75 | i=0 76 | B_input = self.transform(B_img) 77 | while (torch.sum(B_input[3])<5): 78 | # make sure target is large than 10 pixel after cropping 79 | B_input = self.transform(B_img) 80 | i+=1 81 | if i==100: 82 | print('exit target_B\'s total pixel less than 5, what should I do?') 83 | print('id:',B_img_id) 84 | break 85 | return {'A': A_input, 'B': B_input, 86 | # 'A_img_id': A_img_id, 'B_img_id': B_img_id} 87 | 'A_paths': A_path, 'B_paths': B_path} 88 | 89 | def __len__(self): 90 | return max(self.A_size, self.B_size) 91 | 92 | def name(self): 93 | return 'UnalignedCocoSegDataset' 94 | def getSegMask(self, anns): 95 | coco=self.coco 96 | for i,anns_i in enumerate(anns): 97 | if i==0: 98 | mask=np.asarray(coco.annToMask(anns_i)) 99 | else: 100 | mask+=np.asarray(coco.annToMask(anns_i)) 101 | mask[mask>1]=1 102 | mask=torch.from_numpy(mask[np.newaxis,:,:]*255) 103 | ToPil=transforms.ToPILImage() 104 | seg=ToPil(mask) 105 | return seg 106 | 107 | -------------------------------------------------------------------------------- /models/pix2pix_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from collections import OrderedDict 5 | from torch.autograd import Variable 6 | import util.util as util 7 | from util.image_pool import ImagePool 8 | from .base_model import BaseModel 9 | from . import networks 10 | 11 | 12 | class Pix2PixModel(BaseModel): 13 | def name(self): 14 | return 'Pix2PixModel' 15 | 16 | def initialize(self, opt): 17 | BaseModel.initialize(self, opt) 18 | self.isTrain = opt.isTrain 19 | # define tensors 20 | self.input_A = self.Tensor(opt.batchSize, opt.input_nc, 21 | opt.fineSize, opt.fineSize) 22 | self.input_B = self.Tensor(opt.batchSize, opt.output_nc, 23 | opt.fineSize, opt.fineSize) 24 | 25 | # load/define networks 26 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 27 | opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) 28 | if self.isTrain: 29 | use_sigmoid = opt.no_lsgan 30 | self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, 31 | opt.which_model_netD, 32 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) 33 | if not self.isTrain or opt.continue_train: 34 | self.load_network(self.netG, 'G', opt.which_epoch) 35 | if self.isTrain: 36 | self.load_network(self.netD, 'D', opt.which_epoch) 37 | 38 | if self.isTrain: 39 | self.fake_AB_pool = ImagePool(opt.pool_size) 40 | self.old_lr = opt.lr 41 | # define loss functions 42 | self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) 43 | self.criterionL1 = torch.nn.L1Loss() 44 | 45 | # initialize optimizers 46 | self.schedulers = [] 47 | self.optimizers = [] 48 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), 49 | lr=opt.lr, betas=(opt.beta1, 0.999)) 50 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), 51 | lr=opt.lr, betas=(opt.beta1, 0.999)) 52 | self.optimizers.append(self.optimizer_G) 53 | self.optimizers.append(self.optimizer_D) 54 | for optimizer in self.optimizers: 55 | self.schedulers.append(networks.get_scheduler(optimizer, opt)) 56 | 57 | print('---------- Networks initialized -------------') 58 | networks.print_network(self.netG) 59 | if self.isTrain: 60 | networks.print_network(self.netD) 61 | print('-----------------------------------------------') 62 | 63 | def set_input(self, input): 64 | AtoB = self.opt.which_direction == 'AtoB' 65 | input_A = input['A' if AtoB else 'B'] 66 | input_B = input['B' if AtoB else 'A'] 67 | self.input_A.resize_(input_A.size()).copy_(input_A) 68 | self.input_B.resize_(input_B.size()).copy_(input_B) 69 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] 70 | 71 | def forward(self): 72 | self.real_A = Variable(self.input_A) 73 | self.fake_B = self.netG.forward(self.real_A) 74 | self.real_B = Variable(self.input_B) 75 | 76 | # no backprop gradients 77 | def test(self): 78 | self.real_A = Variable(self.input_A, volatile=True) 79 | self.fake_B = self.netG.forward(self.real_A) 80 | self.real_B = Variable(self.input_B, volatile=True) 81 | 82 | # get image paths 83 | def get_image_paths(self): 84 | return self.image_paths 85 | 86 | def backward_D(self): 87 | # Fake 88 | # stop backprop to the generator by detaching fake_B 89 | fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) 90 | self.pred_fake = self.netD.forward(fake_AB.detach()) 91 | self.loss_D_fake = self.criterionGAN(self.pred_fake, False) 92 | 93 | # Real 94 | real_AB = torch.cat((self.real_A, self.real_B), 1) 95 | self.pred_real = self.netD.forward(real_AB) 96 | self.loss_D_real = self.criterionGAN(self.pred_real, True) 97 | 98 | # Combined loss 99 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 100 | 101 | self.loss_D.backward() 102 | 103 | def backward_G(self): 104 | # First, G(A) should fake the discriminator 105 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) 106 | pred_fake = self.netD.forward(fake_AB) 107 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) 108 | 109 | # Second, G(A) = B 110 | self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A 111 | 112 | self.loss_G = self.loss_G_GAN + self.loss_G_L1 113 | 114 | self.loss_G.backward() 115 | 116 | def optimize_parameters(self): 117 | self.forward() 118 | 119 | self.optimizer_D.zero_grad() 120 | self.backward_D() 121 | self.optimizer_D.step() 122 | 123 | self.optimizer_G.zero_grad() 124 | self.backward_G() 125 | self.optimizer_G.step() 126 | 127 | def get_current_errors(self): 128 | return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]), 129 | ('G_L1', self.loss_G_L1.data[0]), 130 | ('D_real', self.loss_D_real.data[0]), 131 | ('D_fake', self.loss_D_fake.data[0]) 132 | ]) 133 | 134 | def get_current_visuals(self): 135 | real_A = util.tensor2im(self.real_A.data) 136 | fake_B = util.tensor2im(self.fake_B.data) 137 | real_B = util.tensor2im(self.real_B.data) 138 | return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) 139 | 140 | def save(self, label): 141 | self.save_network(self.netG, 'G', label, self.gpu_ids) 142 | self.save_network(self.netD, 'D', label, self.gpu_ids) 143 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ntpath 4 | import time 5 | from . import util 6 | from . import html 7 | from pdb import set_trace as st 8 | class Visualizer(): 9 | def __init__(self, opt): 10 | # self.opt = op 11 | self.display_id = opt.display_id 12 | if opt.isTrain == True: 13 | self.display_freq = opt.display_freq 14 | self.epoch_count = opt.epoch_count 15 | self.use_html = opt.isTrain and not opt.no_html 16 | self.win_size = opt.display_winsize 17 | self.name = opt.name 18 | if self.display_id > 0: 19 | import visdom 20 | self.vis = visdom.Visdom(port = opt.display_port) 21 | self.display_single_pane_ncols = opt.display_single_pane_ncols 22 | 23 | if self.use_html: 24 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 25 | self.img_dir = os.path.join(self.web_dir, 'images') 26 | print('create web directory %s...' % self.web_dir) 27 | util.mkdirs([self.web_dir, self.img_dir]) 28 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 29 | with open(self.log_name, "a") as log_file: 30 | now = time.strftime("%c") 31 | log_file.write('================ Training Loss (%s) ================\n' % now) 32 | 33 | # |visuals|: dictionary of images to display or save 34 | def display_current_results(self, visuals, epoch, iter, niter): 35 | if self.display_id > 0: # show images in the browser 36 | if self.display_single_pane_ncols > 0: 37 | h, w = next(iter(visuals.values())).shape[:2] 38 | table_css = """""" % (w, h) 42 | ncols = self.display_single_pane_ncols 43 | title = self.name 44 | label_html = '' 45 | label_html_row = '' 46 | nrows = int(np.ceil(len(visuals.items()) / ncols)) 47 | images = [] 48 | idx = 0 49 | for label, image_numpy in visuals.items(): 50 | label_html_row += '%s' % label 51 | images.append(image_numpy.transpose([2, 0, 1])) 52 | idx += 1 53 | if idx % ncols == 0: 54 | label_html += '%s' % label_html_row 55 | label_html_row = '' 56 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255 57 | while idx % ncols != 0: 58 | images.append(white_image) 59 | label_html_row += '' 60 | idx += 1 61 | if label_html_row != '': 62 | label_html += '%s' % label_html_row 63 | # pane col = image row 64 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 65 | padding=2, opts=dict(title=title + ' images')) 66 | label_html = '%s
' % label_html 67 | self.vis.text(table_css + label_html, win = self.display_id + 2, 68 | opts=dict(title=title + ' labels')) 69 | else: 70 | idx = 1 71 | for label, image_numpy in visuals.items(): 72 | self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label), 73 | win=self.display_id + idx) 74 | idx += 1 75 | if self.use_html: # save images to a html file 76 | for label, image_numpy in visuals.items(): 77 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.4d_%s.png' % (epoch, iter, label)) 78 | util.save_image(image_numpy, img_path) 79 | # update website 80 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) 81 | for n in range(iter, 0, -self.display_freq): 82 | epoch = n/niter + self.epoch_count 83 | webpage.add_header('epoch_ [%d] iter [%d]' % (epoch, n)) 84 | ims = [] 85 | txts = [] 86 | links = [] 87 | 88 | for label, image_numpy in visuals.items(): 89 | img_path = 'epoch%.3d_iter%.4d_%s.png' % (epoch, n, label) 90 | ims.append(img_path) 91 | txts.append(label) 92 | links.append(img_path) 93 | webpage.add_images(ims, txts, links, width=self.win_size) 94 | webpage.save() 95 | 96 | 97 | # errors: dictionary of error labels and values 98 | def plot_current_errors(self, epoch, counter_ratio, opt, errors): 99 | if not hasattr(self, 'plot_data'): 100 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} 101 | self.plot_data['X'].append(epoch + counter_ratio) 102 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) 103 | self.vis.line( 104 | X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), 105 | Y=np.array(self.plot_data['Y']), 106 | opts={ 107 | 'title': self.name + ' loss over time', 108 | 'legend': self.plot_data['legend'], 109 | 'xlabel': 'epoch', 110 | 'ylabel': 'loss'}, 111 | win=self.display_id) 112 | 113 | # errors: same format as |errors| of plotCurrentErrors 114 | def print_current_errors(self, epoch, i, errors, t): 115 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) 116 | for k, v in errors.items(): 117 | message += '%s: %.3f ' % (k, v.cpu().numpy()) 118 | 119 | print(message) 120 | with open(self.log_name, "a") as log_file: 121 | log_file.write('%s\n' % message) 122 | 123 | # save image to the disk 124 | def save_images(self, webpage, visuals, image_path): 125 | image_dir = webpage.get_image_dir() 126 | short_path = ntpath.basename(image_path[0]) 127 | name = os.path.splitext(short_path)[0] 128 | 129 | webpage.add_header(name) 130 | ims = [] 131 | txts = [] 132 | links = [] 133 | for label, image_numpy in visuals.items(): 134 | image_name = '%s_%s.png' % (name, label) 135 | save_path = os.path.join(image_dir, image_name) 136 | util.save_image(image_numpy, save_path) 137 | ims.append(os.path.join('images', image_name)) 138 | txts.append(label) 139 | links.append(os.path.join('images', image_name)) 140 | webpage.add_images(ims, txts, links, width=self.win_size) 141 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | from pdb import set_trace as st 6 | 7 | class BaseOptions(): 8 | def __init__(self): 9 | self.parser = argparse.ArgumentParser() 10 | self.initialized = False 11 | 12 | def initialize(self): 13 | self.parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 14 | self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 15 | self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size') 16 | self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size') 17 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 18 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 19 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 20 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 21 | self.parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD:[multiscale_layers]') 22 | self.parser.add_argument('--which_model_netG', type=str, default='resnet_9blocks', help='selects model to use for netG') 23 | self.parser.add_argument('--n_layers_D', type=str, default='3', help='only used if which_model_netD==n_layers') 24 | self.parser.add_argument('--loss_weight', type=str, default=None, help='weight in the multiscale D (corresponding with n_layers_D)') 25 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 26 | self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 27 | self.parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | CocoSeg | Coco]') 28 | self.parser.add_argument('--model', type=str, default='cycle_gan', 29 | help='chooses which model to use. cycle_gan, pix2pix, test') 30 | self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA') # change the datasets' A and B 31 | self.parser.add_argument('--which_direction_model', type=str, default='AtoB', help='AtoB or BtoA') #change the model's A and B 32 | self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data') 33 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 34 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 35 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 36 | self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 37 | self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 38 | self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 39 | self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 40 | self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 41 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 42 | self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 43 | self.parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]') 44 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 45 | self.initialized = True 46 | self.parser.add_argument('--dataType', type=str, default='val2017', help='data type of coco datasets: train2017 | val2017 | test2017') 47 | self.parser.add_argument('--A_cats', type=str, default='horse', help='category in coco') 48 | self.parser.add_argument('--B_cats', type=str, default='zebra', help='category in coco') 49 | self.parser.add_argument('--which_model_netA', type=str, default='resnet_6blocks', help='selects model to use for netG') 50 | def parse(self): 51 | if not self.initialized: 52 | self.initialize() 53 | self.opt = self.parser.parse_args() 54 | self.opt.isTrain = self.isTrain # train or test 55 | 56 | str_ids = self.opt.gpu_ids.split(',') 57 | self.opt.gpu_ids = [] 58 | for str_id in str_ids: 59 | id = int(str_id) 60 | if id >= 0: 61 | self.opt.gpu_ids.append(id) 62 | 63 | # set gpu ids 64 | if len(self.opt.gpu_ids) > 0: 65 | torch.cuda.set_device(self.opt.gpu_ids[0]) 66 | 67 | # set n_layers_D: 68 | n_layers_D = self.opt.n_layers_D.split(',') 69 | if len(n_layers_D)==1: 70 | self.opt.n_layers_D = int(n_layers_D[0]) 71 | else: 72 | self.opt.n_layers_D=[] 73 | for i in range(len(n_layers_D)): 74 | self.opt.n_layers_D += [int(n_layers_D[i])] 75 | args = vars(self.opt) 76 | 77 | # set loss_weight: 78 | if self.opt.loss_weight: 79 | loss_weight = self.opt.loss_weight.split(',') 80 | if len(loss_weight)==1: 81 | self.opt.loss_weight = float(loss_weight[0]) 82 | else: 83 | self.opt.loss_weight=[] 84 | for i in range(len(loss_weight)): 85 | self.opt.loss_weight += [float(loss_weight[i])] 86 | args = vars(self.opt) 87 | 88 | print('------------ Options -------------') 89 | for k, v in sorted(args.items()): 90 | print('%s: %s' % (str(k), str(v))) 91 | print('-------------- End ----------------') 92 | 93 | # save to the disk 94 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 95 | util.mkdirs(expr_dir) 96 | if self.opt.isTrain: 97 | file_name = os.path.join(expr_dir, 'opt_train.txt') 98 | else: 99 | file_name = os.path.join(expr_dir, 'opt_test.txt') 100 | with open(file_name, 'wt') as opt_file: 101 | opt_file.write('------------ Options -------------\n') 102 | for k, v in sorted(args.items()): 103 | opt_file.write('%s: %s\n' % (str(k), str(v))) 104 | opt_file.write('-------------- End ----------------\n') 105 | return self.opt 106 | -------------------------------------------------------------------------------- /models/cycle_gan_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from collections import OrderedDict 5 | from torch.autograd import Variable 6 | import itertools 7 | import util.util as util 8 | from util.image_pool import ImagePool 9 | from .base_model import BaseModel 10 | from . import networks 11 | import sys 12 | import pdb 13 | 14 | class CycleGANModel(BaseModel): 15 | def name(self): 16 | return 'CycleGANModel' 17 | 18 | def initialize(self, opt): 19 | BaseModel.initialize(self, opt) 20 | 21 | nb = opt.batchSize 22 | size = opt.fineSize 23 | self.input_A = self.Tensor(nb, opt.input_nc, size, size) 24 | self.input_B = self.Tensor(nb, opt.output_nc, size, size) 25 | 26 | # load/define networks 27 | # The naming conversion is different from those used in the paper 28 | # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) 29 | 30 | self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, 31 | opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) 32 | self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, 33 | opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) 34 | 35 | if self.isTrain: 36 | use_sigmoid = opt.no_lsgan 37 | self.netD_A = networks.define_D(opt.output_nc, opt.ndf, 38 | opt.which_model_netD, 39 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) 40 | self.netD_B = networks.define_D(opt.input_nc, opt.ndf, 41 | opt.which_model_netD, 42 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) 43 | if not self.isTrain or opt.continue_train: 44 | which_epoch = opt.which_epoch 45 | self.load_network(self.netG_A, 'G_A', which_epoch) 46 | self.load_network(self.netG_B, 'G_B', which_epoch) 47 | if self.isTrain and not opt.only_load_G: 48 | self.load_network(self.netD_A, 'D_A', which_epoch) 49 | self.load_network(self.netD_B, 'D_B', which_epoch) 50 | 51 | if self.isTrain: 52 | self.old_lr = opt.lr 53 | self.fake_A_pool = ImagePool(opt.pool_size) 54 | self.fake_B_pool = ImagePool(opt.pool_size) 55 | # define loss functions 56 | self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor, loss_weight= opt.loss_weight) 57 | self.criterionCycle = torch.nn.L1Loss() 58 | self.criterionIdt = torch.nn.L1Loss() 59 | # initialize optimizers 60 | self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), 61 | lr=opt.lr, betas=(opt.beta1, 0.999)) 62 | self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 63 | self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 64 | self.optimizers = [] 65 | self.schedulers = [] 66 | self.optimizers.append(self.optimizer_G) 67 | self.optimizers.append(self.optimizer_D_A) 68 | self.optimizers.append(self.optimizer_D_B) 69 | for optimizer in self.optimizers: 70 | self.schedulers.append(networks.get_scheduler(optimizer, opt)) 71 | 72 | print('---------- Networks initialized -------------') 73 | networks.print_network(self.netG_A) 74 | networks.print_network(self.netG_B) 75 | if self.isTrain: 76 | networks.print_network(self.netD_A) 77 | networks.print_network(self.netD_B) 78 | print('-----------------------------------------------') 79 | 80 | def set_input(self, input): 81 | AtoB = self.opt.which_direction == 'AtoB' 82 | input_A = input['A' if AtoB else 'B'] 83 | input_B = input['B' if AtoB else 'A'] 84 | self.input_A.resize_(input_A.size()).copy_(input_A) 85 | self.input_B.resize_(input_B.size()).copy_(input_B) 86 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] 87 | 88 | def forward(self): 89 | self.real_A = Variable(self.input_A) 90 | self.real_B = Variable(self.input_B) 91 | 92 | def test(self): 93 | self.real_A = Variable(self.input_A, volatile=True) 94 | self.fake_B = self.netG_A.forward(self.real_A) 95 | self.rec_A = self.netG_B.forward(self.fake_B) 96 | 97 | self.real_B = Variable(self.input_B, volatile=True) 98 | self.fake_A = self.netG_B.forward(self.real_B) 99 | self.rec_B = self.netG_A.forward(self.fake_A) 100 | 101 | # get image paths 102 | def get_image_paths(self): 103 | return self.image_paths 104 | 105 | def backward_D_basic(self, netD, real, fake): 106 | # Real 107 | pred_real = netD.forward(real) 108 | loss_D_real = self.criterionGAN(pred_real, True) 109 | # Fake 110 | pred_fake = netD.forward(fake.detach()) 111 | loss_D_fake = self.criterionGAN(pred_fake, False) 112 | # Combined loss 113 | loss_D = (loss_D_real + loss_D_fake) * 0.5 114 | # backward 115 | loss_D.backward() 116 | return loss_D 117 | 118 | def backward_D_A(self): 119 | fake_B = self.fake_B_pool.query(self.fake_B) 120 | self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) 121 | 122 | def backward_D_B(self): 123 | fake_A = self.fake_A_pool.query(self.fake_A) 124 | self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) 125 | 126 | def backward_G(self): 127 | lambda_idt = self.opt.identity 128 | lambda_A = self.opt.lambda_A 129 | lambda_B = self.opt.lambda_B 130 | # Identity loss 131 | if lambda_idt > 0: 132 | # G_A should be identity if real_B is fed. 133 | self.idt_A = self.netG_A.forward(self.real_B) 134 | self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt 135 | # G_B should be identity if real_A is fed. 136 | self.idt_B = self.netG_B.forward(self.real_A) 137 | self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt 138 | else: 139 | self.loss_idt_A = 0 140 | self.loss_idt_B = 0 141 | 142 | # GAN loss 143 | # D_A(G_A(A)) 144 | self.fake_B = self.netG_A.forward(self.real_A) 145 | pred_fake = self.netD_A.forward(self.fake_B) 146 | self.loss_G_A = self.criterionGAN(pred_fake, True) 147 | # D_B(G_B(B)) 148 | self.fake_A = self.netG_B.forward(self.real_B) 149 | pred_fake = self.netD_B.forward(self.fake_A) 150 | self.loss_G_B = self.criterionGAN(pred_fake, True) 151 | # Forward cycle loss 152 | self.rec_A = self.netG_B.forward(self.fake_B) 153 | self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A 154 | # Backward cycle loss 155 | self.rec_B = self.netG_A.forward(self.fake_A) 156 | self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B 157 | # combined loss 158 | self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B 159 | self.loss_G.backward() 160 | 161 | def optimize_parameters(self): 162 | # forward 163 | self.forward() 164 | # G_A and G_B 165 | self.optimizer_G.zero_grad() 166 | self.backward_G() 167 | self.optimizer_G.step() 168 | # D_A 169 | self.optimizer_D_A.zero_grad() 170 | self.backward_D_A() 171 | self.optimizer_D_A.step() 172 | # D_B 173 | self.optimizer_D_B.zero_grad() 174 | self.backward_D_B() 175 | self.optimizer_D_B.step() 176 | 177 | def get_current_errors(self): 178 | D_A = self.loss_D_A.data[0] 179 | G_A = self.loss_G_A.data[0] 180 | Cyc_A = self.loss_cycle_A.data[0] 181 | D_B = self.loss_D_B.data[0] 182 | G_B = self.loss_G_B.data[0] 183 | Cyc_B = self.loss_cycle_B.data[0] 184 | if self.opt.identity > 0.0: 185 | idt_A = self.loss_idt_A.data[0] 186 | idt_B = self.loss_idt_B.data[0] 187 | return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), 188 | ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) 189 | else: 190 | return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), 191 | ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)]) 192 | 193 | def get_current_visuals(self): 194 | real_A = util.tensor2im(self.real_A.data) 195 | fake_B = util.tensor2im(self.fake_B.data) 196 | rec_A = util.tensor2im(self.rec_A.data) 197 | real_B = util.tensor2im(self.real_B.data) 198 | fake_A = util.tensor2im(self.fake_A.data) 199 | rec_B = util.tensor2im(self.rec_B.data) 200 | if self.opt.isTrain and self.opt.identity > 0.0: 201 | idt_A = util.tensor2im(self.idt_A.data) 202 | idt_B = util.tensor2im(self.idt_B.data) 203 | return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B), 204 | ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)]) 205 | else: 206 | return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), 207 | ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) 208 | 209 | def save(self, label): 210 | self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) 211 | self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) 212 | self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) 213 | self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) 214 | -------------------------------------------------------------------------------- /models/cycle_attn_gan_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from collections import OrderedDict 5 | from torch.autograd import Variable 6 | import itertools 7 | import util.util as util 8 | from util.image_pool import ImagePool 9 | from .base_model import BaseModel 10 | from . import networks 11 | import sys 12 | from pdb import set_trace as st 13 | 14 | class CycleAttnGANModel(BaseModel): 15 | def name(self): 16 | return 'CycleAttnGANModel' 17 | 18 | def initialize(self, opt): 19 | BaseModel.initialize(self, opt) 20 | 21 | nb = opt.batchSize 22 | size = opt.fineSize 23 | self.which_direction_model = opt.which_direction_model 24 | self.input_A = self.Tensor(nb, opt.input_nc, size, size) 25 | self.input_B = self.Tensor(nb, opt.input_nc, size, size) 26 | self.zeros = self.Tensor(nb, 1, size, size) 27 | self.ones = self.Tensor(nb, 1, size, size) 28 | # load/define networks 29 | # The naming conversion is different from those used in the paper 30 | # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) 31 | self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, 32 | opt.ngf, opt.which_model_netG, opt.norm, 33 | not opt.no_dropout, opt.init_type, self.gpu_ids) 34 | self.netG_B = networks.define_G(opt.input_nc, opt.input_nc, 35 | opt.ngf, opt.which_model_netG, opt.norm, 36 | not opt.no_dropout, opt.init_type, self.gpu_ids) 37 | self.netA_A = networks.define_A(opt.input_nc, 1, 38 | opt.ngf, opt.which_model_netA, opt.norm, 39 | not opt.no_dropout, opt.init_type, self.gpu_ids) 40 | self.netA_B = networks.define_A(opt.input_nc, 1, 41 | opt.ngf, opt.which_model_netA, opt.norm, 42 | not opt.no_dropout, opt.init_type, self.gpu_ids) 43 | if self.isTrain: 44 | use_sigmoid = opt.no_lsgan 45 | else: 46 | use_sigmoid=False 47 | self.netD_A = networks.define_D(opt.output_nc, opt.ndf, 48 | opt.which_model_netD, 49 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) 50 | self.netD_B = networks.define_D(opt.input_nc, opt.ndf, 51 | opt.which_model_netD, 52 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) 53 | if not self.isTrain or opt.continue_train: 54 | which_epoch = opt.which_epoch 55 | if self.which_direction_model=='AtoB': 56 | self.load_network(self.netG_A, 'G_A', which_epoch) 57 | self.load_network(self.netG_B, 'G_B', which_epoch) 58 | self.load_network(self.netA_A, 'A_A', which_epoch) 59 | self.load_network(self.netA_B, 'A_B', which_epoch) 60 | else: 61 | self.load_network(self.netG_A, 'G_B', which_epoch) 62 | self.load_network(self.netG_B, 'G_A', which_epoch) 63 | self.load_network(self.netA_A, 'A_B', which_epoch) 64 | self.load_network(self.netA_B, 'A_A', which_epoch) 65 | if self.isTrain: 66 | self.load_network(self.netD_A, 'D_A', which_epoch) 67 | self.load_network(self.netD_B, 'D_B', which_epoch) 68 | 69 | if self.isTrain: 70 | self.old_lr = opt.lr 71 | self.fake_A_pool = ImagePool(opt.pool_size) 72 | self.fake_B_pool = ImagePool(opt.pool_size) 73 | # define loss functions 74 | self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) 75 | self.criterionCycle = torch.nn.L1Loss() 76 | self.criterionIdt = torch.nn.L1Loss() 77 | # initialize optimizers 78 | self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),lr=opt.lr, betas=(opt.beta1, 0.999)) 79 | self.optimizer_A = torch.optim.Adam(itertools.chain(self.netA_A.parameters(), self.netA_B.parameters()),lr=opt.lr, betas=(opt.beta1, 0.999)) 80 | 81 | self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 82 | self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 83 | self.optimizers = [] 84 | self.schedulers = [] 85 | self.optimizers.append(self.optimizer_G) 86 | self.optimizers.append(self.optimizer_A) 87 | self.optimizers.append(self.optimizer_D_A) 88 | self.optimizers.append(self.optimizer_D_B) 89 | for optimizer in self.optimizers: 90 | self.schedulers.append(networks.get_scheduler(optimizer, opt)) 91 | 92 | print('---------- Networks initialized -------------') 93 | networks.print_network(self.netG_A) 94 | networks.print_network(self.netG_B) 95 | if self.isTrain: 96 | networks.print_network(self.netD_A) 97 | networks.print_network(self.netD_B) 98 | print('-----------------------------------------------') 99 | 100 | def set_input(self, input): 101 | AtoB = self.opt.which_direction == 'AtoB' 102 | input_A = input['A' if AtoB else 'B'] 103 | input_B = input['B' if AtoB else 'A'] 104 | self.input_A.resize_(input_A.size()).copy_(input_A) 105 | self.input_B.resize_(input_B.size()).copy_(input_B) 106 | bz,c,h,w=input_A.size() 107 | self.zeros.resize_((bz,1,h,w)).fill_(0.0) 108 | self.ones.resize_((bz,1,h,w)).fill_(1.0) 109 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] 110 | 111 | def mask_layer(self, foreground, background, mask): 112 | img = foreground * mask + background * (1 - mask) 113 | return img 114 | 115 | def forward(self): 116 | self.real_A = Variable(self.input_A) 117 | self.real_B = Variable(self.input_B) 118 | self.zeros_attn = Variable(self.zeros, requires_grad=False) 119 | self.ones_attn = Variable(self.ones, requires_grad=False) 120 | 121 | def test(self): 122 | self.real_A = Variable(self.input_A, volatile=True) 123 | self.real_B = Variable(self.input_B, volatile=True) 124 | 125 | fake_B = self.netG_A.forward(self.real_A) 126 | self.attn_real_A = self.netA_A.forward(self.real_A) 127 | self.fake_B = self.mask_layer(fake_B, self.real_A, self.attn_real_A) 128 | rec_A = self.netG_B.forward(self.fake_B) 129 | self.attn_fake_B = self.netA_B.forward(self.fake_B) 130 | self.rec_A = self.mask_layer(rec_A, self.fake_B, self.attn_fake_B) 131 | 132 | fake_A = self.netG_B.forward(self.real_B) 133 | self.attn_real_B = self.netA_B.forward(self.real_B) 134 | self.fake_A = self.mask_layer(fake_A, self.real_B, self.attn_real_B) 135 | 136 | rec_B = self.netG_A.forward(self.fake_A) 137 | self.attn_fake_A = self.netA_A.forward(self.fake_A) 138 | self.rec_B = self.mask_layer(rec_B, self.fake_A, self.attn_fake_A) 139 | # get image paths 140 | def get_image_paths(self): 141 | return self.image_paths 142 | 143 | def backward_D_basic(self, netD, real, fake): 144 | # Real 145 | pred_real = netD.forward(real) 146 | loss_D_real = self.criterionGAN(pred_real, True) 147 | # Fake 148 | pred_fake = netD.forward(fake.detach()) 149 | loss_D_fake = self.criterionGAN(pred_fake, False) 150 | # Combined loss 151 | loss_D = (loss_D_real + loss_D_fake) * 0.5 152 | # backward 153 | loss_D.backward() 154 | return loss_D 155 | 156 | def backward_D_A(self): 157 | fake_B = self.fake_B_pool.query(self.fake_B) 158 | self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) 159 | 160 | def backward_D_B(self): 161 | fake_A = self.fake_A_pool.query(self.fake_A) 162 | self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) 163 | 164 | def backward_G(self): 165 | lambda_idt = self.opt.identity 166 | lambda_A = self.opt.lambda_A 167 | lambda_B = self.opt.lambda_B 168 | # Identity loss 169 | if lambda_idt > 0: 170 | # G_A should be identity if real_B is fed. 171 | self.idt_A = self.netG_A.forward(self.real_B) 172 | self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt 173 | # G_B should be identity if real_A is fed. 174 | self.idt_B = self.netG_B.forward(self.real_A) 175 | self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt 176 | else: 177 | self.loss_idt_A = 0 178 | self.loss_idt_B = 0 179 | 180 | # GAN loss 181 | # D_A(G_A(A)) 182 | fake_B = self.netG_A.forward(self.real_A) 183 | self.attn_real_A = self.netA_A.forward(self.real_A) 184 | self.fake_B = self.mask_layer(fake_B, self.real_A, self.attn_real_A) 185 | pred_fake = self.netD_A.forward(self.fake_B) 186 | self.loss_G_A = self.criterionGAN(pred_fake, True) 187 | # D_B(G_B(B)) 188 | fake_A = self.netG_B.forward(self.real_B) 189 | self.attn_real_B = self.netA_B.forward(self.real_B) 190 | self.fake_A = self.mask_layer(fake_A, self.real_B, self.attn_real_B) 191 | pred_fake = self.netD_B.forward(self.fake_A) 192 | self.loss_G_B = self.criterionGAN(pred_fake, True) 193 | # Forward cycle loss 194 | rec_A = self.netG_B.forward(self.fake_B) 195 | self.attn_fake_B = self.netA_B.forward(self.fake_B) 196 | self.rec_A = self.mask_layer(rec_A, self.fake_B, self.attn_fake_B) 197 | self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A 198 | # Backward cycle loss 199 | rec_B = self.netG_A.forward(self.fake_A) 200 | self.attn_fake_A = self.netA_A.forward(self.fake_A) 201 | self.rec_B = self.mask_layer(rec_B, self.fake_A, self.attn_fake_A) 202 | self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B 203 | # attn constrain 204 | self.loss_attnsparse_A = self.criterionIdt(self.attn_real_A, self.zeros_attn) * self.opt.loss_attn_A 205 | self.loss_attnsparse_B = self.criterionIdt(self.attn_real_B, self.zeros_attn) * self.opt.loss_attn_B 206 | self.loss_attnconst_A = self.criterionIdt(self.attn_fake_A, self.attn_real_B.detach()) * self.opt.attn_cycle_weight 207 | self.loss_attnconst_B = self.criterionIdt(self.attn_fake_B, self.attn_real_A.detach()) * self.opt.attn_cycle_weight 208 | # combined loss 209 | self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B + self.loss_attnsparse_A + self.loss_attnsparse_B + self.loss_attnconst_A +self.loss_attnconst_B 210 | self.loss_G.backward() 211 | 212 | def optimize_parameters(self): 213 | # forward 214 | self.forward() 215 | self.optimizer_G.zero_grad() 216 | self.optimizer_A.zero_grad() 217 | self.backward_G() 218 | self.optimizer_G.step() 219 | self.optimizer_A.step() 220 | # D_A 221 | self.optimizer_D_A.zero_grad() 222 | self.backward_D_A() 223 | self.optimizer_D_A.step() 224 | # D_B 225 | self.optimizer_D_B.zero_grad() 226 | self.backward_D_B() 227 | self.optimizer_D_B.step() 228 | def optimize_parameterD(self): 229 | self.forward() 230 | self.optimizer_G.zero_grad() 231 | self.optimizer_A.zero_grad() 232 | self.backward_G() 233 | # D_A 234 | self.optimizer_D_A.zero_grad() 235 | self.backward_D_A() 236 | self.optimizer_D_A.step() 237 | # D_B 238 | self.optimizer_D_B.zero_grad() 239 | self.backward_D_B() 240 | self.optimizer_D_B.step() 241 | def get_current_errors(self): 242 | D_A = self.loss_D_A.data 243 | G_A = self.loss_G_A.data 244 | Cyc_A = self.loss_cycle_A.data 245 | D_B = self.loss_D_B.data 246 | G_B = self.loss_G_B.data 247 | Cyc_B = self.loss_cycle_B.data 248 | if self.opt.identity > 0.0: 249 | idt_A = self.loss_idt_A.data 250 | idt_B = self.loss_idt_B.data 251 | return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), 252 | ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) 253 | else: 254 | return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)]) 255 | 256 | def get_current_visuals(self): 257 | real_A = util.tensor2im(self.real_A.data) 258 | fake_B = util.tensor2im(self.fake_B.data) 259 | rec_A = util.tensor2im(self.rec_A.data) 260 | real_B = util.tensor2im(self.real_B.data) 261 | fake_A = util.tensor2im(self.fake_A.data) 262 | rec_B = util.tensor2im(self.rec_B.data) 263 | 264 | # mask_attn_A = util.mask2im(util.tensor2mask(self.attn_real_A.data)) 265 | # mask_attn_B = util.mask2im(util.tensor2mask(self.attn_real_B.data)) 266 | 267 | attn_real_A = util.mask2heatmap(self.attn_real_A.data) 268 | attn_real_B = util.mask2heatmap(self.attn_real_B.data) 269 | attn_fake_A = util.mask2heatmap(self.attn_fake_A.data) 270 | attn_fake_B = util.mask2heatmap(self.attn_fake_B.data) 271 | attn_real_A = util.overlay(real_A, attn_real_A) 272 | attn_real_B = util.overlay(real_B, attn_real_B) 273 | attn_fake_A = util.overlay(fake_A, attn_fake_A) 274 | attn_fake_B = util.overlay(fake_B, attn_fake_B) 275 | return OrderedDict([('real_A', real_A), ('fake_B', fake_B), 276 | ('rec_A', rec_A), ('attn_real_A:', attn_real_A), 277 | ('attn_fake_B:', attn_fake_B), 278 | ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), 279 | ('attn_real_B:', attn_real_B), ('attn_fake_A:', attn_fake_A)#,('foreground_mask_B', mask_attn_B) 280 | ]) 281 | 282 | def save(self, label): 283 | self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) 284 | self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) 285 | self.save_network(self.netA_A, 'A_A', label, self.gpu_ids) 286 | self.save_network(self.netA_B, 'A_B', label, self.gpu_ids) 287 | self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) 288 | self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) 289 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | import functools 6 | from torch.autograd import Variable 7 | from torch.optim import lr_scheduler 8 | import numpy as np 9 | import torchvision.models as models 10 | import torchvision.transforms as transforms 11 | from pdb import set_trace as st 12 | ############################################################################### 13 | # Functions 14 | ############################################################################### 15 | 16 | 17 | def weights_init_normal(m): 18 | classname = m.__class__.__name__ 19 | # print(classname) 20 | if classname.find('Conv') != -1: 21 | init.uniform(m.weight.data, 0.0, 0.02) 22 | elif classname.find('Linear') != -1: 23 | init.uniform(m.weight.data, 0.0, 0.02) 24 | elif classname.find('BatchNorm2d') != -1: 25 | init.uniform(m.weight.data, 1.0, 0.02) 26 | init.constant(m.bias.data, 0.0) 27 | 28 | 29 | def weights_init_xavier(m): 30 | classname = m.__class__.__name__ 31 | # print(classname) 32 | if classname.find('Conv') != -1: 33 | init.xavier_normal(m.weight.data, gain=1) 34 | elif classname.find('Linear') != -1: 35 | init.xavier_normal(m.weight.data, gain=1) 36 | elif classname.find('BatchNorm2d') != -1: 37 | init.uniform(m.weight.data, 1.0, 0.02) 38 | init.constant(m.bias.data, 0.0) 39 | 40 | 41 | def weights_init_kaiming(m): 42 | classname = m.__class__.__name__ 43 | # print(classname) 44 | if classname.find('Conv') != -1: 45 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 46 | elif classname.find('Linear') != -1: 47 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 48 | elif classname.find('BatchNorm2d') != -1: 49 | init.uniform(m.weight.data, 1.0, 0.02) 50 | init.constant(m.bias.data, 0.0) 51 | 52 | 53 | def weights_init_orthogonal(m): 54 | classname = m.__class__.__name__ 55 | print(classname) 56 | if classname.find('Conv') != -1: 57 | init.orthogonal(m.weight.data, gain=1) 58 | elif classname.find('Linear') != -1: 59 | init.orthogonal(m.weight.data, gain=1) 60 | elif classname.find('BatchNorm2d') != -1: 61 | init.uniform(m.weight.data, 1.0, 0.02) 62 | init.constant(m.bias.data, 0.0) 63 | 64 | 65 | def init_weights(net, init_type='normal'): 66 | print('initialization method [%s]' % init_type) 67 | if init_type == 'normal': 68 | net.apply(weights_init_normal) 69 | elif init_type == 'xavier': 70 | net.apply(weights_init_xavier) 71 | elif init_type == 'kaiming': 72 | net.apply(weights_init_kaiming) 73 | elif init_type == 'orthogonal': 74 | net.apply(weights_init_orthogonal) 75 | else: 76 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 77 | 78 | def get_norm_layer(norm_type='instance'): 79 | if norm_type == 'batch': 80 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 81 | elif norm_type == 'instance': 82 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False,track_running_stats=True) 83 | elif layer_type == 'none': 84 | norm_layer = None 85 | else: 86 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 87 | return norm_layer 88 | 89 | 90 | def get_scheduler(optimizer, opt): 91 | if opt.lr_policy == 'lambda': 92 | def lambda_rule(epoch): 93 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 94 | return lr_l 95 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 96 | elif opt.lr_policy == 'step': 97 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 98 | elif opt.lr_policy == 'plateau': 99 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 100 | else: 101 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 102 | return scheduler 103 | 104 | 105 | def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', gpu_ids=[]): 106 | netG = None 107 | use_gpu = len(gpu_ids) > 0 108 | norm_layer = get_norm_layer(norm_type=norm) 109 | 110 | if use_gpu: 111 | assert(torch.cuda.is_available()) 112 | 113 | if which_model_netG == 'resnet_9blocks': 114 | netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids) 115 | elif which_model_netG == 'resnet_6blocks': 116 | netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids) 117 | elif which_model_netG == 'encoder': 118 | netG = EncoderGenerator(input_nc, ngf=64, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids) 119 | elif which_model_netG == 'decoder': 120 | netG = DecoderGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids) 121 | elif which_model_netG == 'gated_9blocks': 122 | netG = GatednetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids) 123 | elif which_model_netG == 'inserted_9blocks': 124 | netG = MaskDecoder(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids) 125 | elif which_model_netG == 'attn_in_9blocks': 126 | netG = AttnInNetGenerator(input_nc, output_nc, ngf,norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids) 127 | elif which_model_netG == 'attn_vgg_9blocks': 128 | netG = AttnVGGNetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids, isTrain=isTrain) 129 | elif which_model_netG == 'attn_gated_9blocks': 130 | netG = AttnGatedGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids) 131 | elif which_model_netG == 'unet_128': 132 | netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) 133 | elif which_model_netG == 'unet_256': 134 | netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) 135 | else: 136 | raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) 137 | if len(gpu_ids) > 0: 138 | # netG.cuda(device_id=gpu_ids[0]) 139 | netG.cuda(gpu_ids[0]) 140 | init_weights(netG, init_type=init_type) 141 | return netG 142 | 143 | 144 | def define_A(input_nc, output_nc, ngf, which_model_netA, norm='batch', use_dropout=False, init_type='normal', gpu_ids=[]): 145 | netA = None 146 | use_gpu = len(gpu_ids) > 0 147 | norm_layer = get_norm_layer(norm_type=norm) 148 | 149 | if use_gpu: 150 | assert(torch.cuda.is_available()) 151 | 152 | if which_model_netA == 'resnet_9blocks': 153 | netA = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids) 154 | elif which_model_netA == 'resnet_6blocks': 155 | netA = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids) 156 | else: 157 | raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) 158 | netA = nn.Sequential(netA, Norm()) 159 | if len(gpu_ids) > 0: 160 | # netA.cuda(device_id=gpu_ids[0]) 161 | netA.cuda(gpu_ids[0]) 162 | init_weights(netA, init_type=init_type) 163 | return netA 164 | 165 | 166 | def define_D(input_nc, ndf, which_model_netD, 167 | n_layers_D, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[], Norm=False): 168 | netD = None 169 | use_gpu = len(gpu_ids) > 0 170 | norm_layer = get_norm_layer(norm_type=norm) 171 | 172 | if use_gpu: 173 | assert(torch.cuda.is_available()) 174 | if which_model_netD == 'basic': 175 | netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids,use_Norm=Norm) 176 | elif which_model_netD == 'n_layers': 177 | netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids,use_Norm=Norm) 178 | elif which_model_netD == 'attn_implicit': 179 | netD = AttnImDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) 180 | elif which_model_netD == 'attn_groundtruth': 181 | netD = AttnGtDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) 182 | elif which_model_netD == 'attn_infeat': 183 | netD = AttnfeatDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) 184 | elif which_model_netD == 'multiscale_layers': 185 | netD = MultiscaleNLayerDiscriminator(input_nc, ndf, n_layers=[5], norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) 186 | elif which_model_netD == 'multiscale': 187 | netD = MultiscaleDiscriminator(input_nc, ndf, n_layers=n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) 188 | else: 189 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % 190 | which_model_netD) 191 | if use_gpu: 192 | # netD.cuda(device_id=gpu_ids[0]) 193 | netD.cuda(gpu_ids[0]) 194 | init_weights(netD, init_type=init_type) 195 | return netD 196 | 197 | 198 | def print_network(net): 199 | num_params = 0 200 | for param in net.parameters(): 201 | num_params += param.numel() 202 | print(net) 203 | print('Total number of parameters: %d' % num_params) 204 | 205 | def Attn1_G(encoder,transformer,decoder): 206 | netG = AttnNet1(encoder,transformer,decoder) 207 | return netG 208 | def Attn2_G(encoder, decoder): 209 | netG = AttnNet2(encoder, decoder) 210 | return netG 211 | def Onesided_G(encoder, transformer, decoder): 212 | netG = AutoNet(encoder, transformer, decoder) 213 | return netG 214 | ############################################################################## 215 | # Classes 216 | ############################################################################## 217 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 218 | # When LSGAN is used, it is basically same as MSELoss, 219 | # but it abstracts away the need to create the target label tensor 220 | # that has the same size as the input 221 | class GANLoss(nn.Module): 222 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, 223 | tensor=torch.FloatTensor): 224 | super(GANLoss, self).__init__() 225 | self.real_label = target_real_label 226 | self.fake_label = target_fake_label 227 | self.real_label_var = None 228 | self.fake_label_var = None 229 | self.Tensor = tensor 230 | if use_lsgan: 231 | self.loss = nn.MSELoss() 232 | # if use_hingeloss: 233 | # self.loss=nn.MSELoss() 234 | # if use_wgan: 235 | # self.loss = nn.MSELoss() 236 | else: 237 | self.loss = nn.BCELoss() 238 | 239 | def get_target_tensor(self, input, target_is_real): 240 | target_tensor = None 241 | if target_is_real: 242 | create_label = ((self.real_label_var is None) or 243 | (self.real_label_var.numel() != input.numel())) 244 | if create_label: 245 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 246 | self.real_label_var = Variable(real_tensor, requires_grad=False) 247 | target_tensor = self.real_label_var 248 | else: 249 | create_label = ((self.fake_label_var is None) or 250 | (self.fake_label_var.numel() != input.numel())) 251 | if create_label: 252 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 253 | self.fake_label_var = Variable(fake_tensor, requires_grad=False) 254 | target_tensor = self.fake_label_var 255 | return target_tensor 256 | 257 | def __call__(self, input, target_is_real, loss_weight=None): 258 | if type(input) == list: 259 | if loss_weight == None: 260 | loss_weight = [] 261 | for i in range(len(input)): 262 | loss_weight.append(1./len(input)) 263 | target_tensor=self.get_target_tensor(input[0], target_is_real) 264 | loss = self.loss(input[0], target_tensor) 265 | if len(input)>1: 266 | loss= loss*loss_weight[0] 267 | for i in range(1,len(input)): 268 | target_tensor=self.get_target_tensor(input[i], target_is_real) 269 | loss += self.loss(input[i], target_tensor)*loss_weight[i] 270 | return loss 271 | else: 272 | target_tensor = self.get_target_tensor(input, target_is_real) 273 | return self.loss(input, target_tensor) 274 | def attnLoss(self, input, target_tensor, attn_map): 275 | target_tensor = self.get_target_tensor(input, target_tensor) * attn_map.detach() 276 | return self.loss(input * attn_map, target_tensor) 277 | 278 | class vgg_normalize: 279 | def __init__(self,bz): 280 | self.size = (224,224) 281 | self.mean_data=0.5 282 | self.std_data=0.5 283 | self.mean_vgg=Variable(torch.Tensor(bz,3,1,1)).cuda() 284 | self.mean_vgg[:,0]=0.485; self.mean_vgg[:,1]=0.456; self.mean_vgg[:,2]=0.406 285 | self.std_vgg=Variable(torch.Tensor(bz,3,1,1)).cuda() 286 | self.std_vgg[:,0]=0.229; self.std_vgg[:,1]=0.224; self.std_vgg[:,2]=0.225 287 | def __call__(self, input): 288 | input=F.interpolate(input, size=self.size, mode='bilinear') 289 | vgg_input = input.mul(self.std_data).add(self.mean_data).sub(self.mean_vgg[:input.size(0)]).div(self.std_vgg[:input.size(0)]) 290 | return vgg_input 291 | 292 | class Maskloss(nn.Module): 293 | def __init__(self): 294 | super(Maskloss, self).__init__() 295 | def forward(self, input, target, mask, size_average=True): 296 | loss = (1-mask)*torch.abs(input-target) 297 | if size_average: 298 | return torch.mean(loss) 299 | else: 300 | return torch.sum(loss) 301 | 302 | # Defines the generator that consists of Resnet blocks between a few 303 | # downsampling/upsampling operations. 304 | # Code and idea originally from Justin Johnson's architecture. 305 | # https://github.com/jcjohnson/fast-neural-style/ 306 | class Norm(nn.Module): 307 | def __init(self): 308 | super(Norm, self).__init__() 309 | def forward(self, input): 310 | output = torch.clamp(torch.abs(input), 0 ,1) 311 | return output 312 | class Normalization(nn.Module): 313 | """docstring for Normalization""" 314 | def __init__(self): 315 | super(Normalization, self).__init__() 316 | def forward(self, input): 317 | output = torch.div(input,torch.sum(torch.sum(input,2,keepdim=True),3,keepdim=True)) 318 | return output 319 | 320 | class ResnetGenerator(nn.Module): 321 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], padding_type='reflect'): 322 | assert(n_blocks >= 0) 323 | super(ResnetGenerator, self).__init__() 324 | self.input_nc = input_nc 325 | self.output_nc = output_nc 326 | self.ngf = ngf 327 | self.gpu_ids = gpu_ids 328 | if type(norm_layer) == functools.partial: 329 | use_bias = norm_layer.func == nn.InstanceNorm2d 330 | else: 331 | use_bias = norm_layer == nn.InstanceNorm2d 332 | 333 | model = [nn.ReflectionPad2d(3), 334 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, 335 | bias=use_bias), 336 | norm_layer(ngf), 337 | nn.ReLU(True)] 338 | 339 | n_downsampling = 2 340 | for i in range(n_downsampling): 341 | mult = 2**i 342 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, 343 | stride=2, padding=1, bias=use_bias), 344 | norm_layer(ngf * mult * 2), 345 | nn.ReLU(True)] 346 | 347 | mult = 2**n_downsampling 348 | for i in range(n_blocks): 349 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 350 | 351 | for i in range(n_downsampling): 352 | mult = 2**(n_downsampling - i) 353 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 354 | kernel_size=3, stride=2, 355 | padding=1, output_padding=1, 356 | bias=use_bias), 357 | norm_layer(int(ngf * mult / 2)), 358 | nn.ReLU(True)] 359 | model += [nn.ReflectionPad2d(3)] 360 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),nn.Tanh()] 361 | self.model = nn.Sequential(*model) 362 | def forward(self, input): 363 | if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): 364 | return nn.parallel.data_parallel(self.model, input, self.gpu_ids) 365 | else: 366 | return self.model(input) 367 | 368 | 369 | 370 | # Define a resnet block 371 | class ResnetBlock(nn.Module): 372 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 373 | super(ResnetBlock, self).__init__() 374 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 375 | 376 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 377 | conv_block = [] 378 | p = 0 379 | if padding_type == 'reflect': 380 | conv_block += [nn.ReflectionPad2d(1)] 381 | elif padding_type == 'replicate': 382 | conv_block += [nn.ReplicationPad2d(1)] 383 | elif padding_type == 'zero': 384 | p = 1 385 | else: 386 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 387 | 388 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 389 | norm_layer(dim), 390 | nn.ReLU(True)] 391 | if use_dropout: 392 | conv_block += [nn.Dropout(0.5)] 393 | 394 | p = 0 395 | if padding_type == 'reflect': 396 | conv_block += [nn.ReflectionPad2d(1)] 397 | elif padding_type == 'replicate': 398 | conv_block += [nn.ReplicationPad2d(1)] 399 | elif padding_type == 'zero': 400 | p = 1 401 | else: 402 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 403 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 404 | norm_layer(dim)] 405 | 406 | return nn.Sequential(*conv_block) 407 | 408 | def forward(self, x): 409 | out = x + self.conv_block(x) 410 | return out 411 | 412 | 413 | # Defines the Unet generator. 414 | # |num_downs|: number of downsamplings in UNet. For example, 415 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1 416 | # at the bottleneck 417 | class UnetGenerator(nn.Module): 418 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, 419 | norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[]): 420 | super(UnetGenerator, self).__init__() 421 | self.gpu_ids = gpu_ids 422 | 423 | # construct unet structure 424 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) 425 | for i in range(num_downs - 5): 426 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 427 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 428 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 429 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 430 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) 431 | 432 | self.model = unet_block 433 | 434 | def forward(self, input): 435 | if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): 436 | return nn.parallel.data_parallel(self.model, input, self.gpu_ids) 437 | else: 438 | return self.model(input) 439 | 440 | 441 | # Defines the submodule with skip connection. 442 | # X -------------------identity---------------------- X 443 | # |-- downsampling -- |submodule| -- upsampling --| 444 | class UnetSkipConnectionBlock(nn.Module): 445 | def __init__(self, outer_nc, inner_nc, input_nc=None, 446 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 447 | super(UnetSkipConnectionBlock, self).__init__() 448 | self.outermost = outermost 449 | if type(norm_layer) == functools.partial: 450 | use_bias = norm_layer.func == nn.InstanceNorm2d 451 | use_bias = norm_layer == nn.InstanceNorm2d 452 | if input_nc is None: 453 | input_nc = outer_nc 454 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 455 | stride=2, padding=1, bias=use_bias) 456 | downrelu = nn.LeakyReLU(0.2, True) 457 | downnorm = norm_layer(inner_nc) 458 | uprelu = nn.ReLU(True) 459 | upnorm = norm_layer(outer_nc) 460 | 461 | if outermost: 462 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 463 | kernel_size=4, stride=2, 464 | padding=1) 465 | down = [downconv] 466 | up = [uprelu, upconv, nn.Tanh()] 467 | model = down + [submodule] + up 468 | elif innermost: 469 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 470 | kernel_size=4, stride=2, 471 | padding=1, bias=use_bias) 472 | down = [downrelu, downconv] 473 | up = [uprelu, upconv, upnorm] 474 | model = down + up 475 | else: 476 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 477 | kernel_size=4, stride=2, 478 | padding=1, bias=use_bias) 479 | down = [downrelu, downconv, downnorm] 480 | up = [uprelu, upconv, upnorm] 481 | 482 | if use_dropout: 483 | model = down + [submodule] + up + [nn.Dropout(0.5)] 484 | else: 485 | model = down + [submodule] + up 486 | 487 | self.model = nn.Sequential(*model) 488 | 489 | def forward(self, x): 490 | if self.outermost: 491 | return self.model(x) 492 | else: 493 | return torch.cat([x, self.model(x)], 1) 494 | 495 | 496 | # Defines the PatchGAN discriminator with the specified arguments. 497 | class NLayerDiscriminator(nn.Module): 498 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[], use_Norm=False): 499 | super(NLayerDiscriminator, self).__init__() 500 | self.gpu_ids = gpu_ids 501 | if type(norm_layer) == functools.partial: 502 | use_bias = norm_layer.func == nn.InstanceNorm2d 503 | else: 504 | use_bias = norm_layer == nn.InstanceNorm2d 505 | 506 | kw = 4 507 | padw = 1 508 | sequence = [ 509 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 510 | nn.LeakyReLU(0.2, True) 511 | ] # 3*256*256 -> 64*128*128 512 | 513 | nf_mult = 1 514 | nf_mult_prev = 1 515 | for n in range(1, n_layers):# n_layers=3, [1,2] 516 | nf_mult_prev = nf_mult 517 | nf_mult = min(2**n, 8) 518 | sequence += [ 519 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 520 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), 521 | norm_layer(ndf * nf_mult), 522 | nn.LeakyReLU(0.2, True) 523 | ] # 64*128*128 -> 128*64*64 -> 256*32*32 524 | 525 | nf_mult_prev = nf_mult 526 | nf_mult = min(2**n_layers, 8) 527 | sequence += [ 528 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 529 | kernel_size=kw, stride=1, padding=padw, bias=use_bias), 530 | norm_layer(ndf * nf_mult), 531 | nn.LeakyReLU(0.2, True) 532 | ] # 256*32*32 -> 512*32*32 533 | 534 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # 512*32*32 -> 1*32*32 535 | 536 | if use_sigmoid: 537 | sequence += [nn.Sigmoid()] 538 | 539 | if use_Norm: 540 | sequence += [Norm(),Normalization()] 541 | self.model = nn.Sequential(*sequence) 542 | 543 | def forward(self, input): 544 | if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): 545 | return nn.parallel.data_parallel(self.model, input, self.gpu_ids) 546 | else: 547 | return self.model(input) 548 | def attn(self, input): 549 | feat_net=nn.Sequential(*list(self.model.children())[:-1]) 550 | attn = feat_net.forward(input) 551 | attn = attn[:,:,2:29,2:29] 552 | upsampler = nn.UpsamplingBilinear2d(size=256) 553 | attn = upsampler(torch.sum(torch.abs(attn),1,keepdim=True)) 554 | attn = attn / torch.max(attn) 555 | return attn 556 | def pred(self, input): 557 | pred=self.model(input) 558 | pred[pred>1]=1 559 | pred[pred<0]=0 560 | upsampler = nn.UpsamplingBilinear2d(size=256) 561 | pred = upsampler(pred) 562 | return pred 563 | --------------------------------------------------------------------------------