├── .gitignore ├── README.md ├── input └── train │ ├── 000000000009.jpg │ ├── 000000000025.jpg │ ├── 000000000030.jpg │ ├── 000000000034.jpg │ ├── 000000000036.jpg │ ├── 000000000042.jpg │ ├── 000000000049.jpg │ ├── 000000000061.jpg │ ├── 000000000064.jpg │ └── 000000000071.jpg └── src ├── config.py ├── dataset.py ├── engine.py ├── model.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | input/ 131 | models/ 132 | checkpoint/ 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | This is an experimental implementation of **ChromaGAN: An Adversarial Approach for Picture Colorization** in PyTorch. 3 | 4 | You can find the offical implementation in this link https://github.com/pvitoria/ChromaGAN 5 | You can also read about the research paper at https://arxiv.org/pdf/1907.09837.pdf 6 | -------------------------------------------------------------------------------- /input/train/000000000009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sambaths/ChromaGAN_PyTorch/8379274ce876da3385c2d0dbc7053be172122f2e/input/train/000000000009.jpg -------------------------------------------------------------------------------- /input/train/000000000025.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sambaths/ChromaGAN_PyTorch/8379274ce876da3385c2d0dbc7053be172122f2e/input/train/000000000025.jpg -------------------------------------------------------------------------------- /input/train/000000000030.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sambaths/ChromaGAN_PyTorch/8379274ce876da3385c2d0dbc7053be172122f2e/input/train/000000000030.jpg -------------------------------------------------------------------------------- /input/train/000000000034.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sambaths/ChromaGAN_PyTorch/8379274ce876da3385c2d0dbc7053be172122f2e/input/train/000000000034.jpg -------------------------------------------------------------------------------- /input/train/000000000036.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sambaths/ChromaGAN_PyTorch/8379274ce876da3385c2d0dbc7053be172122f2e/input/train/000000000036.jpg -------------------------------------------------------------------------------- /input/train/000000000042.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sambaths/ChromaGAN_PyTorch/8379274ce876da3385c2d0dbc7053be172122f2e/input/train/000000000042.jpg -------------------------------------------------------------------------------- /input/train/000000000049.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sambaths/ChromaGAN_PyTorch/8379274ce876da3385c2d0dbc7053be172122f2e/input/train/000000000049.jpg -------------------------------------------------------------------------------- /input/train/000000000061.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sambaths/ChromaGAN_PyTorch/8379274ce876da3385c2d0dbc7053be172122f2e/input/train/000000000061.jpg -------------------------------------------------------------------------------- /input/train/000000000064.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sambaths/ChromaGAN_PyTorch/8379274ce876da3385c2d0dbc7053be172122f2e/input/train/000000000064.jpg -------------------------------------------------------------------------------- /input/train/000000000071.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sambaths/ChromaGAN_PyTorch/8379274ce876da3385c2d0dbc7053be172122f2e/input/train/000000000071.jpg -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | USE_TPU = False 2 | MULTI_CORE = False 3 | 4 | import os 5 | import torch 6 | 7 | DATA_DIR = '../input/' 8 | OUT_DIR = '../result/' 9 | MODEL_DIR = '../models/' 10 | CHECKPOINT_DIR = '../checkpoint/' 11 | 12 | TRAIN_DIR = DATA_DIR+"train/" # UPDATE 13 | TEST_DIR = DATA_DIR+"test/" # UPDATE 14 | 15 | os.makedirs(TRAIN_DIR, exist_ok=True) 16 | os.makedirs(TEST_DIR, exist_ok=True) 17 | os.makedirs(MODEL_DIR, exist_ok=True) 18 | os.makedirs(CHECKPOINT_DIR, exist_ok=True) 19 | os.makedirs(OUT_DIR, exist_ok=True) 20 | 21 | # DATA INFORMATION 22 | IMAGE_SIZE = 224 23 | BATCH_SIZE = 1 24 | GRADIENT_PENALTY_WEIGHT = 10 25 | NUM_EPOCHS = 10 26 | KEEP_CKPT = 2 27 | # save_model_path = MODEL_DIR 28 | 29 | if USE_TPU: 30 | import torch_xla.core.xla_model as xm 31 | if not MULTI_CORE: 32 | DEVICE = xm.xla_device() 33 | 34 | if not USE_TPU: 35 | if torch.cuda.is_available(): 36 | DEVICE = torch.device('cuda') 37 | else: 38 | DEVICE = 'cpu' 39 | 40 | 41 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import config 4 | import numpy as np 5 | 6 | 7 | class DATA(): 8 | def __init__(self, dirname, max_len=None): 9 | self.dir_path = dirname 10 | self.filelist = os.listdir(self.dir_path)[:max_len] 11 | self.batch_size = config.BATCH_SIZE 12 | self.size = len(self.filelist) 13 | self.data_index = 0 14 | def __len__(self): 15 | return len(self.filelist) 16 | 17 | def __getitem__(self, item): 18 | img = [] 19 | label = [] 20 | itemfilelist = '' 21 | filename = os.path.join(self.dir_path, self.filelist[item]) 22 | itemfilelist = self.filelist[item] 23 | greyimg, colorimg = self.read_img(filename) 24 | img = greyimg 25 | label = colorimg 26 | img = np.asarray(img)/255 # values between 0 and 1 27 | label = np.asarray(label)/255 # values between 0 and 1 28 | return img, label, itemfilelist 29 | 30 | def read_img(self, filename): 31 | img = cv2.imread(filename, 3) 32 | height, width, channels = img.shape 33 | min_hw = int(min(height,width)/2) 34 | img = img[int(height/2)-min_hw:int(height/2)+min_hw,int(width/2)-min_hw:int(width/2)+min_hw,:] 35 | labimg = cv2.cvtColor(cv2.resize(img, (config.IMAGE_SIZE, config.IMAGE_SIZE)), cv2.COLOR_RGB2Lab) ## Changed BGR to RGB 36 | return np.reshape(labimg[:,:,0], (1, config.IMAGE_SIZE, config.IMAGE_SIZE)), np.reshape(labimg[:, :, 1:], (2,config.IMAGE_SIZE, config.IMAGE_SIZE)) 37 | 38 | def generate_batch(self): 39 | batch = [] 40 | labels = [] 41 | filelist = [] 42 | for i in range(self.batch_size): 43 | filename = os.path.join(self.dir_path, self.filelist[self.data_index]) 44 | filelist.append(self.filelist[self.data_index]) 45 | greyimg, colorimg = self.read_img(filename) 46 | batch.append(greyimg) 47 | labels.append(colorimg) 48 | self.data_index = (self.data_index + 1) % self.size 49 | batch = np.asarray(batch)/255 # values between 0 and 1 50 | labels = np.asarray(labels)/255 # values between 0 and 1 51 | return batch, labels, filelist 52 | -------------------------------------------------------------------------------- /src/engine.py: -------------------------------------------------------------------------------- 1 | import config 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | if config.USE_TPU: 11 | import torch_xla 12 | import torch_xla.core.xla_model as xm 13 | import torch_xla.distributed.parallel_loader as pl 14 | import torch_xla.distributed.xla_multiprocessing as xmp 15 | 16 | 17 | 18 | def train(train_loader, GAN_Model, netD, VGG_MODEL, optG, optD, device, losses): 19 | batch = 0 20 | 21 | def wgan_loss(prediction, real_or_not): 22 | if real_or_not: 23 | return -torch.mean(prediction.float()) 24 | else: 25 | return torch.mean(prediction.float()) 26 | 27 | def gp_loss(y_pred, averaged_samples, gradient_penalty_weight): 28 | 29 | gradients = torch.autograd.grad(y_pred,averaged_samples, 30 | grad_outputs=torch.ones(y_pred.size(), device=device), 31 | create_graph=True, retain_graph=True, only_inputs=True)[0] 32 | gradients = gradients.view(gradients.size(0), -1) 33 | gradient_penalty = (((gradients+1e-16).norm(2, dim=1) - 1) ** 2).mean() * gradient_penalty_weight 34 | return gradient_penalty 35 | for trainL, trainAB, _ in tqdm(iter(train_loader)): 36 | batch += 1 37 | 38 | trainL_3 = torch.tensor(np.tile(trainL.cpu(), [1,3,1,1]), device=device) 39 | 40 | trainL = torch.tensor(trainL, device=device).double() 41 | trainAB = torch.tensor(trainAB, device=device).double() 42 | # trainL_3 = trainL_3.to(device).double() 43 | 44 | predictVGG = F.softmax(VGG_MODEL(trainL_3)) 45 | 46 | ############ GAN MODEL ( Training Generator) ################### 47 | optG.zero_grad() 48 | predAB, classVector, discpred = GAN_Model(trainL, trainL_3) 49 | D_G_z1 = discpred.mean().item() 50 | Loss_KLD = nn.KLDivLoss(size_average='False')(classVector.log().float(), predictVGG.detach().float()) * 0.003 51 | Loss_MSE = nn.MSELoss()(predAB.float(), trainAB.float()) 52 | Loss_WL = wgan_loss(discpred.float(), True) * 0.1 53 | Loss_G = Loss_KLD + Loss_MSE + Loss_WL 54 | Loss_G.backward() 55 | 56 | if config.USE_TPU: 57 | if config.MULTI_CORE: 58 | xm.optimizer_step(optG) 59 | else: 60 | xm.optimizer_step(optG, barrier=True) 61 | else: 62 | optG.step() 63 | 64 | losses['G_losses'].append(Loss_G.item()) 65 | losses['EPOCH_G_losses'].append(Loss_G.item()) 66 | 67 | 68 | 69 | ################################################################ 70 | 71 | ############### Discriminator Training ######################### 72 | 73 | for param in netD.parameters(): 74 | param.requires_grad = True 75 | 76 | optD.zero_grad() 77 | predLAB = torch.cat([trainL, predAB], dim=1) 78 | discpred = netD(predLAB.detach()) 79 | D_G_z2 = discpred.mean().item() 80 | realLAB = torch.cat([trainL, trainAB], dim=1) 81 | discreal = netD(realLAB) 82 | D_x = discreal.mean().item() 83 | 84 | weights = torch.randn((trainAB.size(0),1,1,1), device=device) 85 | averaged_samples = (weights * trainAB ) + ((1 - weights) * predAB.detach()) 86 | averaged_samples = torch.autograd.Variable(averaged_samples, requires_grad=True) 87 | avg_img = torch.cat([trainL, averaged_samples], dim=1) 88 | discavg = netD(avg_img) 89 | 90 | Loss_D_Fake = wgan_loss(discpred, False) 91 | Loss_D_Real = wgan_loss(discreal, True) 92 | Loss_D_avg = gp_loss(discavg, averaged_samples, config.GRADIENT_PENALTY_WEIGHT) 93 | 94 | Loss_D = Loss_D_Fake + Loss_D_Real + Loss_D_avg 95 | Loss_D.backward() 96 | if config.USE_TPU: 97 | if config.MULTI_CORE: 98 | xm.optimzer_step(optD) 99 | else: 100 | xm.optimizer_step(optD, barrier=True) 101 | else: 102 | optD.step() 103 | 104 | losses['D_losses'].append(Loss_D.item()) 105 | losses['EPOCH_D_losses'].append(Loss_D.item()) 106 | # Output training stats 107 | if batch % 100 == 0: 108 | print('Loss_D: %.8f | Loss_G: %.8f | D(x): %.8f | D(G(z)): %.8f / %.8f | MSE: %.8f | KLD: %.8f | WGAN_F(G): %.8f | WGAN_F(D): %.8f | WGAN_R(D): %.8f | WGAN_A(D): %.8f' 109 | % (Loss_D.item(), Loss_G.item(), D_x, D_G_z1, D_G_z2,Loss_MSE.item(),Loss_KLD.item(),Loss_WL.item(), Loss_D_Fake.item(), Loss_D_Real.item(), Loss_D_avg.item())) 110 | 111 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | 7 | bias=True 8 | 9 | class discriminator_model(nn.Module): 10 | 11 | def __init__(self): 12 | super(discriminator_model, self).__init__() 13 | self.conv1 = nn.Conv2d(3, 64, kernel_size=(4,4),padding=1,stride=(2,2),bias=bias) # 64, 112, 112 14 | self.conv2 = nn.Conv2d(64, 128, kernel_size=(4,4), padding=1, stride=(2,2), bias=bias) # 128, 56, 56 15 | self.conv3 = nn.Conv2d(128,256, kernel_size=(4,4), padding=1, stride=(2,2), bias=bias) # 256, 28, 28, 2 16 | self.conv4 = nn.Conv2d(256,512, kernel_size=(4,4), padding=3, stride=(1,1), bias=bias) # 512, 28, 28 17 | self.conv5 = nn.Conv2d(512,1, kernel_size=(4,4), padding=3, stride=(1,1), bias=bias) # 1, 18 | self.leaky_relu = nn.LeakyReLU(0.3) 19 | 20 | def forward(self,input): 21 | 22 | net = self.conv1(input) #[-1, 64, 112, 112] 23 | net = self.leaky_relu(net) #[-1, 64, 112, 112] 24 | net = self.conv2(net) #[-1, 128, 56, 56] 25 | net = self.leaky_relu(net) #[-1, 128, 56, 56] 26 | net = self.conv3(net) #[-1, 256, 28, 28] 27 | net = self.leaky_relu(net) #[-1, 256, 28, 28] 28 | net = self.conv4(net) #[-1, 512, 27, 27] 29 | net = self.leaky_relu(net) #[-1, 512, 27, 27] 30 | net = self.conv5(net) #[-1, 1, 26, 26] 31 | return net 32 | 33 | class colorization_model(nn.Module): 34 | def __init__(self): 35 | super(colorization_model, self).__init__() 36 | 37 | self.VGG_model = torchvision.models.vgg16(pretrained=True) 38 | self.VGG_model = nn.Sequential(*list(self.VGG_model.features.children())[:-8]) #[None, 512, 28, 28] 39 | self.VGG_model = self.VGG_model.double() 40 | self.relu = nn.ReLU() 41 | self.lrelu = nn.LeakyReLU(0.3) 42 | self.global_features_conv1 = nn.Conv2d(512, 512, kernel_size=(3,3), padding=1, stride=(2,2), bias=bias) #[None, 512, 14, 14] 43 | self.global_features_bn1 = nn.BatchNorm2d(512,eps=0.001,momentum=0.99) 44 | self.global_features_conv2 = nn.Conv2d(512, 512, kernel_size=(3,3), padding=1, stride=(1,1), bias=bias) #[None, 512, 14, 14] 45 | self.global_features_bn2 = nn.BatchNorm2d(512,eps=0.001,momentum=0.99) 46 | self.global_features_conv3 = nn.Conv2d(512, 512, kernel_size=(3,3), padding=1, stride=(2,2), bias=bias) #[None, 512, 7, 7] 47 | self.global_features_bn3 = nn.BatchNorm2d(512,eps=0.001,momentum=0.99) 48 | self.global_features_conv4 = nn.Conv2d(512, 512, kernel_size=(3,3), padding=1, stride=(1,1), bias=bias) #[None, 512, 7, 7] 49 | self.global_features_bn4 = nn.BatchNorm2d(512,eps=0.001,momentum=0.99) 50 | 51 | self.global_features2_flatten = nn.Flatten() 52 | self.global_features2_dense1 = nn.Linear(512*7*7,1024) 53 | self.midlevel_conv1 = nn.Conv2d(512,512, kernel_size=(3,3), padding=1, stride=(1,1), bias=bias) #[None, 512, 28, 28] 54 | self.global_features2_dense2 = nn.Linear(1024,512) 55 | self.midlevel_bn1 = nn.BatchNorm2d(512, eps=0.001,momentum=0.99) 56 | self.global_features2_dense3 = nn.Linear(512,256) 57 | self.midlevel_conv2 = nn.Conv2d(512,256, kernel_size=(3,3), padding=1, stride=(1,1), bias=bias) 58 | 59 | self.midlevel_bn2 = nn.BatchNorm2d(256,eps=0.001,momentum=0.99) 60 | 61 | #[None, 256, 28, 28] 62 | # self.midlevel_bn2 = nn.BatchNorm2d(256)#,,eps=0.001,momentum=0.99) 63 | 64 | self.global_featuresClass_flatten = nn.Flatten() 65 | self.global_featuresClass_dense1 = nn.Linear(512*7*7, 4096) 66 | self.global_featuresClass_dense2 = nn.Linear(4096, 4096) 67 | self.global_featuresClass_dense3 = nn.Linear(4096, 1000) 68 | self.softmax = nn.Softmax() 69 | 70 | self.outputmodel_conv1 = nn.Conv2d(512, 256, kernel_size=(1,1), padding=0, stride=(1,1), bias=bias) 71 | self.outputmodel_conv2 = nn.Conv2d(256, 128, kernel_size=(3,3), padding=1, stride=(1,1), bias=bias) 72 | self.outputmodel_conv3 = nn.Conv2d(128, 64, kernel_size=(3,3), padding=1, stride=(1,1), bias=bias) 73 | self.outputmodel_conv4 = nn.Conv2d(64, 64, kernel_size=(3,3), padding=1, stride=(1,1), bias=bias) 74 | self.outputmodel_conv5 = nn.Conv2d(64, 32, kernel_size=(3,3), padding=1, stride=(1,1), bias=bias) 75 | self.outputmodel_conv6 = nn.Conv2d(32, 2, kernel_size=(3,3), padding=1, stride=(1,1), bias=bias) 76 | self.outputmodel_upsample = nn.Upsample(scale_factor=(2,2)) 77 | self.outputmodel_bn1 = nn.BatchNorm2d(128) 78 | self.outputmodel_bn2 = nn.BatchNorm2d(64) 79 | self.sigmoid = nn.Sigmoid() 80 | self.tanh = nn.Tanh() 81 | 82 | def forward(self,input_img): 83 | 84 | # VGG Without Top Layers 85 | 86 | vgg_out = self.VGG_model(torch.tensor(input_img).double()) 87 | 88 | #Global Features 89 | 90 | global_features = self.relu(self.global_features_conv1(vgg_out)) #[None, 512, 14, 14] 91 | global_features = self.global_features_bn1(global_features) #[None, 512, 14, 14] 92 | global_features = self.relu(self.global_features_conv2(global_features)) #[None, 512, 14, 14] 93 | global_features = self.global_features_bn2(global_features) #[None, 512, 14, 14] 94 | 95 | global_features = self.relu(self.global_features_conv3(global_features)) #[None, 512, 7, 7] 96 | global_features = self.global_features_bn3(global_features) #[None, 512, 7, 7] 97 | global_features = self.relu(self.global_features_conv4(global_features)) #[None, 512, 7, 7] 98 | global_features = self.global_features_bn4(global_features) #[None, 512, 7, 7] 99 | 100 | global_features2 = self.global_features2_flatten(global_features) #[None, 512*7*7] 101 | 102 | global_features2 = self.global_features2_dense1(global_features2) #[None, 1024] 103 | global_features2 = self.global_features2_dense2(global_features2) #[None, 512] 104 | global_features2 = self.global_features2_dense3(global_features2) #[None, 256] 105 | global_features2 = global_features2.unsqueeze(2).expand(-1,256,28*28) #[None, 256, 784] 106 | global_features2 = global_features2.view((-1,256,28,28)) #[None, 256, 28, 28] 107 | 108 | global_featureClass = self.global_featuresClass_flatten(global_features) #[None, 512*7*7] 109 | global_featureClass = self.global_featuresClass_dense1(global_featureClass) #[None, 4096] 110 | global_featureClass = self.global_featuresClass_dense2(global_featureClass) #[None, 4096] 111 | global_featureClass = self.softmax(self.global_featuresClass_dense3(global_featureClass))#[None, 1000] 112 | 113 | # Mid Level Features 114 | midlevel_features = self.midlevel_conv1(vgg_out.double()) #[None, 512, 28, 28] 115 | midlevel_features = self.midlevel_bn1(midlevel_features) #[None, 512, 28, 28] 116 | midlevel_features = self.midlevel_conv2(midlevel_features) #[None, 256, 28, 28] 117 | midlevel_features = self.midlevel_bn2(midlevel_features) #[None, 256, 28, 28] 118 | 119 | # Fusion of (VGG16 + MidLevel) + (VGG16 + Global) 120 | 121 | modelFusion = torch.cat([midlevel_features, global_features2],dim=1) 122 | 123 | # Fusion Colorization 124 | 125 | outputmodel = self.relu(self.outputmodel_conv1(modelFusion)) # None, 256, 28, 28 126 | outputmodel = self.relu(self.outputmodel_conv2(outputmodel)) # None, 128, 28, 28 127 | 128 | outputmodel = self.outputmodel_upsample(outputmodel) # None, 128, 56, 56 129 | outputmodel = self.outputmodel_bn1(outputmodel) # None, 128, 56, 56 130 | outputmodel = self.relu(self.outputmodel_conv3(outputmodel)) # None, 64, 56, 56 131 | outputmodel = self.relu(self.outputmodel_conv4(outputmodel)) # None, 64, 56, 56 132 | 133 | outputmodel = self.outputmodel_upsample(outputmodel) # None, 64, 112, 112 134 | outputmodel = self.outputmodel_bn2(outputmodel) # None, 64, 112, 112 135 | outputmodel = self.relu(self.outputmodel_conv5(outputmodel)) # None, 32, 112, 112 136 | outputmodel = self.sigmoid(self.outputmodel_conv6(outputmodel)) # None, 2, 112, 112 137 | outputmodel = self.outputmodel_upsample(outputmodel) # None, 2, 224, 224 138 | 139 | return outputmodel, global_featureClass 140 | 141 | 142 | class GAN(nn.Module): 143 | def __init__(self, netG, netD): 144 | super(GAN, self).__init__() 145 | 146 | self.netG = netG 147 | self.netD = netD 148 | 149 | def forward(self, trainL, trainL_3): 150 | 151 | for param in self.netD.parameters(): 152 | param.requires_grad= False 153 | 154 | predAB, classVector = self.netG(trainL_3) 155 | predLAB = torch.cat([trainL, predAB], dim=1) 156 | discpred = self.netD(predLAB) 157 | 158 | return predAB, classVector, discpred 159 | 160 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | 2 | import model 3 | import config 4 | import dataset 5 | import utils 6 | import engine 7 | 8 | import time 9 | import torch 10 | import torchvision 11 | import warnings 12 | warnings.filterwarnings('ignore') 13 | 14 | import gc 15 | 16 | if config.USE_TPU: 17 | import torch_xla 18 | import torch_xla.core.xla_model as xm 19 | import torch_xla.distributed.parallel_loader as pl 20 | import torch_xla.distributed.xla_multiprocessing as xmp 21 | 22 | 23 | def map_fn(index=None, flags=None): 24 | torch.set_default_tensor_type('torch.FloatTensor') 25 | torch.manual_seed(1234) 26 | 27 | train_data = dataset.DATA(config.TRAIN_DIR) 28 | 29 | if config.MULTI_CORE: 30 | train_sampler = torch.utils.data.distributed.DistributedSampler( 31 | train_data, 32 | num_replicas=xm.xrt_world_size(), 33 | rank=xm.get_ordinal(), 34 | shuffle=True) 35 | else: 36 | train_sampler = torch.utils.data.RandomSampler(train_data) 37 | 38 | train_loader = torch.utils.data.DataLoader( 39 | train_data, 40 | batch_size=flags['batch_size'] if config.MULTI_CORE else config.BATCH_SIZE, 41 | sampler=train_sampler, 42 | num_workers=flags['num_workers'] if config.MULTI_CORE else 4, 43 | drop_last=True, 44 | pin_memory=True) 45 | 46 | if config.MULTI_CORE: 47 | DEVICE = xm.xla_device() 48 | else: 49 | DEVICE = config.DEVICE 50 | 51 | 52 | netG = model.colorization_model().double() 53 | netD = model.discriminator_model().double() 54 | 55 | VGG_modelF = torchvision.models.vgg16(pretrained=True).double() 56 | VGG_modelF.requires_grad_(False) 57 | 58 | netG = netG.to(DEVICE) 59 | netD = netD.to(DEVICE) 60 | 61 | VGG_modelF = VGG_modelF.to(DEVICE) 62 | 63 | optD = torch.optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.999)) 64 | optG = torch.optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.999)) 65 | 66 | ## Trains 67 | train_start = time.time() 68 | losses = { 69 | 'G_losses' : [], 70 | 'D_losses' : [], 71 | 'EPOCH_G_losses' : [], 72 | 'EPOCH_D_losses' : [], 73 | 'G_losses_eval' : [] 74 | } 75 | 76 | netG, optG, netD, optD, epoch_checkpoint = utils.load_checkpoint(config.CHECKPOINT_DIR, netG, optG, netD, optD, DEVICE) 77 | netGAN = model.GAN(netG, netD) 78 | for epoch in range(epoch_checkpoint,flags['num_epochs']+1 if config.MULTI_CORE else config.NUM_EPOCHS+1): 79 | print('\n') 80 | print('#'*8,f'EPOCH-{epoch}','#'*8) 81 | losses['EPOCH_G_losses'] = [] 82 | losses['EPOCH_D_losses'] = [] 83 | if config.MULTI_CORE: 84 | para_train_loader = pl.ParallelLoader(train_loader, [DEVICE]).per_device_loader(DEVICE) 85 | engine.train(para_train_loader, netGAN, netD, VGG_modelF, optG, optD, device=DEVICE, losses=losses) 86 | elapsed_train_time = time.time() - train_start 87 | print("Process", index, "finished training. Train time was:", elapsed_train_time) 88 | else: 89 | engine.train(train_loader, netGAN, netD, VGG_modelF, optG, optD, device=DEVICE, losses=losses) 90 | #########################CHECKPOINTING################################# 91 | utils.create_checkpoint(epoch, netG, optG, netD, optD, max_checkpoint=config.KEEP_CKPT, save_path = config.CHECKPOINT_DIR) 92 | ######################################################################## 93 | utils.plot_some(train_data, netG, DEVICE, epoch) 94 | gc.collect() 95 | # Configures training (and evaluation) parameters 96 | 97 | def run(): 98 | if config.MULTI_CORE: 99 | flags = {} 100 | flags['batch_size'] = config.BATCH_SIZE 101 | flags['num_workers'] = 8 102 | flags['num_epochs'] = config.NUM_EPOCHS 103 | flags['seed'] = 1234 104 | xmp.spawn(map_fn, args=(flags,), nprocs=8, start_method='fork') 105 | else: 106 | map_fn() 107 | # print(flags) 108 | if __name__=='__main__': 109 | run() -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | import torch 5 | import config 6 | import glob 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | if config.USE_TPU: 11 | import torch_xla.core.xla_model as xm 12 | 13 | def preprocess(imgs): 14 | try: 15 | imgs = imgs.detach().numpy() 16 | except: 17 | pass 18 | imgs = imgs * 255 19 | imgs[imgs>255] = 255 20 | imgs[imgs<0] = 0 21 | return imgs.astype(np.uint8) # torch.unit8 22 | 23 | def reconstruct(batchX, predictedY, filelist): 24 | 25 | batchX = batchX.reshape(224,224,1) 26 | predictedY = predictedY.reshape(224,224,2) 27 | result = np.concatenate((batchX, predictedY), axis=2) 28 | result = cv2.cvtColor(result, cv2.COLOR_Lab2RGB) 29 | save_results_path = config.OUT_DIR 30 | if not os.path.exists(save_results_path): 31 | os.makedirs(save_results_path) 32 | save_path = os.path.join(save_results_path, filelist + "_reconstructed.jpg" ) 33 | cv2.imwrite(save_path, result) 34 | return result 35 | 36 | def reconstruct_no(batchX, predictedY): 37 | 38 | batchX = batchX.reshape(224,224,1) 39 | predictedY = predictedY.reshape(224,224,2) 40 | 41 | result = np.concatenate((batchX, predictedY), axis=2) 42 | result = cv2.cvtColor(result, cv2.COLOR_Lab2RGB) 43 | return result 44 | 45 | 46 | def imag_gird(axrow, orig, batchL, preds, epoch): 47 | fig , ax = plt.subplots(1,3, figsize=(15,15)) 48 | ax[0].imshow(orig) 49 | ax[0].set_title('Original Image') 50 | 51 | ax[1].imshow(np.tile(batchL,(1,1,3))) 52 | ax[1].set_title('L Image with Channels reapeated(Input)') 53 | 54 | ax[2].imshow(preds) 55 | ax[2].set_title('Pred Image') 56 | plt.savefig(f'sample_preds_{epoch}') 57 | plt.close() 58 | # plt.show() 59 | 60 | def plot_some(test_data, colorization_model, device, epoch): 61 | with torch.no_grad(): 62 | indexes = [0, 2, 9] 63 | for idx in indexes: 64 | # for batch in range(TOTAL_TEST_BATCH): 65 | #torch.randint(0, len(test_data), (1,)).item() 66 | # idx= 67 | batchL, realAB, filename = test_data[idx] 68 | filepath = config.TRAIN_DIR+filename 69 | batchL = batchL.reshape(1,1,224,224) 70 | realAB = realAB.reshape(1,2,224,224) 71 | batchL_3 = torch.tensor(np.tile(batchL, [1, 3, 1, 1])) 72 | batchL_3 = batchL_3.to(device) 73 | batchL = torch.tensor(batchL).to(device).double() 74 | realAB = torch.tensor(realAB).to(device).double() 75 | 76 | colorization_model.eval() 77 | batch_predAB, _ = colorization_model(batchL_3) 78 | img = cv2.imread(filepath) 79 | batch_predAB = batch_predAB.cpu().numpy().reshape((224,224,2)) 80 | batchL = batchL.cpu().numpy().reshape((224,224,1)) 81 | realAB = realAB.cpu().numpy().reshape((224,224,2)) 82 | orig = cv2.imread(filepath) 83 | orig = cv2.resize(cv2.cvtColor(orig, cv2.COLOR_BGR2RGB), (224,224)) 84 | # orig = reconstruct_no(preprocess(batchL), preprocess(realAB)) 85 | preds = reconstruct_no(preprocess(batchL), preprocess(batch_predAB)) 86 | imag_gird(0, orig, batchL, preds, epoch) 87 | plt.show() 88 | 89 | def create_checkpoint(epoch, netG, optG, netD, optD, max_checkpoint, save_path=config.CHECKPOINT_DIR): 90 | print('Saving Model and Optimizer weights.....') 91 | checkpoint = { 92 | 'epoch' : epoch, 93 | 'generator_state_dict' :netG.state_dict(), 94 | 'generator_optimizer': optG.state_dict(), 95 | 'discriminator_state_dict': netD.state_dict(), 96 | 'discriminator_optimizer': optD.state_dict() 97 | } 98 | if config.USE_TPU: 99 | xm.save(checkpoint, f'{save_path}{epoch}_checkpoint.pt') 100 | else: 101 | torch.save(checkpoint, f'{save_path}{epoch}_checkpoint.pt') 102 | print('Weights Saved !!') 103 | del checkpoint 104 | files = glob.glob(os.path.expanduser(f"{save_path}*")) 105 | sorted_files = sorted(files, key=lambda t: -os.stat(t).st_mtime) 106 | if len(sorted_files) > max_checkpoint: 107 | os.remove(sorted_files[-1]) 108 | 109 | 110 | 111 | def load_checkpoint(checkpoint_directory, netG, optG, netD, optD, device): 112 | load_from_checkpoint = False 113 | files = glob.glob(os.path.expanduser(f"{checkpoint_directory}*")) 114 | for file in files: 115 | if file.endswith('.pt'): 116 | load_from_checkpoint=True 117 | break 118 | 119 | if load_from_checkpoint: 120 | print('Loading Model and optimizer states from checkpoint....') 121 | sorted_files = sorted(files, key=lambda t: -os.stat(t).st_mtime) 122 | checkpoint = torch.load(f'{sorted_files[0]}') 123 | epoch_checkpoint = checkpoint['epoch'] + 1 124 | netG.load_state_dict(checkpoint['generator_state_dict']) 125 | netG.to(device) 126 | 127 | optG.load_state_dict(checkpoint['generator_optimizer']) 128 | 129 | netD.load_state_dict(checkpoint['discriminator_state_dict']) 130 | netD.to(device) 131 | 132 | optD.load_state_dict(checkpoint['discriminator_optimizer']) 133 | print('Loaded States !!!') 134 | print(f'It looks like the this states belong to epoch {epoch_checkpoint-1}.') 135 | print(f'so the model will train for {config.NUM_EPOCHS - (epoch_checkpoint-1)} more epochs.') 136 | print(f'If you want to train for more epochs, change the "NUM_EPOCHS" in config.py !!') 137 | 138 | 139 | return netG, optG, netD, optD, epoch_checkpoint 140 | else: 141 | print('There are no checkpoints in the mentioned directoy, the Model will train from scratch.') 142 | epoch_checkpoint = 1 143 | return netG, optG, netD, optD, epoch_checkpoint 144 | 145 | 146 | 147 | def plot_gan_loss(G_losses, D_losses): 148 | plt.figure(figsize=(10,5)) 149 | plt.title(f"Generator and Discriminator Loss During Training ") 150 | plt.plot(G_losses,label="G") 151 | plt.plot(D_losses,label="D") 152 | plt.xlabel("iterations") 153 | plt.ylabel("Loss") 154 | plt.legend() 155 | plt.savefig(f'GANLOSS{epoch}.pdf',figsize=(30,30)) 156 | --------------------------------------------------------------------------------