├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── c2pDis.cpython-310.pyc │ ├── c2pGen.cpython-310.pyc │ ├── networks.cpython-310.pyc │ ├── p2cGen.cpython-310.pyc │ └── basic_layer.cpython-310.pyc ├── p2cGen.py ├── networks.py ├── c2pGen.py ├── c2pDis.py └── basic_layer.py ├── reference.png ├── README.md ├── server.py └── pixelization.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /reference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arenasys/pixelization_inference/HEAD/reference.png -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arenasys/pixelization_inference/HEAD/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/c2pDis.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arenasys/pixelization_inference/HEAD/models/__pycache__/c2pDis.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/c2pGen.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arenasys/pixelization_inference/HEAD/models/__pycache__/c2pGen.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/networks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arenasys/pixelization_inference/HEAD/models/__pycache__/networks.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/p2cGen.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arenasys/pixelization_inference/HEAD/models/__pycache__/p2cGen.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/basic_layer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arenasys/pixelization_inference/HEAD/models/__pycache__/basic_layer.cpython-310.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Inference script for [Pixelization](https://github.com/WuZongWei6/Pixelization) 2 | All credit to those guys, I just stripped down thier code to make it simple to use. 3 | 4 | ## Usage 5 | ``` 6 | git clone https://github.com/arenatemp/pixelization_inference 7 | pip install pillow torch torchvision numpy 8 | ``` 9 | Download the pretrained models into the pixelization_inference folder: 10 | [pixelart_vgg19.pth](https://drive.google.com/file/d/1VRYKQOsNlE1w1LXje3yTRU5THN2MGdMM/view?usp=sharing) 11 | [alias_net.pth](https://drive.google.com/file/d/17f2rKnZOpnO9ATwRXgqLz5u5AZsyDvq_/view?usp=sharing) 12 | [160_net_G_A.pth](https://drive.google.com/file/d/1i_8xL3stbLWNF4kdQJ50ZhnRFhSDh3Az/view?usp=sharing) 13 | ``` 14 | python pixelization.py --input input_file.png 15 | ``` -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import http.server 3 | import uploadserver 4 | import shutil 5 | 6 | working_dir = 'working' 7 | inputs_dir = os.path.join(working_dir, 'inputs') 8 | outputs_dir = os.path.join(working_dir, 'outputs') 9 | 10 | os.makedirs(working_dir, exist_ok=True) 11 | os.makedirs(inputs_dir, exist_ok=True) 12 | os.makedirs(outputs_dir, exist_ok=True) 13 | 14 | def handle_uploads(): 15 | import subprocess 16 | result = subprocess.run(['python', 'pixelization.py', '--input', working_dir, '--output', outputs_dir]) 17 | if result.returncode != 0: 18 | print('Error processing images') 19 | return 20 | else: 21 | print('Images processed') 22 | 23 | # debug 24 | # for file in os.listdir(unprocessed_dir): 25 | # old_file = os.path.join(unprocessed_dir, file) 26 | # new_file = os.path.join(outputs_dir, file) 27 | # shutil.copy(old_file, new_file) 28 | 29 | for file in os.listdir(working_dir): 30 | old_file = os.path.join(working_dir, file) 31 | if not os.path.isfile(old_file): 32 | continue 33 | new_file = os.path.join(inputs_dir, file) 34 | os.rename(old_file, new_file) 35 | 36 | class Args: 37 | port = 8000 38 | cgi = False 39 | allow_replace = False 40 | bind = None 41 | directory = working_dir 42 | theme = 'dark' 43 | server_certificate = None 44 | client_certificate = None 45 | basic_auth = None 46 | basic_auth_upload = None 47 | uploadserver.args = Args() 48 | 49 | old_receive_upload = uploadserver.receive_upload 50 | def new_receive_upload(handler: http.server.BaseHTTPRequestHandler): 51 | status, message = old_receive_upload(handler) 52 | 53 | if status != http.HTTPStatus.BAD_REQUEST: 54 | handle_uploads() 55 | 56 | return status, message 57 | uploadserver.receive_upload = new_receive_upload 58 | 59 | uploadserver.serve_forever() -------------------------------------------------------------------------------- /models/p2cGen.py: -------------------------------------------------------------------------------- 1 | from .basic_layer import * 2 | 3 | 4 | class P2CGen(nn.Module): 5 | def __init__(self, input_dim, output_dim, dim, n_downsample, n_res, activ='relu', pad_type='reflect'): 6 | super(P2CGen, self).__init__() 7 | self.RGBEnc = RGBEncoder(input_dim, dim, n_downsample, n_res, "in", activ, pad_type=pad_type) 8 | self.RGBDec = RGBDecoder(self.RGBEnc.output_dim, output_dim, n_downsample, n_res, res_norm='in', 9 | activ=activ, pad_type=pad_type) 10 | 11 | def forward(self, x): 12 | x = self.RGBEnc(x) 13 | # print("encoder->>", x.shape) 14 | x = self.RGBDec(x) 15 | # print(x_small.shape) 16 | # print(x_middle.shape) 17 | # print(x_big.shape) 18 | #return y_small, y_middle, y_big 19 | return x 20 | 21 | 22 | class RGBEncoder(nn.Module): 23 | def __init__(self, input_dim, dim, n_downsample, n_res, norm, activ, pad_type): 24 | super(RGBEncoder, self).__init__() 25 | self.model = [] 26 | self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] 27 | # downsampling blocks 28 | for i in range(n_downsample): 29 | self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] 30 | dim *= 2 31 | # residual blocks 32 | self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)] 33 | self.model = nn.Sequential(*self.model) 34 | self.output_dim = dim 35 | 36 | def forward(self, x): 37 | return self.model(x) 38 | 39 | 40 | class RGBDecoder(nn.Module): 41 | def __init__(self, dim, output_dim, n_upsample, n_res, res_norm, activ='relu', pad_type='zero'): 42 | super(RGBDecoder, self).__init__() 43 | # self.model = [] 44 | # # AdaIN residual blocks 45 | # self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)] 46 | # # upsampling blocks 47 | # for i in range(n_upsample): 48 | # self.model += [nn.Upsample(scale_factor=2, mode='nearest'), 49 | # ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)] 50 | # dim //= 2 51 | # # use reflection padding in the last conv layer 52 | # self.model += [ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)] 53 | # self.model = nn.Sequential(*self.model) 54 | self.Res_Blocks = ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type) 55 | self.upsample_block1 = nn.Upsample(scale_factor=2, mode='nearest') 56 | self.conv_1 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type) 57 | dim //= 2 58 | self.upsample_block2 = nn.Upsample(scale_factor=2, mode='nearest') 59 | self.conv_2 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type) 60 | dim //= 2 61 | self.conv_3 = ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type) 62 | 63 | def forward(self, x): 64 | x = self.Res_Blocks(x) 65 | # print(x.shape) 66 | x = self.upsample_block1(x) 67 | # print(x.shape) 68 | x = self.conv_1(x) 69 | # print(x_small.shape) 70 | x = self.upsample_block2(x) 71 | # print(x.shape) 72 | x = self.conv_2(x) 73 | # print(x_middle.shape) 74 | x = self.conv_3(x) 75 | # print(x_big.shape) 76 | return x 77 | -------------------------------------------------------------------------------- /pixelization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import warnings 3 | warnings.simplefilter(action='ignore', category=FutureWarning) 4 | 5 | import torch 6 | import torchvision.transforms as transforms 7 | from PIL import Image 8 | import numpy as np 9 | from models.networks import define_G 10 | import glob 11 | 12 | class Model(): 13 | def __init__(self, device="cpu"): 14 | self.device = torch.device(device) 15 | self.G_A_net = None 16 | self.alias_net = None 17 | self.ref_t = None 18 | 19 | def load(self): 20 | with torch.no_grad(): 21 | self.G_A_net = define_G(3, 3, 64, "c2pGen", "instance", False, "normal", 0.02, [0]) 22 | self.alias_net = define_G(3, 3, 64, "antialias", "instance", False, "normal", 0.02, [0]) 23 | 24 | G_A_state = torch.load("160_net_G_A.pth", map_location=str(self.device)) 25 | for p in list(G_A_state.keys()): 26 | G_A_state["module."+str(p)] = G_A_state.pop(p) 27 | self.G_A_net.load_state_dict(G_A_state) 28 | 29 | alias_state = torch.load("alias_net.pth", map_location=str(self.device)) 30 | for p in list(alias_state.keys()): 31 | alias_state["module."+str(p)] = alias_state.pop(p) 32 | self.alias_net.load_state_dict(alias_state) 33 | 34 | ref_img = Image.open("reference.png").convert('L') 35 | self.ref_t = process(greyscale(ref_img)).to(self.device) 36 | 37 | def pixelize(self, in_img, out_img): 38 | with torch.no_grad(): 39 | in_img = Image.open(in_img).convert('RGB') 40 | in_t = process(in_img).to(self.device) 41 | 42 | out_t = self.alias_net(self.G_A_net(in_t, self.ref_t)) 43 | 44 | save(out_t, out_img) 45 | 46 | def greyscale(img): 47 | gray = np.array(img.convert('L')) 48 | tmp = np.expand_dims(gray, axis=2) 49 | tmp = np.concatenate((tmp, tmp, tmp), axis=-1) 50 | return Image.fromarray(tmp) 51 | 52 | def process(img): 53 | ow,oh = img.size 54 | 55 | nw = int(round(ow / 4) * 4) 56 | nh = int(round(oh / 4) * 4) 57 | 58 | left = (ow - nw)//2 59 | top = (oh - nh)//2 60 | right = left + nw 61 | bottom = top + nh 62 | 63 | img = img.crop((left, top, right, bottom)) 64 | 65 | trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 66 | 67 | return trans(img)[None, :, :, :] 68 | 69 | def save(tensor, file): 70 | img = tensor.data[0].cpu().float().numpy() 71 | img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0 72 | img = img.astype(np.uint8) 73 | img = Image.fromarray(img) 74 | img = img.resize((img.size[0]//4, img.size[1]//4), resample=Image.Resampling.NEAREST) 75 | img = img.resize((img.size[0]*4, img.size[1]*4), resample=Image.Resampling.NEAREST) 76 | img.save(file) 77 | 78 | 79 | def pixelize_cli(): 80 | import argparse 81 | import os 82 | parser = argparse.ArgumentParser(description='Pixelization') 83 | parser.add_argument('--input', type=str, default=None, required=True, help='path to image or directory') 84 | parser.add_argument('--output', type=str, default=None, required=False, help='path to save image/images') 85 | parser.add_argument('--cpu', action='store_true', help='use CPU instead of GPU') 86 | 87 | args = parser.parse_args() 88 | in_path = args.input 89 | out_path = args.output 90 | use_cpu = args.cpu 91 | 92 | if not os.path.exists("alias_net.pth"): 93 | print("missing models") 94 | 95 | pairs = [] 96 | 97 | if os.path.isdir(in_path): 98 | in_images = glob.glob(in_path + "/*.png") + glob.glob(in_path + "/*.jpg") 99 | if not out_path: 100 | out_path = os.path.join(in_path, "outputs") 101 | if not os.path.exists(out_path): 102 | os.makedirs(out_path) 103 | elif os.path.isfile(out_path): 104 | print("output cant be a file if input is a directory") 105 | return 106 | for i in in_images: 107 | pairs += [(i, i.replace(in_path, out_path))] 108 | elif os.path.isfile(in_path): 109 | if not out_path: 110 | base, ext = os.path.splitext(in_path) 111 | out_path = base+"_pixelized"+ext 112 | else: 113 | if os.path.isdir(out_path): 114 | _, file = os.path.split(in_path) 115 | out_path = os.path.join(out_path, file) 116 | pairs = [(in_path, out_path)] 117 | 118 | m = Model(device = "cpu" if use_cpu else "cuda") 119 | m.load() 120 | 121 | for in_file, out_file in pairs: 122 | print("PROCESSING", in_file, "TO", out_file) 123 | m.pixelize(in_file, out_file) 124 | 125 | if __name__ == "__main__": 126 | pixelize_cli() -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | from .c2pGen import * 7 | from .p2cGen import * 8 | from .c2pDis import * 9 | 10 | class Identity(nn.Module): 11 | def forward(self, x): 12 | return x 13 | 14 | def get_norm_layer(norm_type='instance'): 15 | """Return a normalization layer 16 | 17 | Parameters: 18 | norm_type (str) -- the name of the normalization layer: batch | instance | none 19 | 20 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 21 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 22 | """ 23 | if norm_type == 'batch': 24 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 25 | elif norm_type == 'instance': 26 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 27 | elif norm_type == 'none': 28 | def norm_layer(x): return Identity() 29 | else: 30 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 31 | return norm_layer 32 | 33 | 34 | def get_scheduler(optimizer, opt): 35 | """Return a learning rate scheduler 36 | 37 | Parameters: 38 | optimizer -- the optimizer of the network 39 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  40 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 41 | 42 | For 'linear', we keep the same learning rate for the first epochs 43 | and linearly decay the rate to zero over the next epochs. 44 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 45 | See https://pytorch.org/docs/stable/optim.html for more details. 46 | """ 47 | if opt.lr_policy == 'linear': 48 | def lambda_rule(epoch): 49 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) 50 | return lr_l 51 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 52 | elif opt.lr_policy == 'step': 53 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 54 | elif opt.lr_policy == 'plateau': 55 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 56 | elif opt.lr_policy == 'cosine': 57 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) 58 | else: 59 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 60 | return scheduler 61 | 62 | 63 | def init_weights(net, init_type='normal', init_gain=0.02): 64 | """Initialize network weights. 65 | 66 | Parameters: 67 | net (network) -- network to be initialized 68 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 69 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 70 | 71 | """ 72 | def init_func(m): # define the initialization function 73 | classname = m.__class__.__name__ 74 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 75 | if init_type == 'normal': 76 | init.normal_(m.weight.data, 0.0, init_gain) 77 | elif init_type == 'xavier': 78 | init.xavier_normal_(m.weight.data, gain=init_gain) 79 | elif init_type == 'kaiming': 80 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 81 | elif init_type == 'orthogonal': 82 | init.orthogonal_(m.weight.data, gain=init_gain) 83 | else: 84 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 85 | if hasattr(m, 'bias') and m.bias is not None: 86 | init.constant_(m.bias.data, 0.0) 87 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 88 | init.normal_(m.weight.data, 1.0, init_gain) 89 | init.constant_(m.bias.data, 0.0) 90 | 91 | #print('initialize network with %s' % init_type) 92 | net.apply(init_func) # apply the initialization function 93 | 94 | 95 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 96 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 97 | Parameters: 98 | net (network) -- the network to be initialized 99 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 100 | gain (float) -- scaling factor for normal, xavier and orthogonal. 101 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 102 | 103 | Return an initialized network. 104 | """ 105 | if len(gpu_ids) > 0: 106 | assert(torch.cuda.is_available()) 107 | net.to(gpu_ids[0]) 108 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 109 | init_weights(net, init_type, init_gain=init_gain) 110 | return net 111 | 112 | 113 | def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): 114 | """Create a generator 115 | 116 | Parameters: 117 | input_nc (int) -- the number of channels in input images 118 | output_nc (int) -- the number of channels in output images 119 | ngf (int) -- the number of filters in the last conv layer 120 | netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 121 | norm (str) -- the name of normalization layers used in the network: batch | instance | none 122 | use_dropout (bool) -- if use dropout layers. 123 | init_type (str) -- the name of our initialization method. 124 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 125 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 126 | 127 | Returns a generator 128 | """ 129 | net = None 130 | norm_layer = get_norm_layer(norm_type=norm) 131 | 132 | if netG == 'c2pGen': # style_dim mlp_dim 133 | net = C2PGen(input_nc, output_nc, ngf, 2, 4, 256, 256, activ='relu', pad_type='reflect') 134 | #print('c2pgen resblock is 8') 135 | elif netG == 'p2cGen': 136 | net = P2CGen(input_nc, output_nc, ngf, 2, 3, activ='relu', pad_type='reflect') 137 | elif netG == 'antialias': 138 | net = AliasNet(input_nc, output_nc, ngf, 2, 3, activ='relu', pad_type='reflect') 139 | else: 140 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG) 141 | return init_net(net, init_type, init_gain, gpu_ids) 142 | 143 | 144 | 145 | def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): 146 | """Create a discriminator 147 | 148 | Parameters: 149 | input_nc (int) -- the number of channels in input images 150 | ndf (int) -- the number of filters in the first conv layer 151 | netD (str) -- the architecture's name: basic | n_layers | pixel 152 | n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' 153 | norm (str) -- the type of normalization layers used in the network. 154 | init_type (str) -- the name of the initialization method. 155 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 156 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 157 | 158 | Returns a discriminator 159 | """ 160 | net = None 161 | norm_layer = get_norm_layer(norm_type=norm) 162 | 163 | 164 | if netD == 'CPDis': 165 | net = CPDis(image_size=256, conv_dim=64, repeat_num=3, norm='SN') 166 | elif netD == 'CPDis_cls': 167 | net = CPDis_cls(image_size=256, conv_dim=64, repeat_num=3, norm='SN') 168 | else: 169 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) 170 | return init_net(net, init_type, init_gain, gpu_ids) 171 | 172 | 173 | class GANLoss(nn.Module): 174 | """Define different GAN objectives. 175 | 176 | The GANLoss class abstracts away the need to create the target label tensor 177 | that has the same size as the input. 178 | """ 179 | 180 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 181 | """ Initialize the GANLoss class. 182 | 183 | Parameters: 184 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 185 | target_real_label (bool) - - label for a real image 186 | target_fake_label (bool) - - label of a fake image 187 | 188 | Note: Do not use sigmoid as the last layer of Discriminator. 189 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 190 | """ 191 | super(GANLoss, self).__init__() 192 | self.register_buffer('real_label', torch.tensor(target_real_label)) 193 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 194 | self.gan_mode = gan_mode 195 | if gan_mode == 'lsgan': 196 | self.loss = nn.MSELoss() 197 | elif gan_mode == 'vanilla': 198 | self.loss = nn.BCEWithLogitsLoss() 199 | elif gan_mode in ['wgangp']: 200 | self.loss = None 201 | else: 202 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 203 | 204 | def get_target_tensor(self, prediction, target_is_real): 205 | """Create label tensors with the same size as the input. 206 | 207 | Parameters: 208 | prediction (tensor) - - tpyically the prediction from a discriminator 209 | target_is_real (bool) - - if the ground truth label is for real images or fake images 210 | 211 | Returns: 212 | A label tensor filled with ground truth label, and with the size of the input 213 | """ 214 | 215 | if target_is_real: 216 | target_tensor = self.real_label 217 | else: 218 | target_tensor = self.fake_label 219 | return target_tensor.expand_as(prediction) 220 | 221 | def __call__(self, prediction, target_is_real): 222 | """Calculate loss given Discriminator's output and grount truth labels. 223 | 224 | Parameters: 225 | prediction (tensor) - - tpyically the prediction output from a discriminator 226 | target_is_real (bool) - - if the ground truth label is for real images or fake images 227 | 228 | Returns: 229 | the calculated loss. 230 | """ 231 | if self.gan_mode in ['lsgan', 'vanilla']: 232 | target_tensor = self.get_target_tensor(prediction, target_is_real) 233 | loss = self.loss(prediction, target_tensor) 234 | elif self.gan_mode == 'wgangp': 235 | if target_is_real: 236 | loss = -prediction.mean() 237 | else: 238 | loss = prediction.mean() 239 | return loss 240 | 241 | 242 | 243 | 244 | -------------------------------------------------------------------------------- /models/c2pGen.py: -------------------------------------------------------------------------------- 1 | from .basic_layer import * 2 | import torchvision.models as models 3 | 4 | 5 | 6 | class AliasNet(nn.Module): 7 | def __init__(self, input_dim, output_dim, dim, n_downsample, n_res, activ='relu', pad_type='reflect'): 8 | super(AliasNet, self).__init__() 9 | self.RGBEnc = AliasRGBEncoder(input_dim, dim, n_downsample, n_res, "in", activ, pad_type=pad_type) 10 | self.RGBDec = AliasRGBDecoder(self.RGBEnc.output_dim, output_dim, n_downsample, n_res, res_norm='in', 11 | activ=activ, pad_type=pad_type) 12 | 13 | def forward(self, x): 14 | x = self.RGBEnc(x) 15 | x = self.RGBDec(x) 16 | return x 17 | 18 | 19 | class AliasRGBEncoder(nn.Module): 20 | def __init__(self, input_dim, dim, n_downsample, n_res, norm, activ, pad_type): 21 | super(AliasRGBEncoder, self).__init__() 22 | self.model = [] 23 | self.model += [AliasConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] 24 | # downsampling blocks 25 | for i in range(n_downsample): 26 | self.model += [AliasConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] 27 | dim *= 2 28 | # residual blocks 29 | self.model += [AliasResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)] 30 | self.model = nn.Sequential(*self.model) 31 | self.output_dim = dim 32 | 33 | def forward(self, x): 34 | return self.model(x) 35 | 36 | 37 | class AliasRGBDecoder(nn.Module): 38 | def __init__(self, dim, output_dim, n_upsample, n_res, res_norm, activ='relu', pad_type='zero'): 39 | super(AliasRGBDecoder, self).__init__() 40 | # self.model = [] 41 | # # AdaIN residual blocks 42 | # self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)] 43 | # # upsampling blocks 44 | # for i in range(n_upsample): 45 | # self.model += [nn.Upsample(scale_factor=2, mode='nearest'), 46 | # ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)] 47 | # dim //= 2 48 | # # use reflection padding in the last conv layer 49 | # self.model += [ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)] 50 | # self.model = nn.Sequential(*self.model) 51 | self.Res_Blocks = AliasResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type) 52 | self.upsample_block1 = nn.Upsample(scale_factor=2, mode='nearest') 53 | self.conv_1 = AliasConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type) 54 | dim //= 2 55 | self.upsample_block2 = nn.Upsample(scale_factor=2, mode='nearest') 56 | self.conv_2 = AliasConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type) 57 | dim //= 2 58 | self.conv_3 = AliasConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type) 59 | 60 | def forward(self, x): 61 | x = self.Res_Blocks(x) 62 | # print(x.shape) 63 | x = self.upsample_block1(x) 64 | # print(x.shape) 65 | x = self.conv_1(x) 66 | # print(x_small.shape) 67 | x = self.upsample_block2(x) 68 | # print(x.shape) 69 | x = self.conv_2(x) 70 | # print(x_middle.shape) 71 | x = self.conv_3(x) 72 | # print(x_big.shape) 73 | return x 74 | 75 | 76 | class C2PGen(nn.Module): 77 | def __init__(self, input_dim, output_dim, dim, n_downsample, n_res, style_dim, mlp_dim, activ='relu', pad_type='reflect'): 78 | super(C2PGen, self).__init__() 79 | self.PBEnc = PixelBlockEncoder(input_dim, dim, style_dim, norm='none', activ=activ, pad_type=pad_type) 80 | self.RGBEnc = RGBEncoder(input_dim, dim, n_downsample, n_res, "in", activ, pad_type=pad_type) 81 | self.RGBDec = RGBDecoder(self.RGBEnc.output_dim, output_dim, n_downsample, n_res, res_norm='adain', 82 | activ=activ, pad_type=pad_type) 83 | self.MLP = MLP(style_dim, 2048, mlp_dim, 3, norm='none', activ=activ) 84 | 85 | def forward(self, clipart, pixelart, s=1): 86 | feature = self.RGBEnc(clipart) 87 | code = self.PBEnc(pixelart) 88 | result, cellcode = self.fuse(feature, code, s) 89 | return result#, cellcode #return cellcode when visualizing the cell size code 90 | 91 | def fuse(self, content, style_code, s=1): 92 | #print("MLP input:code's shape:", style_code.shape) 93 | adain_params = self.MLP(style_code) * s # [batch,2048] 94 | #print("MLP output:adain_params's shape", adain_params.shape) 95 | #self.assign_adain_params(adain_params, self.RGBDec) 96 | images = self.RGBDec(content, adain_params) 97 | return images, adain_params 98 | 99 | def assign_adain_params(self, adain_params, model): 100 | # assign the adain_params to the AdaIN layers in model 101 | for m in model.modules(): 102 | if m.__class__.__name__ == "AdaptiveInstanceNorm2d": 103 | mean = adain_params[:, :m.num_features] 104 | std = adain_params[:, m.num_features:2 * m.num_features] 105 | m.bias = mean.contiguous().view(-1) 106 | m.weight = std.contiguous().view(-1) 107 | if adain_params.size(1) > 2 * m.num_features: 108 | adain_params = adain_params[:, 2 * m.num_features:] 109 | 110 | def get_num_adain_params(self, model): 111 | # return the number of AdaIN parameters needed by the model 112 | num_adain_params = 0 113 | for m in model.modules(): 114 | if m.__class__.__name__ == "AdaptiveInstanceNorm2d": 115 | num_adain_params += 2 * m.num_features 116 | return num_adain_params 117 | 118 | 119 | class PixelBlockEncoder(nn.Module): 120 | def __init__(self, input_dim, dim, style_dim, norm, activ, pad_type): 121 | super(PixelBlockEncoder, self).__init__() 122 | vgg19 = models.vgg.vgg19() 123 | vgg19.classifier._modules['6'] = nn.Linear(4096, 7, bias=True) 124 | vgg19.load_state_dict(torch.load('./pixelart_vgg19.pth')) 125 | self.vgg = vgg19.features 126 | for p in self.vgg.parameters(): 127 | p.requires_grad = False 128 | # vgg19 = models.vgg.vgg19(pretrained=False) 129 | # vgg19.load_state_dict(torch.load('./vgg.pth')) 130 | # self.vgg = vgg19.features 131 | # for p in self.vgg.parameters(): 132 | # p.requires_grad = False 133 | 134 | 135 | self.conv1 = ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type) # 3->64,concat 136 | dim = dim * 2 137 | self.conv2 = ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 128->128 138 | dim = dim * 2 139 | self.conv3 = ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 256->256 140 | dim = dim * 2 141 | self.conv4 = ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 512->512 142 | dim = dim * 2 143 | 144 | self.model = [] 145 | self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling 146 | self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)] 147 | self.model = nn.Sequential(*self.model) 148 | self.output_dim = dim 149 | 150 | def get_features(self, image, model, layers=None): 151 | if layers is None: 152 | layers = {'0': 'conv1_1', '5': 'conv2_1', '10': 'conv3_1', '19': 'conv4_1'} 153 | features = {} 154 | x = image 155 | # model._modules is a dictionary holding each module in the model 156 | for name, layer in model._modules.items(): 157 | x = layer(x) 158 | if name in layers: 159 | features[layers[name]] = x 160 | return features 161 | 162 | def componet_enc(self, x): 163 | # x [16,3,256,256] 164 | # factor_img [16,7,256,256] 165 | vgg_aux = self.get_features(x, self.vgg) # x是3通道灰度图 166 | #x = torch.cat([x, factor_img], dim=1) # [16,3+7,256,256] 167 | x = self.conv1(x) # 64 256 256 168 | x = torch.cat([x, vgg_aux['conv1_1']], dim=1) # 128 256 256 169 | x = self.conv2(x) # 128 128 128 170 | x = torch.cat([x, vgg_aux['conv2_1']], dim=1) # 256 128 128 171 | x = self.conv3(x) # 256 64 64 172 | x = torch.cat([x, vgg_aux['conv3_1']], dim=1) # 512 64 64 173 | x = self.conv4(x) # 512 32 32 174 | x = torch.cat([x, vgg_aux['conv4_1']], dim=1) # 1024 32 32 175 | x = self.model(x) 176 | return x 177 | 178 | def forward(self, x): 179 | code = self.componet_enc(x) 180 | return code 181 | 182 | class RGBEncoder(nn.Module): 183 | def __init__(self, input_dim, dim, n_downsample, n_res, norm, activ, pad_type): 184 | super(RGBEncoder, self).__init__() 185 | self.model = [] 186 | self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] 187 | # downsampling blocks 188 | for i in range(n_downsample): 189 | self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] 190 | dim *= 2 191 | # residual blocks 192 | self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)] 193 | self.model = nn.Sequential(*self.model) 194 | self.output_dim = dim 195 | 196 | def forward(self, x): 197 | return self.model(x) 198 | 199 | 200 | class RGBDecoder(nn.Module): 201 | def __init__(self, dim, output_dim, n_upsample, n_res, res_norm, activ='relu', pad_type='zero'): 202 | super(RGBDecoder, self).__init__() 203 | # self.model = [] 204 | # # AdaIN residual blocks 205 | # self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)] 206 | # # upsampling blocks 207 | # for i in range(n_upsample): 208 | # self.model += [nn.Upsample(scale_factor=2, mode='nearest'), 209 | # ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)] 210 | # dim //= 2 211 | # # use reflection padding in the last conv layer 212 | # self.model += [ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)] 213 | # self.model = nn.Sequential(*self.model) 214 | #self.Res_Blocks = ModulationResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type) 215 | self.mod_conv_1 = ModulationConvBlock(256,256,3) 216 | self.mod_conv_2 = ModulationConvBlock(256,256,3) 217 | self.mod_conv_3 = ModulationConvBlock(256,256,3) 218 | self.mod_conv_4 = ModulationConvBlock(256,256,3) 219 | self.mod_conv_5 = ModulationConvBlock(256,256,3) 220 | self.mod_conv_6 = ModulationConvBlock(256,256,3) 221 | self.mod_conv_7 = ModulationConvBlock(256,256,3) 222 | self.mod_conv_8 = ModulationConvBlock(256,256,3) 223 | self.upsample_block1 = nn.Upsample(scale_factor=2, mode='nearest') 224 | self.conv_1 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type) 225 | dim //= 2 226 | self.upsample_block2 = nn.Upsample(scale_factor=2, mode='nearest') 227 | self.conv_2 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type) 228 | dim //= 2 229 | self.conv_3 = ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type) 230 | 231 | # def forward(self, x): 232 | # residual = x 233 | # out = self.model(x) 234 | # out += residual 235 | # return out 236 | def forward(self, x, code): 237 | residual = x 238 | x = self.mod_conv_1(x, code[:, :256]) 239 | x = self.mod_conv_2(x, code[:, 256*1:256*2]) 240 | x += residual 241 | residual = x 242 | x = self.mod_conv_2(x, code[:, 256*2:256 * 3]) 243 | x = self.mod_conv_2(x, code[:, 256*3:256 * 4]) 244 | x += residual 245 | residual =x 246 | x = self.mod_conv_2(x, code[:, 256*4:256 * 5]) 247 | x = self.mod_conv_2(x, code[:, 256*5:256 * 6]) 248 | x += residual 249 | residual = x 250 | x = self.mod_conv_2(x, code[:, 256*6:256 * 7]) 251 | x = self.mod_conv_2(x, code[:, 256*7:256 * 8]) 252 | x += residual 253 | # print(x.shape) 254 | x = self.upsample_block1(x) 255 | # print(x.shape) 256 | x = self.conv_1(x) 257 | # print(x_small.shape) 258 | x = self.upsample_block2(x) 259 | # print(x.shape) 260 | x = self.conv_2(x) 261 | # print(x_middle.shape) 262 | x = self.conv_3(x) 263 | # print(x_big.shape) 264 | return x 265 | 266 | -------------------------------------------------------------------------------- /models/c2pDis.py: -------------------------------------------------------------------------------- 1 | from .basic_layer import * 2 | import math 3 | from torch.nn import Parameter 4 | #from pytorch_metric_learning import losses 5 | 6 | ''' 7 | Margin code is borrowed from https://github.com/MuggleWang/CosFace_pytorch and https://github.com/wujiyang/Face_Pytorch. 8 | ''' 9 | def cosine_sim(x1, x2, dim=1, eps=1e-8): 10 | ip = torch.mm(x1, x2.t()) # w 7*512 11 | w1 = torch.norm(x1, 2, dim) 12 | w2 = torch.norm(x2, 2, dim) 13 | return ip / torch.ger(w1,w2).clamp(min=eps) 14 | 15 | class MarginCosineProduct(nn.Module): 16 | r"""Implement of large margin cosine distance: : 17 | Args: 18 | in_features: size of each input sample 19 | out_features: size of each output sample 20 | s: norm of input feature 21 | m: margin 22 | """ 23 | 24 | def __init__(self, in_features, out_features, s=30.0, m=0.40): 25 | super(MarginCosineProduct, self).__init__() 26 | self.in_features = in_features 27 | self.out_features = out_features 28 | self.s = s 29 | self.m = m 30 | self.weight = Parameter(torch.Tensor(out_features, in_features)) # 7 512 31 | nn.init.xavier_uniform_(self.weight) 32 | #stdv = 1. / math.sqrt(self.weight.size(1)) 33 | #self.weight.data.uniform_(-stdv, stdv) 34 | 35 | def forward(self, input, label): 36 | cosine = cosine_sim(input, self.weight) # 1*512 7*512 37 | # cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 38 | # --------------------------- convert label to one-hot --------------------------- 39 | # https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507 40 | one_hot = torch.zeros_like(cosine) 41 | one_hot.scatter_(1, label.view(-1, 1), 1.0) 42 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 43 | output = self.s * (cosine - one_hot * self.m) 44 | 45 | return output 46 | 47 | def __repr__(self): 48 | return self.__class__.__name__ + '(' \ 49 | + 'in_features=' + str(self.in_features) \ 50 | + ', out_features=' + str(self.out_features) \ 51 | + ', s=' + str(self.s) \ 52 | + ', m=' + str(self.m) + ')' 53 | 54 | class ArcMarginProduct(nn.Module): 55 | def __init__(self, in_feature=128, out_feature=10575, s=32.0, m=0.50, easy_margin=False): 56 | super(ArcMarginProduct, self).__init__() 57 | self.in_feature = in_feature 58 | self.out_feature = out_feature 59 | self.s = s 60 | self.m = m 61 | self.weight = Parameter(torch.Tensor(out_feature, in_feature)) 62 | nn.init.xavier_uniform_(self.weight) 63 | 64 | self.easy_margin = easy_margin 65 | self.cos_m = math.cos(m) 66 | self.sin_m = math.sin(m) 67 | 68 | # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°] 69 | self.th = math.cos(math.pi - m) 70 | self.mm = math.sin(math.pi - m) * m 71 | 72 | def forward(self, x, label): 73 | # cos(theta) 74 | cosine = F.linear(F.normalize(x), F.normalize(self.weight)) 75 | # cos(theta + m) 76 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 77 | phi = cosine * self.cos_m - sine * self.sin_m 78 | 79 | if self.easy_margin: 80 | phi = torch.where(cosine > 0, phi, cosine) 81 | else: 82 | phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm) 83 | 84 | #one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu') 85 | one_hot = torch.zeros_like(cosine) 86 | one_hot.scatter_(1, label.view(-1, 1), 1) 87 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 88 | output = output * self.s 89 | 90 | return output 91 | 92 | 93 | class MultiMarginProduct(nn.Module): 94 | def __init__(self, in_feature=128, out_feature=10575, s=32.0, m1=0.20, m2=0.35, easy_margin=False): 95 | super(MultiMarginProduct, self).__init__() 96 | self.in_feature = in_feature 97 | self.out_feature = out_feature 98 | self.s = s 99 | self.m1 = m1 100 | self.m2 = m2 101 | self.weight = Parameter(torch.Tensor(out_feature, in_feature)) 102 | nn.init.xavier_uniform_(self.weight) 103 | 104 | self.easy_margin = easy_margin 105 | self.cos_m1 = math.cos(m1) 106 | self.sin_m1 = math.sin(m1) 107 | 108 | # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°] 109 | self.th = math.cos(math.pi - m1) 110 | self.mm = math.sin(math.pi - m1) * m1 111 | 112 | def forward(self, x, label): 113 | # cos(theta) 114 | cosine = F.linear(F.normalize(x), F.normalize(self.weight)) 115 | # cos(theta + m1) 116 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 117 | phi = cosine * self.cos_m1 - sine * self.sin_m1 118 | 119 | if self.easy_margin: 120 | phi = torch.where(cosine > 0, phi, cosine) 121 | else: 122 | phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm) 123 | 124 | 125 | one_hot = torch.zeros_like(cosine) 126 | one_hot.scatter_(1, label.view(-1, 1), 1) 127 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # additive angular margin 128 | output = output - one_hot * self.m2 # additive cosine margin 129 | output = output * self.s 130 | 131 | return output 132 | 133 | 134 | class CPDis(nn.Module): 135 | """PatchGAN.""" 136 | def __init__(self, image_size=256, conv_dim=64, repeat_num=3, norm='SN'): 137 | super(CPDis, self).__init__() 138 | 139 | layers = [] 140 | if norm == 'SN': 141 | layers.append(spectral_norm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))) 142 | else: 143 | layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)) 144 | layers.append(nn.LeakyReLU(0.01, inplace=True)) 145 | 146 | curr_dim = conv_dim 147 | for i in range(1, repeat_num): 148 | if norm == 'SN': 149 | layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1))) 150 | else: 151 | layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1)) 152 | layers.append(nn.LeakyReLU(0.01, inplace=True)) 153 | curr_dim = curr_dim * 2 154 | 155 | # k_size = int(image_size / np.power(2, repeat_num)) 156 | if norm == 'SN': 157 | layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1))) 158 | else: 159 | layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1)) 160 | layers.append(nn.LeakyReLU(0.01, inplace=True)) 161 | curr_dim = curr_dim * 2 162 | 163 | self.main = nn.Sequential(*layers) 164 | if norm == 'SN': 165 | self.conv1 = spectral_norm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)) 166 | else: 167 | self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False) 168 | 169 | def forward(self, x): 170 | if x.ndim == 5: 171 | x = x.squeeze(0) 172 | assert x.ndim == 4, x.ndim 173 | h = self.main(x) 174 | # out_real = self.conv1(h) 175 | out_makeup = self.conv1(h) 176 | # return out_real.squeeze(), out_makeup.squeeze() 177 | return out_makeup 178 | 179 | 180 | class CPDis_cls(nn.Module): 181 | """PatchGAN.""" 182 | def __init__(self, image_size=256, conv_dim=64, repeat_num=3, norm='SN'): 183 | super(CPDis_cls, self).__init__() 184 | 185 | layers = [] 186 | if norm == 'SN': 187 | layers.append(spectral_norm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))) 188 | else: 189 | layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)) 190 | layers.append(nn.LeakyReLU(0.01, inplace=True)) 191 | 192 | curr_dim = conv_dim 193 | for i in range(1, repeat_num): 194 | if norm == 'SN': 195 | layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1))) 196 | else: 197 | layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1)) 198 | layers.append(nn.LeakyReLU(0.01, inplace=True)) 199 | curr_dim = curr_dim * 2 200 | 201 | # k_size = int(image_size / np.power(2, repeat_num)) 202 | if norm == 'SN': 203 | layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1))) 204 | else: 205 | layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1)) 206 | layers.append(nn.LeakyReLU(0.01, inplace=True)) 207 | curr_dim = curr_dim * 2 208 | 209 | self.main = nn.Sequential(*layers) 210 | if norm == 'SN': 211 | self.conv1 = spectral_norm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)) 212 | self.classifier_pool = nn.AdaptiveAvgPool2d(1) 213 | self.classifier_conv = nn.Conv2d(512, 512, 1, 1, 0) 214 | self.classifier = MarginCosineProduct(512,7)#ArcMarginProduct(512, 7) 215 | print("Using Large Margin Cosine Loss.") 216 | 217 | else: 218 | self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False) 219 | 220 | def forward(self, x, label): 221 | if x.ndim == 5: 222 | x = x.squeeze(0) 223 | assert x.ndim == 4, x.ndim 224 | h = self.main(x) # ([1, 512, 31, 31]) 225 | #print(out_cls.shape) 226 | out_cls = self.classifier_pool(h) 227 | #print(out_cls.shape) 228 | out_cls = self.classifier_conv(out_cls) 229 | #print(out_cls.shape) 230 | out_cls = torch.squeeze(out_cls, -1) 231 | out_cls = torch.squeeze(out_cls, -1) 232 | out_cls = self.classifier(out_cls, label) 233 | out_makeup = self.conv1(h) # torch.Size([1, 1, 30, 30]) 234 | # return out_real.squeeze(), out_makeup.squeeze() 235 | return out_makeup, out_cls 236 | 237 | class SpectralNorm(object): 238 | def __init__(self): 239 | self.name = "weight" 240 | # print(self.name) 241 | self.power_iterations = 1 242 | 243 | def compute_weight(self, module): 244 | u = getattr(module, self.name + "_u") 245 | v = getattr(module, self.name + "_v") 246 | w = getattr(module, self.name + "_bar") 247 | 248 | height = w.data.shape[0] 249 | for _ in range(self.power_iterations): 250 | v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) 251 | u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) 252 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 253 | sigma = u.dot(w.view(height, -1).mv(v)) 254 | return w / sigma.expand_as(w) 255 | 256 | @staticmethod 257 | def apply(module): 258 | name = "weight" 259 | fn = SpectralNorm() 260 | 261 | try: 262 | u = getattr(module, name + "_u") 263 | v = getattr(module, name + "_v") 264 | w = getattr(module, name + "_bar") 265 | except AttributeError: 266 | w = getattr(module, name) 267 | height = w.data.shape[0] 268 | width = w.view(height, -1).data.shape[1] 269 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 270 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 271 | w_bar = Parameter(w.data) 272 | 273 | # del module._parameters[name] 274 | 275 | module.register_parameter(name + "_u", u) 276 | module.register_parameter(name + "_v", v) 277 | module.register_parameter(name + "_bar", w_bar) 278 | 279 | # remove w from parameter list 280 | del module._parameters[name] 281 | 282 | setattr(module, name, fn.compute_weight(module)) 283 | 284 | # recompute weight before every forward() 285 | module.register_forward_pre_hook(fn) 286 | 287 | return fn 288 | 289 | def remove(self, module): 290 | weight = self.compute_weight(module) 291 | delattr(module, self.name) 292 | del module._parameters[self.name + '_u'] 293 | del module._parameters[self.name + '_v'] 294 | del module._parameters[self.name + '_bar'] 295 | module.register_parameter(self.name, Parameter(weight.data)) 296 | 297 | def __call__(self, module, inputs): 298 | setattr(module, self.name, self.compute_weight(module)) 299 | 300 | def spectral_norm(module): 301 | SpectralNorm.apply(module) 302 | return module 303 | 304 | def remove_spectral_norm(module): 305 | name = 'weight' 306 | for k, hook in module._forward_pre_hooks.items(): 307 | if isinstance(hook, SpectralNorm) and hook.name == name: 308 | hook.remove(module) 309 | del module._forward_pre_hooks[k] 310 | return module 311 | 312 | raise ValueError("spectral_norm of '{}' not found in {}" 313 | .format(name, module)) 314 | -------------------------------------------------------------------------------- /models/basic_layer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class ModulationConvBlock(nn.Module): 7 | def __init__(self, input_dim, output_dim, kernel_size, stride=1, 8 | padding=0, norm='none', activation='relu', pad_type='zero'): 9 | super(ModulationConvBlock, self).__init__() 10 | self.in_c = input_dim 11 | self.out_c = output_dim 12 | self.ksize = kernel_size 13 | self.stride = 1 14 | self.padding = kernel_size // 2 15 | 16 | self.eps = 1e-8 17 | weight_shape = (output_dim, input_dim, kernel_size, kernel_size) 18 | fan_in = kernel_size * kernel_size *input_dim 19 | wscale = 1.0/np.sqrt(fan_in) 20 | 21 | self.weight = nn.Parameter(torch.randn(*weight_shape)) 22 | self.wscale = wscale 23 | 24 | self.bias = nn.Parameter(torch.zeros(output_dim)) 25 | 26 | self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) 27 | self.activate_scale = np.sqrt(2.0) 28 | 29 | def forward(self, x, code): 30 | batch,in_channel,height,width = x.shape 31 | weight = self.weight * self.wscale 32 | _weight = weight.view(1, self.ksize, self.ksize, self.in_c, self.out_c) 33 | _weight = _weight * code.view(batch, 1, 1, self.in_c, 1) 34 | # demodulation 35 | _weight_norm = torch.sqrt(torch.sum(_weight ** 2, dim=[1, 2, 3]) + self.eps) 36 | _weight = _weight / _weight_norm.view(batch, 1, 1, 1, self.out_c) 37 | # fused_modulate 38 | x = x.view(1, batch * self.in_c, x.shape[2], x.shape[3]) 39 | weight = _weight.permute(1, 2, 3, 0, 4).reshape( 40 | self.ksize, self.ksize, self.in_c, batch * self.out_c) 41 | # not use_conv2d_transpose 42 | weight = weight.permute(3, 2, 0, 1) 43 | x = F.conv2d(x, 44 | weight=weight, 45 | bias=None, 46 | stride=self.stride, 47 | padding=self.padding, 48 | groups=(batch if True else 1)) 49 | 50 | if True:#self.fused_modulate: 51 | x = x.view(batch, self.out_c, height, width) 52 | x = x+self.bias.view(1,-1,1,1) 53 | x = self.activate(x)*self.activate_scale 54 | return x 55 | 56 | 57 | class AliasConvBlock(nn.Module): 58 | def __init__(self, input_dim, output_dim, kernel_size, stride, 59 | padding=0, norm='none', activation='relu', pad_type='zero'): 60 | super(AliasConvBlock, self).__init__() 61 | self.use_bias = True 62 | # initialize padding 63 | if pad_type == 'reflect': 64 | self.pad = nn.ReflectionPad2d(padding) 65 | elif pad_type == 'replicate': 66 | self.pad = nn.ReplicationPad2d(padding) 67 | elif pad_type == 'zero': 68 | self.pad = nn.ZeroPad2d(padding) 69 | else: 70 | assert 0, "Unsupported padding type: {}".format(pad_type) 71 | 72 | # initialize normalization 73 | norm_dim = output_dim 74 | if norm == 'bn': 75 | self.norm = nn.BatchNorm2d(norm_dim) 76 | elif norm == 'in': 77 | # self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True) 78 | self.norm = nn.InstanceNorm2d(norm_dim) 79 | elif norm == 'ln': 80 | self.norm = LayerNorm(norm_dim) 81 | elif norm == 'adain': 82 | self.norm = AdaptiveInstanceNorm2d(norm_dim) 83 | elif norm == 'none' or norm == 'sn': 84 | self.norm = None 85 | else: 86 | assert 0, "Unsupported normalization: {}".format(norm) 87 | 88 | # initialize activation 89 | if activation == 'relu': 90 | self.activation = nn.ReLU(inplace=True) 91 | elif activation == 'lrelu': 92 | self.activation = nn.LeakyReLU(0.2, inplace=True) 93 | elif activation == 'prelu': 94 | self.activation = nn.PReLU() 95 | elif activation == 'selu': 96 | self.activation = nn.SELU(inplace=True) 97 | elif activation == 'tanh': 98 | self.activation = nn.Tanh() 99 | elif activation == 'none': 100 | self.activation = None 101 | else: 102 | assert 0, "Unsupported activation: {}".format(activation) 103 | 104 | # initialize convolution 105 | if norm == 'sn': 106 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) 107 | 108 | else: 109 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) 110 | 111 | def forward(self, x): 112 | x = self.conv(self.pad(x)) 113 | if self.norm: 114 | x = self.norm(x) 115 | if self.activation: 116 | x = self.activation(x) 117 | return x 118 | 119 | class AliasResBlocks(nn.Module): 120 | def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'): 121 | super(AliasResBlocks, self).__init__() 122 | self.model = [] 123 | for i in range(num_blocks): 124 | self.model += [AliasResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)] 125 | self.model = nn.Sequential(*self.model) 126 | 127 | def forward(self, x): 128 | return self.model(x) 129 | class AliasResBlock(nn.Module): 130 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): 131 | super(AliasResBlock, self).__init__() 132 | 133 | model = [] 134 | model += [AliasConvBlock(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] 135 | model += [AliasConvBlock(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] 136 | self.model = nn.Sequential(*model) 137 | 138 | def forward(self, x): 139 | residual = x 140 | out = self.model(x) 141 | out += residual 142 | return out 143 | ################################################################################## 144 | # Sequential Models 145 | ################################################################################## 146 | class ResBlocks(nn.Module): 147 | def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'): 148 | super(ResBlocks, self).__init__() 149 | self.model = [] 150 | for i in range(num_blocks): 151 | self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)] 152 | self.model = nn.Sequential(*self.model) 153 | 154 | def forward(self, x): 155 | return self.model(x) 156 | 157 | 158 | class MLP(nn.Module): 159 | def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'): 160 | super(MLP, self).__init__() 161 | self.model = [] 162 | self.model += [linearBlock(input_dim, input_dim, norm=norm, activation=activ)] 163 | self.model += [linearBlock(input_dim, dim, norm=norm, activation=activ)] 164 | for i in range(n_blk - 2): 165 | self.model += [linearBlock(dim, dim, norm=norm, activation=activ)] 166 | self.model += [linearBlock(dim, output_dim, norm='none', activation='none')] # no output activations 167 | self.model = nn.Sequential(*self.model) 168 | 169 | # def forward(self, style0, style1, a=0): 170 | # return self.model[3]((1 - a) * self.model[0:3](style0.view(style0.size(0), -1)) + a * self.model[0:3]( 171 | # style1.view(style1.size(0), -1))) 172 | def forward(self, style0, style1=None, a=0): 173 | style1 = style0 174 | return self.model[3]((1 - a) * self.model[0:3](style0.view(style0.size(0), -1)) + a * self.model[0:3]( 175 | style1.view(style1.size(0), -1))) 176 | ################################################################################## 177 | # Basic Blocks 178 | ################################################################################## 179 | class ResBlock(nn.Module): 180 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): 181 | super(ResBlock, self).__init__() 182 | 183 | model = [] 184 | model += [ConvBlock(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] 185 | model += [ConvBlock(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] 186 | self.model = nn.Sequential(*model) 187 | 188 | def forward(self, x): 189 | residual = x 190 | out = self.model(x) 191 | out += residual 192 | return out 193 | 194 | 195 | class ConvBlock(nn.Module): 196 | def __init__(self, input_dim, output_dim, kernel_size, stride, 197 | padding=0, norm='none', activation='relu', pad_type='zero'): 198 | super(ConvBlock, self).__init__() 199 | self.use_bias = True 200 | # initialize padding 201 | if pad_type == 'reflect': 202 | self.pad = nn.ReflectionPad2d(padding) 203 | elif pad_type == 'replicate': 204 | self.pad = nn.ReplicationPad2d(padding) 205 | elif pad_type == 'zero': 206 | self.pad = nn.ZeroPad2d(padding) 207 | else: 208 | assert 0, "Unsupported padding type: {}".format(pad_type) 209 | 210 | # initialize normalization 211 | norm_dim = output_dim 212 | if norm == 'bn': 213 | self.norm = nn.BatchNorm2d(norm_dim) 214 | elif norm == 'in': 215 | # self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True) 216 | self.norm = nn.InstanceNorm2d(norm_dim) 217 | elif norm == 'ln': 218 | self.norm = LayerNorm(norm_dim) 219 | elif norm == 'adain': 220 | self.norm = AdaptiveInstanceNorm2d(norm_dim) 221 | elif norm == 'none' or norm == 'sn': 222 | self.norm = None 223 | else: 224 | assert 0, "Unsupported normalization: {}".format(norm) 225 | 226 | # initialize activation 227 | if activation == 'relu': 228 | self.activation = nn.ReLU(inplace=True) 229 | elif activation == 'lrelu': 230 | self.activation = nn.LeakyReLU(0.2, inplace=True) 231 | elif activation == 'prelu': 232 | self.activation = nn.PReLU() 233 | elif activation == 'selu': 234 | self.activation = nn.SELU(inplace=True) 235 | elif activation == 'tanh': 236 | self.activation = nn.Tanh() 237 | elif activation == 'none': 238 | self.activation = None 239 | else: 240 | assert 0, "Unsupported activation: {}".format(activation) 241 | 242 | # initialize convolution 243 | if norm == 'sn': 244 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) 245 | 246 | else: 247 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) 248 | 249 | def forward(self, x): 250 | x = self.conv(self.pad(x)) 251 | if self.norm: 252 | x = self.norm(x) 253 | if self.activation: 254 | x = self.activation(x) 255 | return x 256 | 257 | class linearBlock(nn.Module): 258 | def __init__(self, input_dim, output_dim, norm='none', activation='relu'): 259 | super(linearBlock, self).__init__() 260 | use_bias = True 261 | # initialize fully connected layer 262 | if norm == 'sn': 263 | self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias)) 264 | else: 265 | self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) 266 | 267 | # initialize normalization 268 | norm_dim = output_dim 269 | if norm == 'bn': 270 | self.norm = nn.BatchNorm1d(norm_dim) 271 | elif norm == 'in': 272 | self.norm = nn.InstanceNorm1d(norm_dim) 273 | elif norm == 'ln': 274 | self.norm = LayerNorm(norm_dim) 275 | elif norm == 'none' or norm == 'sn': 276 | self.norm = None 277 | else: 278 | assert 0, "Unsupported normalization: {}".format(norm) 279 | 280 | # initialize activation 281 | if activation == 'relu': 282 | self.activation = nn.ReLU(inplace=True) 283 | elif activation == 'lrelu': 284 | self.activation = nn.LeakyReLU(0.2, inplace=True) 285 | elif activation == 'prelu': 286 | self.activation = nn.PReLU() 287 | elif activation == 'selu': 288 | self.activation = nn.SELU(inplace=True) 289 | elif activation == 'tanh': 290 | self.activation = nn.Tanh() 291 | elif activation == 'none': 292 | self.activation = None 293 | else: 294 | assert 0, "Unsupported activation: {}".format(activation) 295 | 296 | def forward(self, x): 297 | out = self.fc(x) 298 | if self.norm: 299 | out = self.norm(out) 300 | if self.activation: 301 | out = self.activation(out) 302 | return out 303 | ################################################################################## 304 | # Normalization layers 305 | ################################################################################## 306 | class AdaptiveInstanceNorm2d(nn.Module): 307 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 308 | super(AdaptiveInstanceNorm2d, self).__init__() 309 | self.num_features = num_features 310 | self.eps = eps 311 | self.momentum = momentum 312 | # weight and bias are dynamically assigned 313 | self.weight = None 314 | self.bias = None 315 | # just dummy buffers, not used 316 | self.register_buffer('running_mean', torch.zeros(num_features)) 317 | self.register_buffer('running_var', torch.ones(num_features)) 318 | 319 | def forward(self, x): 320 | assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!" 321 | b, c = x.size(0), x.size(1) 322 | running_mean = self.running_mean.repeat(b) 323 | running_var = self.running_var.repeat(b) 324 | 325 | # Apply instance norm 326 | x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) 327 | 328 | out = F.batch_norm( 329 | x_reshaped, running_mean, running_var, self.weight, self.bias, 330 | True, self.momentum, self.eps) 331 | 332 | return out.view(b, c, *x.size()[2:]) 333 | 334 | def __repr__(self): 335 | return self.__class__.__name__ + '(' + str(self.num_features) + ')' 336 | 337 | 338 | class LayerNorm(nn.Module): 339 | def __init__(self, num_features, eps=1e-5, affine=True): 340 | super(LayerNorm, self).__init__() 341 | self.num_features = num_features 342 | self.affine = affine 343 | self.eps = eps 344 | 345 | if self.affine: 346 | self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) 347 | self.beta = nn.Parameter(torch.zeros(num_features)) 348 | 349 | def forward(self, x): 350 | shape = [-1] + [1] * (x.dim() - 1) 351 | # print(x.size()) 352 | if x.size(0) == 1: 353 | # These two lines run much faster in pytorch 0.4 than the two lines listed below. 354 | mean = x.view(-1).mean().view(*shape) 355 | std = x.view(-1).std().view(*shape) 356 | else: 357 | mean = x.view(x.size(0), -1).mean(1).view(*shape) 358 | std = x.view(x.size(0), -1).std(1).view(*shape) 359 | 360 | x = (x - mean) / (std + self.eps) 361 | 362 | if self.affine: 363 | shape = [1, -1] + [1] * (x.dim() - 2) 364 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 365 | return x 366 | 367 | 368 | def l2normalize(v, eps=1e-12): 369 | return v / (v.norm() + eps) 370 | 371 | 372 | class SpectralNorm(nn.Module): 373 | """ 374 | Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida 375 | and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan 376 | """ 377 | 378 | def __init__(self, module, name='weight', power_iterations=1): 379 | super(SpectralNorm, self).__init__() 380 | self.module = module 381 | self.name = name 382 | self.power_iterations = power_iterations 383 | if not self._made_params(): 384 | self._make_params() 385 | 386 | def _update_u_v(self): 387 | u = getattr(self.module, self.name + "_u") 388 | v = getattr(self.module, self.name + "_v") 389 | w = getattr(self.module, self.name + "_bar") 390 | 391 | height = w.data.shape[0] 392 | for _ in range(self.power_iterations): 393 | v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) 394 | u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) 395 | 396 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 397 | sigma = u.dot(w.view(height, -1).mv(v)) 398 | setattr(self.module, self.name, w / sigma.expand_as(w)) 399 | 400 | def _made_params(self): 401 | try: 402 | u = getattr(self.module, self.name + "_u") 403 | v = getattr(self.module, self.name + "_v") 404 | w = getattr(self.module, self.name + "_bar") 405 | return True 406 | except AttributeError: 407 | return False 408 | 409 | def _make_params(self): 410 | w = getattr(self.module, self.name) 411 | 412 | height = w.data.shape[0] 413 | width = w.view(height, -1).data.shape[1] 414 | 415 | u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 416 | v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 417 | u.data = l2normalize(u.data) 418 | v.data = l2normalize(v.data) 419 | w_bar = nn.Parameter(w.data) 420 | 421 | del self.module._parameters[self.name] 422 | 423 | self.module.register_parameter(self.name + "_u", u) 424 | self.module.register_parameter(self.name + "_v", v) 425 | self.module.register_parameter(self.name + "_bar", w_bar) 426 | 427 | def forward(self, *args): 428 | self._update_u_v() 429 | return self.module.forward(*args) 430 | --------------------------------------------------------------------------------