├── README.md ├── datasets.py ├── datasets └── download_pix2pix_dataset.sh ├── models.py ├── optimizer.py ├── options.py ├── pix2pix_train.py ├── result ├── 0.png ├── 199.png └── 99.png └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # pix2pix-pytorch 2 | the pytorch version of pix2pix 3 | 4 | ## Requirments 5 | - CUDA 8.0+ 6 | - pytorch 0.3.1 7 | - torchvision 8 | 9 | ## Datasets 10 | - Download a pix2pix dataset (e.g.facades): 11 | ```bash 12 | bash ./datasets/download_pix2pix_dataset.sh facades 13 | ``` 14 | ## Train a model: 15 | ``` 16 | python pix2pix_train.py --data_root 'your data directory' --which_direction "BtoA" 17 | ``` 18 | 19 | ## Result examples 20 | ### epoch-0 21 | ![image](https://github.com/TeeyoHuang/pix2pix-pytorch/blob/master/result/0.png) 22 | ### epoch-99 23 | ![image](https://github.com/TeeyoHuang/pix2pix-pytorch/blob/master/result/99.png) 24 | ### epoch-199 25 | ![image](https://github.com/TeeyoHuang/pix2pix-pytorch/blob/master/result/199.png) 26 | 27 | 28 | ## Reference 29 | [1][Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/pdf/1703.10593.pdf) 30 | ``` 31 | @inproceedings{isola2017image, 32 | title={Image-to-Image Translation with Conditional Adversarial Networks}, 33 | author={Isola, Phillip and Zhu, Jun-Yan and Zhou, Tinghui and Efros, Alexei A}, 34 | booktitle={Computer Vision and Pattern Recognition (CVPR), 2017 IEEE Conference on}, 35 | year={2017} 36 | } 37 | 38 | ``` 39 | 40 | ## Personal-Blog 41 | [teeyohuang](https://blog.csdn.net/Teeyohuang/article/details/82699781) 42 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | import os 4 | import numpy as np 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | from torch.utils.data import DataLoader 9 | 10 | class ImageDataset(Dataset): 11 | def __init__(self, args, root, transforms_=None, mode='train'): 12 | self.transform = transforms.Compose(transforms_) 13 | self.args = args 14 | self.files = sorted(glob.glob(os.path.join(root, mode) + '/*.*')) 15 | #print(self.files) 16 | #input() 17 | 18 | def __getitem__(self, index): 19 | 20 | img = Image.open(self.files[index]) 21 | w, h = img.size 22 | 23 | if self.args.which_direction == 'AtoB': 24 | img_A = img.crop((0, 0, w/2, h)) 25 | img_B = img.crop((w/2, 0, w, h)) 26 | else: 27 | img_B = img.crop((0, 0, w/2, h)) 28 | img_A = img.crop((w/2, 0, w, h)) 29 | 30 | 31 | if np.random.random() < 0.5: 32 | img_A = Image.fromarray(np.array(img_A)[:, ::-1, :], 'RGB') 33 | img_B = Image.fromarray(np.array(img_B)[:, ::-1, :], 'RGB') 34 | 35 | img_A = self.transform(img_A) 36 | img_B = self.transform(img_B) 37 | 38 | return {'A': img_A, 'B': img_B} 39 | 40 | def __len__(self): 41 | return len(self.files) 42 | 43 | 44 | # Configure dataloaders 45 | def Get_dataloader(args): 46 | transforms_ = [ transforms.Resize((args.img_height, args.img_width), Image.BICUBIC), 47 | transforms.ToTensor(), 48 | transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ] 49 | 50 | train_dataloader = DataLoader(ImageDataset(args, "%s/%s" % (args.data_root,args.dataset_name), transforms_=transforms_,mode='train'), 51 | batch_size=args.batch_size, shuffle=True, num_workers=args.n_cpu, drop_last=True) 52 | 53 | test_dataloader = DataLoader(ImageDataset(args, "%s/%s" % (args.data_root,args.dataset_name), transforms_=transforms_, mode='test'), 54 | batch_size=10, shuffle=True, num_workers=1, drop_last=True) 55 | 56 | val_dataloader = DataLoader(ImageDataset(args, "%s/%s" % (args.data_root,args.dataset_name), transforms_=transforms_, mode='val'), 57 | batch_size=10, shuffle=True, num_workers=1, drop_last=True) 58 | 59 | return train_dataloader, test_dataloader, val_dataloader 60 | -------------------------------------------------------------------------------- /datasets/download_pix2pix_dataset.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | URL=https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/$FILE.tar.gz 3 | TAR_FILE=./datasets/$FILE.tar.gz 4 | TARGET_DIR=./datasets/$FILE/ 5 | wget -N $URL -O $TAR_FILE 6 | mkdir $TARGET_DIR 7 | tar -zxvf $TAR_FILE -C ./datasets/ 8 | rm $TAR_FILE -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | def weights_init_normal(m): 6 | classname = m.__class__.__name__ 7 | if classname.find('Conv') != -1: 8 | torch.nn.init.normal(m.weight.data, 0.0, 0.02) 9 | elif classname.find('BatchNorm2d') != -1: 10 | torch.nn.init.normal(m.weight.data, 1.0, 0.02) 11 | torch.nn.init.constant(m.bias.data, 0.0) 12 | 13 | def print_network(net): 14 | num_params = 0 15 | for param in net.parameters(): 16 | num_params += param.numel() 17 | print(net) 18 | print('Total number of parameters: %d' % num_params) 19 | 20 | ############################## 21 | # U-NET 22 | ############################## 23 | 24 | class UNetDown(nn.Module): 25 | def __init__(self, in_size, out_size, normalize=True, dropout=0.0): 26 | super(UNetDown, self).__init__() 27 | layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)] 28 | if normalize: 29 | layers.append(nn.InstanceNorm2d(out_size)) 30 | layers.append(nn.LeakyReLU(0.2)) 31 | if dropout: 32 | layers.append(nn.Dropout(dropout)) 33 | self.model = nn.Sequential(*layers) 34 | 35 | def forward(self, x): 36 | return self.model(x) 37 | 38 | class UNetUp(nn.Module): 39 | def __init__(self, in_size, out_size, dropout=0.0): 40 | super(UNetUp, self).__init__() 41 | layers = [ nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False), 42 | nn.InstanceNorm2d(out_size), 43 | nn.ReLU(inplace=True)] 44 | if dropout: 45 | layers.append(nn.Dropout(dropout)) 46 | 47 | self.model = nn.Sequential(*layers) 48 | 49 | def forward(self, x, skip_input): 50 | x = self.model(x) 51 | x = torch.cat((x, skip_input), 1) 52 | 53 | return x 54 | 55 | class GeneratorUNet(nn.Module): 56 | def __init__(self, in_channels=3, out_channels=3): 57 | super(GeneratorUNet, self).__init__() 58 | 59 | self.down1 = UNetDown(in_channels, 64, normalize=False) 60 | self.down2 = UNetDown(64, 128) 61 | self.down3 = UNetDown(128, 256) 62 | self.down4 = UNetDown(256, 512, dropout=0.5) 63 | self.down5 = UNetDown(512, 512, dropout=0.5) 64 | self.down6 = UNetDown(512, 512, dropout=0.5) 65 | self.down7 = UNetDown(512, 512, dropout=0.5) 66 | self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5) 67 | 68 | self.up1 = UNetUp(512, 512, dropout=0.5) 69 | self.up2 = UNetUp(1024, 512, dropout=0.5) 70 | self.up3 = UNetUp(1024, 512, dropout=0.5) 71 | self.up4 = UNetUp(1024, 512, dropout=0.5) 72 | self.up5 = UNetUp(1024, 256) 73 | self.up6 = UNetUp(512, 128) 74 | self.up7 = UNetUp(256, 64) 75 | 76 | ''' 77 | self.final = nn.Sequential( 78 | nn.Upsample(scale_factor=2), 79 | nn.ZeroPad2d((1, 0, 1, 0)), 80 | nn.Conv2d(128, out_channels, 4, padding=1), 81 | nn.Tanh() 82 | ) 83 | ''' 84 | self.up8 = nn.Sequential( 85 | nn.ConvTranspose2d(128, out_channels, 4, 2, 1), 86 | nn.Tanh() 87 | ) 88 | def forward(self, x): 89 | # U-Net generator with skip connections from encoder to decoder 90 | d1 = self.down1(x) 91 | d2 = self.down2(d1) 92 | d3 = self.down3(d2) 93 | d4 = self.down4(d3) 94 | d5 = self.down5(d4) 95 | d6 = self.down6(d5) 96 | d7 = self.down7(d6) 97 | d8 = self.down8(d7) 98 | u1 = self.up1(d8, d7) 99 | u2 = self.up2(u1, d6) 100 | u3 = self.up3(u2, d5) 101 | u4 = self.up4(u3, d4) 102 | u5 = self.up5(u4, d3) 103 | u6 = self.up6(u5, d2) 104 | u7 = self.up7(u6, d1) 105 | 106 | return self.up8(u7) 107 | 108 | 109 | ############################## 110 | # Discriminator 111 | ############################## 112 | 113 | 114 | class Discriminator_n_layers(nn.Module): 115 | def __init__(self, args ): 116 | super(Discriminator_n_layers, self).__init__() 117 | 118 | n_layers = args.n_D_layers 119 | in_channels = args.out_channels 120 | def discriminator_block(in_filters, out_filters, k=4, s=2, p=1, norm=True, sigmoid=False): 121 | """Returns downsampling layers of each discriminator block""" 122 | layers = [nn.Conv2d(in_filters, out_filters, kernel_size=k, stride=s, padding=p)] 123 | if norm: 124 | layers.append(nn.BatchNorm2d(out_filters)) 125 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 126 | if sigmoid: 127 | layers.append(nn.Sigmoid()) 128 | print('use sigmoid') 129 | return layers 130 | 131 | sequence = [*discriminator_block(in_channels*2, 64, norm=False)] # (1,64,128,128) 132 | 133 | assert n_layers<=5 134 | 135 | if (n_layers == 1): 136 | 'when n_layers==1, the patch_size is (16x16)' 137 | out_filters = 64* 2**(n_layers-1) 138 | 139 | elif (1 < n_layers & n_layers<= 4): 140 | ''' 141 | when n_layers==2, the patch_size is (34x34) 142 | when n_layers==3, the patch_size is (70x70), this is the size used in the paper 143 | when n_layers==4, the patch_size is (142x142) 144 | ''' 145 | for k in range(1,n_layers): # k=1,2,3 146 | sequence += [*discriminator_block(2**(5+k), 2**(6+k))] 147 | out_filters = 64* 2**(n_layers-1) 148 | 149 | elif (n_layers == 5): 150 | ''' 151 | when n_layers==5, the patch_size is (286x286), lis larger than the img_size(256), 152 | so this is the whole img condition 153 | ''' 154 | for k in range(1,4): # k=1,2,3 155 | sequence += [*discriminator_block(2**(5+k), 2**(6+k))] 156 | # k=4 157 | sequence += [*discriminator_block(2**9, 2**9)] # 158 | out_filters = 2**9 159 | 160 | num_of_filter = min(2*out_filters, 2**9) 161 | 162 | sequence += [*discriminator_block(out_filters, num_of_filter, k=4, s=1, p=1)] 163 | sequence += [*discriminator_block(num_of_filter, 1, k=4, s=1, p=1, norm=False, sigmoid=True)] 164 | 165 | self.model = nn.Sequential(*sequence) 166 | 167 | def forward(self, img_A, img_B): 168 | # Concatenate image and condition image by channels to produce input 169 | img_input = torch.cat((img_A, img_B), 1) 170 | #print("self.model(img_input): ",self.model(img_input).size()) 171 | return self.model(img_input) 172 | 173 | 174 | #################################################### 175 | # Initialize generator and discriminator 176 | #################################################### 177 | def Create_nets(args): 178 | generator = GeneratorUNet() 179 | discriminator = Discriminator_n_layers(args) 180 | 181 | if torch.cuda.is_available(): 182 | generator = generator.cuda() 183 | discriminator = discriminator.cuda() 184 | 185 | if args.epoch_start != 0: 186 | # Load pretrained models 187 | generator.load_state_dict(torch.load('saved_models/%s/generator_%d.pth' % (args.dataset_name, args.epoch))) 188 | discriminator.load_state_dict(torch.load('saved_models/%s/discriminator_%d.pth' % (args.dataset_name, args.epoch))) 189 | else: 190 | # Initialize weights 191 | generator.apply(weights_init_normal) 192 | discriminator.apply(weights_init_normal) 193 | print_network(generator) 194 | print_network(discriminator) 195 | 196 | return generator, discriminator 197 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # Optimizers 3 | def Get_optimizers(args, generator, discriminator): 4 | optimizer_G = torch.optim.Adam( 5 | generator.parameters(), 6 | lr=args.lr, betas=(args.b1, args.b2)) 7 | optimizer_D = torch.optim.Adam( 8 | discriminator.parameters(), 9 | lr=args.lr, betas=(args.b1, args.b2)) 10 | 11 | return optimizer_G, optimizer_D 12 | # Loss functions 13 | def Get_loss_func(args): 14 | criterion_GAN = torch.nn.BCELoss() 15 | criterion_pixelwise = torch.nn.L1Loss() 16 | if torch.cuda.is_available(): 17 | criterion_GAN.cuda() 18 | criterion_pixelwise.cuda() 19 | return criterion_GAN, criterion_pixelwise 20 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | 5 | class TrainOptions(): 6 | def __init__(self): 7 | self.parser = argparse.ArgumentParser() 8 | self.initialized = False 9 | 10 | def initialize(self): 11 | self.parser.add_argument('--exp_name', type=str, default="Exp0", help='the name of the experiment') 12 | self.parser.add_argument('--epoch_start', type=int, default=0, help='epoch to start training from') 13 | self.parser.add_argument('--epoch_num', type=int, default=200, help='number of epochs of training') 14 | self.parser.add_argument('--data_root', type=str, default='../../data/', help='dir of the dataset') 15 | self.parser.add_argument('--dataset_name', type=str, default="facades", help='name of the dataset') 16 | self.parser.add_argument('--batch_size', type=int, default=1, help='size of the batches') 17 | self.parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate') 18 | self.parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient') 19 | self.parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient') 20 | self.parser.add_argument('--decay_epoch', type=int, default=100, help='epoch from which to start lr decay') 21 | self.parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') 22 | self.parser.add_argument('--img_height', type=int, default=256, help='size of image height') 23 | self.parser.add_argument('--img_width', type=int, default=256, help='size of image width') 24 | self.parser.add_argument('--in_channels', type=int, default=3, help='number of input image channels') 25 | self.parser.add_argument('--out_channels', type=int, default=3, help='number of output image channels') 26 | self.parser.add_argument('--sample_interval', type=int, default=200, help='interval between sampling of images from generators') 27 | self.parser.add_argument('--checkpoint_interval', type=int, default=-1, help='interval between model checkpoints') 28 | self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA') 29 | self.parser.add_argument('--n_D_layers', type=int, default=3, help='used to decision the patch_size in D-net, should less than 8') 30 | self.parser.add_argument('--lambda_pixel', type=int, default=100, help=' Loss weight of L1 pixel-wise loss between translated image and real image') 31 | self.parser.add_argument('--img_result_dir', type=str, default='result_images', help=' where to save the result images') 32 | self.parser.add_argument('--model_result_dir', type=str, default='saved_models', help=' where to save the checkpoints') 33 | 34 | 35 | def parse(self): 36 | if not self.initialized: 37 | self.initialize() 38 | args = self.parser.parse_args() 39 | 40 | os.makedirs('%s-%s/%s' % (args.exp_name, args.dataset_name, args.img_result_dir), exist_ok=True) 41 | os.makedirs('%s-%s/%s' % (args.exp_name, args.dataset_name, args.model_result_dir), exist_ok=True) 42 | 43 | print('------------ Options -------------') 44 | with open("./%s-%s/args.log" % (args.exp_name, args.dataset_name) ,"w") as args_log: 45 | for k, v in sorted(vars(args).items()): 46 | print('%s: %s ' % (str(k), str(v))) 47 | args_log.write('%s: %s \n' % (str(k), str(v))) 48 | 49 | print('-------------- End ----------------') 50 | 51 | 52 | 53 | self.args = args 54 | return self.args 55 | -------------------------------------------------------------------------------- /pix2pix_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import time 5 | import datetime 6 | import sys 7 | 8 | import torch 9 | from torch.autograd import Variable 10 | 11 | from models import Create_nets 12 | from datasets import Get_dataloader 13 | from options import TrainOptions 14 | from optimizer import * 15 | from utils import sample_images , LambdaLR 16 | 17 | 18 | 19 | 20 | #load the args 21 | args = TrainOptions().parse() 22 | # Calculate output of image discriminator (PatchGAN) 23 | D_out_size = 256//(2**args.n_D_layers) - 2 24 | print(D_out_size) 25 | patch = (1, D_out_size, D_out_size) 26 | 27 | # Initialize generator and discriminator 28 | generator, discriminator = Create_nets(args) 29 | # Loss functions 30 | criterion_GAN, criterion_pixelwise = Get_loss_func(args) 31 | # Optimizers 32 | optimizer_G, optimizer_D = Get_optimizers(args, generator, discriminator) 33 | 34 | # Configure dataloaders 35 | train_dataloader,test_dataloader,_ = Get_dataloader(args) 36 | 37 | 38 | # ---------- 39 | # Training 40 | # ---------- 41 | prev_time = time.time() 42 | #Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 43 | for epoch in range(args.epoch_start, args.epoch_num): 44 | for i, batch in enumerate(train_dataloader): 45 | 46 | # Model inputs 47 | real_A = Variable(batch['A'].type(torch.FloatTensor).cuda()) 48 | real_B = Variable(batch['B'].type(torch.FloatTensor).cuda()) 49 | 50 | # Adversarial ground truths 51 | valid = Variable(torch.FloatTensor(np.ones((real_A.size(0), *patch))).cuda(), requires_grad=False) 52 | fake = Variable(torch.FloatTensor(np.zeros((real_A.size(0), *patch))).cuda(), requires_grad=False) 53 | 54 | # Update learning rates 55 | #lr_scheduler_G.step(epoch) 56 | #lr_scheduler_D.step(epoch) 57 | # ------------------ 58 | # Train Generators 59 | # ------------------ 60 | optimizer_G.zero_grad() 61 | 62 | #loss 63 | fake_B = generator(real_A) 64 | pred_fake = discriminator(fake_B, real_A) 65 | #print("pred_fake: ",pred_fake.size(),"valid: ", valid.size()) 66 | loss_GAN = criterion_GAN(pred_fake, valid) 67 | # Pixel-wise loss 68 | loss_pixel = criterion_pixelwise(fake_B, real_B) 69 | 70 | # Total loss 71 | loss_G = loss_GAN + args.lambda_pixel * loss_pixel 72 | loss_G.backward() 73 | optimizer_G.step() 74 | 75 | # --------------------- 76 | # Train Discriminator 77 | # --------------------- 78 | optimizer_D.zero_grad() 79 | # Real loss 80 | pred_real = discriminator(real_B, real_A) 81 | loss_real = criterion_GAN(pred_real, valid) 82 | 83 | # Fake loss 84 | pred_fake = discriminator(fake_B.detach(), real_A) 85 | loss_fake = criterion_GAN(pred_fake, fake) 86 | 87 | # Total loss 88 | loss_D = 0.5 * (loss_real + loss_fake) 89 | loss_D.backward() 90 | optimizer_D.step() 91 | 92 | # -------------- 93 | # Log Progress 94 | # -------------- 95 | # Determine approximate time left 96 | batches_done = epoch * len(train_dataloader) + i 97 | batches_left = args.epoch_num * len(train_dataloader) - batches_done 98 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) 99 | prev_time = time.time() 100 | 101 | # Print log 102 | sys.stdout.write("\r[Epoch%d/%d]-[Batch%d/%d]-[Dloss:%f]-[Gloss:%f, loss_pixel:%f, adv:%f] ETA:%s" % 103 | (epoch+1, args.epoch_num, 104 | i, len(train_dataloader), 105 | loss_D.data.cpu(), loss_G.data.cpu(), 106 | loss_pixel.data.cpu(), loss_GAN.data.cpu(), 107 | time_left)) 108 | 109 | # If at sample interval save image 110 | if batches_done % args.sample_interval == 0: 111 | sample_images(generator, test_dataloader, args, epoch, batches_done) 112 | 113 | 114 | if args.checkpoint_interval != -1 and epoch % args.checkpoint_interval == 0: 115 | # Save model checkpoints 116 | torch.save(generator.state_dict(), '%s/%s/generator_%d.pth' % (args.model_result_dir,args.dataset_name, epoch)) 117 | torch.save(discriminator.state_dict(), '%s/%s/discriminator_%d.pth' % (args.model_result_dir,args.dataset_name, epoch)) 118 | -------------------------------------------------------------------------------- /result/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TeeyoHuang/pix2pix-pytorch/0a1924f397b3f39d60e37fa40111036927bfbf01/result/0.png -------------------------------------------------------------------------------- /result/199.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TeeyoHuang/pix2pix-pytorch/0a1924f397b3f39d60e37fa40111036927bfbf01/result/199.png -------------------------------------------------------------------------------- /result/99.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TeeyoHuang/pix2pix-pytorch/0a1924f397b3f39d60e37fa40111036927bfbf01/result/99.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torchvision.utils import save_image 4 | 5 | def sample_images(generator, test_dataloader, args, epoch, batches_done): 6 | """Saves a generated sample from the validation set""" 7 | imgs = next(iter(test_dataloader)) 8 | real_A = Variable(imgs['A'].type(torch.FloatTensor).cuda()) 9 | real_B = Variable(imgs['B'].type(torch.FloatTensor).cuda()) 10 | fake_B = generator(real_A) 11 | img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2) 12 | save_image(img_sample, '%s-%s/%s/%s-%s.png' % (args.exp_name, args.dataset_name, args.img_result_dir, batches_done, epoch), nrow=5, normalize=True) 13 | 14 | class LambdaLR(): 15 | def __init__(self, epoch_num, epoch_start, decay_start_epoch): 16 | assert ((epoch_num - decay_start_epoch) > 0), "Decay must start before the training session ends!" 17 | self.epoch_num = epoch_num 18 | self.epoch_start = epoch_start 19 | self.decay_start_epoch = decay_start_epoch 20 | 21 | def step(self, epoch): 22 | return 1.0 - max(0, epoch + 1 + self.epoch_start - self.decay_start_epoch)/(self.epoch_num - self.decay_start_epoch) 23 | --------------------------------------------------------------------------------