├── .gitignore ├── LICENSE ├── README.md ├── data ├── .DS_Store ├── .ipynb_checkpoints │ └── ThinPlateSpline-checkpoint.ipynb └── tag2pix │ ├── keras_train │ ├── 3175012.png │ ├── 3175015.png │ ├── 3175017.png │ ├── 3175018.png │ ├── 3175027.png │ ├── 3175036.png │ ├── 3175044.png │ ├── 3175064.png │ ├── 3175131.png │ ├── 3175133.png │ ├── 3175134.png │ ├── 3175156.png │ ├── 3175209.png │ ├── 3175210.png │ ├── 3175211.png │ ├── 3175233.png │ ├── 3175234.png │ ├── 3175261.png │ ├── 999257.png │ └── 999262.png │ ├── rgb_cropped │ ├── 3175012.png │ ├── 3175015.png │ ├── 3175017.png │ ├── 3175018.png │ ├── 3175027.png │ ├── 3175036.png │ ├── 3175044.png │ ├── 3175064.png │ ├── 3175131.png │ ├── 3175133.png │ ├── 3175134.png │ ├── 3175156.png │ ├── 3175209.png │ ├── 3175210.png │ ├── 3175211.png │ ├── 3175233.png │ ├── 3175234.png │ ├── 3175261.png │ ├── 999257.png │ └── 999262.png │ └── xdog_train │ ├── 3175012.png │ ├── 3175015.png │ ├── 3175017.png │ ├── 3175018.png │ ├── 3175027.png │ ├── 3175036.png │ ├── 3175044.png │ ├── 3175064.png │ ├── 3175131.png │ ├── 3175133.png │ ├── 3175134.png │ ├── 3175156.png │ ├── 3175209.png │ ├── 3175210.png │ ├── 3175211.png │ ├── 3175233.png │ ├── 3175234.png │ ├── 3175261.png │ ├── 999257.png │ └── 999262.png ├── data_loader.py ├── loss.py ├── main.py ├── model.py ├── preprocess ├── 17938.png ├── __init__.py ├── __pycache__ │ └── xdog_blend.cpython-37.pyc ├── sketch_keras_util.py ├── xdog.py └── xdog_blend.py ├── scripts ├── train_tag2pix_keras.sh └── train_tag2pix_xdog.sh ├── solver.py ├── tmp.png ├── tps ├── __pycache__ │ └── numpy.cpython-37.pyc ├── numpy.py └── pytorch.py └── vgg.py /.gitignore: -------------------------------------------------------------------------------- 1 | results/ 2 | udon.sh 3 | deploy.sh 4 | .DS_Store 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Daichi Horita 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reference-Based Sketch Image Colorization using Augmented-Self Reference and Dense Semantic Correspondence[Lee+, CVPR20] 2 | https://openaccess.thecvf.com/content_CVPR_2020/papers/Lee_Reference-Based_Sketch_Image_Colorization_Using_Augmented-Self_Reference_and_Dense_Semantic_CVPR_2020_paper.pdf 3 | 4 | **Note that this is an ongoing re-implementation and I cannot fully reproduce the results. Suggestions and PRs are welcome!** 5 | 6 | # Requirements 7 | + Python 3.6+ 8 | + PyTorch 0.4+ 9 | 10 | # Usage 11 | 1. Download [a Tag2Pix dataset from the officical repsitory.](https://github.com/blandocs/Tag2Pix). 12 | 2. Put it on `./datasets/tag2pix` 13 | 3. Run `bash scripts/train_tag2pix_xdog.sh baseline`. The training using sketches by XDoG will run. 14 | 4. Run `bash scripts/train_tag2pix_keras.sh baseline`. The training using sketches by SketchKeras will run. 15 | 16 | 17 | # LICENCE 18 | All code is licensed under the MIT license. 19 | 20 | # RELATED WORKS 21 | + Re-implementation: https://github.com/SerialLain3170/Colorization/tree/master/scft 22 | + https://github.com/MarkMoHR/Awesome-Image-Colorization 23 | 24 | # Acknowledgements 25 | This repository is based on https://github.com/yunjey/stargan. 26 | 27 | Additionally, if you use this repository, please cite original paper 28 | ``` 29 | @InProceedings{lee2020referencebased, 30 | title={Reference-Based Sketch Image Colorization using Augmented-Self Reference and Dense Semantic Correspondence}, 31 | author={Junsoo Lee and Eungyeup Kim and Yunsung Lee and Dongjun Kim and Jaehyuk Chang and Jaegul Choo}, 32 | year={2020}, 33 | booktitle = {Proc. IEEE Computer Vision and Pattern Recognition (CVPR)} 34 | } 35 | ``` 36 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/.DS_Store -------------------------------------------------------------------------------- /data/tag2pix/keras_train/3175012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/3175012.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/3175015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/3175015.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/3175017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/3175017.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/3175018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/3175018.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/3175027.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/3175027.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/3175036.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/3175036.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/3175044.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/3175044.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/3175064.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/3175064.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/3175131.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/3175131.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/3175133.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/3175133.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/3175134.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/3175134.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/3175156.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/3175156.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/3175209.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/3175209.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/3175210.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/3175210.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/3175211.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/3175211.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/3175233.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/3175233.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/3175234.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/3175234.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/3175261.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/3175261.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/999257.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/999257.png -------------------------------------------------------------------------------- /data/tag2pix/keras_train/999262.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/keras_train/999262.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/3175012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/3175012.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/3175015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/3175015.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/3175017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/3175017.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/3175018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/3175018.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/3175027.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/3175027.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/3175036.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/3175036.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/3175044.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/3175044.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/3175064.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/3175064.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/3175131.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/3175131.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/3175133.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/3175133.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/3175134.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/3175134.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/3175156.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/3175156.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/3175209.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/3175209.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/3175210.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/3175210.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/3175211.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/3175211.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/3175233.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/3175233.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/3175234.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/3175234.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/3175261.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/3175261.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/999257.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/999257.png -------------------------------------------------------------------------------- /data/tag2pix/rgb_cropped/999262.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/rgb_cropped/999262.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/3175012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/3175012.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/3175015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/3175015.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/3175017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/3175017.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/3175018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/3175018.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/3175027.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/3175027.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/3175036.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/3175036.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/3175044.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/3175044.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/3175064.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/3175064.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/3175131.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/3175131.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/3175133.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/3175133.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/3175134.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/3175134.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/3175156.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/3175156.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/3175209.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/3175209.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/3175210.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/3175210.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/3175211.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/3175211.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/3175233.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/3175233.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/3175234.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/3175234.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/3175261.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/3175261.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/999257.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/999257.png -------------------------------------------------------------------------------- /data/tag2pix/xdog_train/999262.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/data/tag2pix/xdog_train/999262.png -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torchvision.transforms import functional as tvF 6 | from torch.utils import data 7 | from torchvision import transforms as T 8 | from torchvision.datasets import ImageFolder 9 | from PIL import Image 10 | import torch 11 | import os 12 | import random 13 | import albumentations as A 14 | from sys import exit 15 | import math 16 | from tps.numpy import warp_image_cv 17 | 18 | def rot_crop(x): 19 | """return maximum width ratio of rotated image without letterbox""" 20 | x = abs(x) 21 | deg45 = math.pi * 0.25 22 | deg135 = math.pi * 0.75 23 | x = x * math.pi / 180 24 | a = (math.sin(deg135 - x) - math.sin(deg45 - x))/(math.cos(deg135-x)-math.cos(deg45-x)) 25 | return math.sqrt(2) * (math.sin(deg45-x) - a*math.cos(deg45-x)) / (1-a) 26 | 27 | class RandomFRC(T.RandomResizedCrop): 28 | """RandomHorizontalFlip + RandomRotation + RandomResizedCrop 2 images""" 29 | def __call__(self, img1, img2): 30 | img1 = tvF.resize(img1, self.size, interpolation=Image.LANCZOS) 31 | img2 = tvF.resize(img2, self.size, interpolation=Image.LANCZOS) 32 | if random.random() < 0.5: 33 | img1 = tvF.hflip(img1) 34 | img2 = tvF.hflip(img2) 35 | if random.random() < 0.5: 36 | rot = random.uniform(-10, 10) 37 | crop_ratio = rot_crop(rot) 38 | img1 = tvF.rotate(img1, rot, resample=Image.BILINEAR) 39 | img2 = tvF.rotate(img2, rot, resample=Image.BILINEAR) 40 | img1 = tvF.center_crop(img1, int(img1.size[0] * crop_ratio)) 41 | img2 = tvF.center_crop(img2, int(img2.size[0] * crop_ratio)) 42 | 43 | i, j, h, w = self.get_params(img1, self.scale, self.ratio) 44 | 45 | # return the image with the same transformation 46 | return (tvF.resized_crop(img1, i, j, h, w, self.size, self.interpolation), 47 | tvF.resized_crop(img2, i, j, h, w, self.size, self.interpolation)) 48 | 49 | class Dataset(data.Dataset): 50 | def __init__(self, image_dir, line_dir, transform_common, transform_a, transform_line, transform_original): 51 | 52 | self.image_dir = image_dir 53 | self.line_dir = line_dir 54 | 55 | self.transform_common = transform_common 56 | self.transform_a = transform_a 57 | self.transform_line = transform_line 58 | self.transform_original = transform_original 59 | 60 | self.ids = [f.split('/')[-1] for f in glob(os.path.join(line_dir, '*.png'))] 61 | 62 | def __getitem__(self, index): 63 | filename = self.ids[index] 64 | 65 | image_path = os.path.join(self.image_dir, filename) 66 | line_path = os.path.join(self.line_dir, filename) 67 | 68 | image = Image.open(image_path).convert('RGB') 69 | line = Image.open(line_path).convert('L') 70 | 71 | image_ori, line = self.transform_common(image, line) 72 | 73 | I_original = self.transform_original(image_ori) 74 | 75 | I_gt = self.transform_a(image_ori) 76 | 77 | line = self.transform_line(line) 78 | 79 | if random.random() <= 0.9: 80 | # I_r = TPS(I_gt) 81 | I_r = TPS(I_gt.unsqueeze(0)).squeeze() 82 | # I_r = I_gt.clone() 83 | else: 84 | I_r = torch.zeros(I_gt.size()) 85 | 86 | return I_original, I_gt, I_r, line 87 | 88 | 89 | def __len__(self): 90 | return len(self.ids) 91 | 92 | # def TPS(x): 93 | # Pseudo original implementation... But, I do not believe this is correct. 94 | # c,h,w = x.size() 95 | # x = x.numpy() 96 | # common = np.random.rand(4, 2) 97 | # A = np.random.rand(2, 2) 98 | # B = np.random.rand(2, 2) 99 | # c_src = np.concatenate([common, A], 0) 100 | # c_dst = np.concatenate([common, B], 0) 101 | # warped = warp_image_cv(x.transpose(2, 1, 0), c_src, c_dst, dshape=(h, w)).transpose((2, 0, 1)) # HWC -> CHW 102 | # return torch.from_numpy(warped) 103 | 104 | 105 | def TPS(x): 106 | """ 107 | http://www.mech.tohoku-gakuin.ac.jp/rde/contents/course/robotics/coordtrans.html 108 | """ 109 | def affine_transform(x, theta): 110 | theta = theta.view(-1, 2, 3) 111 | # grid = F.affine_grid(theta, x.size(), align_corners=True) 112 | # x = F.grid_sample(x, grid, align_corners=True) 113 | grid = F.affine_grid(theta, x.size()) 114 | x = F.grid_sample(x, grid) 115 | return x 116 | 117 | theta1 = np.zeros(9) 118 | theta1[0:6] = np.random.randn(6) * 0.15 119 | theta1 = theta1 + np.array([1,0,0,0,1,0,0,0,1]) 120 | affine1 = np.reshape(theta1, (3,3)) 121 | affine1 = np.reshape(affine1, -1)[0:6] 122 | affine1 = torch.from_numpy(affine1).type(torch.FloatTensor) 123 | x = affine_transform(x, affine1) # source image 124 | 125 | return x 126 | 127 | 128 | def get_loader(crop_size=256, image_size=266, batch_size=16, dataset='CelebA', mode='train', num_workers=8, line_type='xdog', ROOT='./datasets'): 129 | """Build and return a data loader.""" 130 | transform_common = [] 131 | transform_a = [] 132 | transform_line = [] 133 | transform_original = [] 134 | 135 | transform_common = RandomFRC(crop_size, scale=(0.9, 1.0), ratio=(0.95, 1.05), interpolation=Image.LANCZOS) 136 | 137 | transform_a = T.Compose([ 138 | T.ColorJitter(brightness=0.05, contrast=0.2, saturation=0.4, hue=0.4), 139 | T.Resize((crop_size, crop_size), interpolation=Image.LANCZOS), 140 | T.ToTensor(), 141 | T.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5)) 142 | ]) 143 | 144 | transform_line = T.Compose([ 145 | T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), 146 | T.Resize((crop_size, crop_size), interpolation=Image.LANCZOS), 147 | T.ToTensor(), 148 | # T.RandomErasing(p=0.9, value=1., scale=(0.02, 0.1)) # For experimental 149 | ]) 150 | 151 | transform_original = T.Compose([ 152 | T.Resize((crop_size, crop_size), interpolation=Image.LANCZOS), 153 | T.ToTensor(), 154 | T.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5)) 155 | ]) 156 | 157 | if dataset == 'line_art': 158 | image_dir = os.path.join(ROOT, 'line_art/train/color') 159 | line_dir = os.path.join(ROOT, 'line_art/train/xdog') 160 | elif dataset == 'tag2pix': 161 | image_dir = os.path.join(ROOT, 'tag2pix/rgb_cropped') 162 | if line_type == 'xdog': 163 | line_dir = os.path.join(ROOT, 'tag2pix/xdog_train') 164 | elif line_type == 'keras': 165 | line_dir = os.path.join(ROOT, 'tag2pix/keras_train') 166 | 167 | dataset = Dataset(image_dir, line_dir, 168 | transform_common, transform_a, transform_line, transform_original) 169 | 170 | data_loader = data.DataLoader(dataset=dataset, 171 | batch_size=batch_size, 172 | shuffle=(mode=='train'), 173 | num_workers=num_workers, 174 | pin_memory=True, 175 | drop_last=True) 176 | return data_loader 177 | 178 | if __name__ == '__main__': 179 | 180 | def denorm(x): 181 | """Convert the range from [-1, 1] to [0, 1].""" 182 | out = (x + 1) / 2 183 | return out.clamp_(0, 1) 184 | 185 | from torchvision.utils import save_image 186 | loader = get_loader( 187 | crop_size=256, 188 | image_size=256, 189 | batch_size=20, 190 | dataset='tag2pix', mode='test', num_workers=4, line_type='xdog', 191 | ROOT='./data' 192 | ) 193 | 194 | loader = iter(loader) 195 | 196 | I_ori, I_gt, I_r, I_s = next(loader) 197 | 198 | I_concat = denorm(torch.cat([I_ori, I_gt, I_r], dim=2)) 199 | I_concat = torch.cat([I_concat, I_s.repeat(1,3,1,1)], dim=2) 200 | 201 | save_image(I_concat, 'tmp.png') -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from vgg import VGG16FeatureExtractor 4 | 5 | 6 | def gram_matrix(feat): 7 | # https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py 8 | (b, ch, h, w) = feat.size() 9 | feat = feat.view(b, ch, h * w) 10 | feat_t = feat.transpose(1, 2) 11 | gram = torch.bmm(feat, feat_t) / (ch * h * w) 12 | return gram 13 | 14 | class VGGLoss(nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | self.l1 = nn.L1Loss() 18 | self.extractor = VGG16FeatureExtractor() 19 | 20 | def forward(self, real, fake): 21 | loss_dict = {} 22 | 23 | feat_real = self.extractor(real) 24 | feat_fake = self.extractor(fake) 25 | 26 | L_prec = 0. 27 | L_style = 0. 28 | for i in range(len(feat_real)): 29 | L_prec += self.l1(feat_real[i], feat_fake[i]) 30 | L_style += self.l1(gram_matrix(feat_real[i]), gram_matrix(feat_fake[i])) 31 | 32 | L_prec = L_prec.mean() 33 | L_style = L_style.mean() 34 | 35 | return L_prec, L_style 36 | 37 | 38 | if __name__ == '__main__': 39 | loss = VGGLoss() 40 | 41 | x = torch.randn(1,3,256,256) 42 | 43 | l1, l2 = loss(x,x) 44 | 45 | print(l1, l2) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import warnings 4 | warnings.simplefilter('ignore') 5 | from solver import Solver 6 | from data_loader import get_loader 7 | from torch.backends import cudnn 8 | 9 | 10 | def str2bool(v): 11 | return v.lower() in ('true') 12 | 13 | def main(config): 14 | # For fast training. 15 | cudnn.benchmark = True 16 | 17 | # Create directories if not exist. 18 | os.makedirs(config.log_dir, exist_ok=True) 19 | os.makedirs(config.model_save_dir, exist_ok=True) 20 | os.makedirs(config.sample_dir, exist_ok=True) 21 | 22 | 23 | data_loader = get_loader(config.crop_size, config.image_size, config.batch_size, 24 | config.dataset, config.mode, config.num_workers, config.line_type) 25 | 26 | solver = Solver(data_loader, config) 27 | 28 | if config.mode == 'train': 29 | solver.train() 30 | # elif config.mode == 'test': 31 | # solver.test() 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser() 36 | 37 | # Model configuration. 38 | parser.add_argument('--crop_size', type=int, default=256, help='crop size for the CelebA dataset') 39 | parser.add_argument('--image_size', type=int, default=276, help='image resolution') 40 | parser.add_argument('--g_conv_dim', type=int, default=16, help='number of conv filters in the first layer of G') 41 | parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D') 42 | parser.add_argument('--d_channel', type=int, default=448) 43 | parser.add_argument('--channel_1x1', type=int, default=256) 44 | parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D') 45 | parser.add_argument('--lambda_rec', type=float, default=30, help='weight for reconstruction loss') 46 | parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty') 47 | parser.add_argument('--lambda_perc', type=float, default=0.01) 48 | parser.add_argument('--lambda_style', type=float, default=50) 49 | parser.add_argument('--lambda_tr', type=float, default=1) 50 | 51 | # Training configuration. 52 | parser.add_argument('--dataset', type=str, default='line_art') # , choices=['line_art, tag2pix'] 53 | parser.add_argument('--line_type', type=str, default='xdog') # , choices=['xdog, keras'] 54 | parser.add_argument('--batch_size', type=int, default=16, help='mini-batch size') 55 | parser.add_argument('--num_epoch', type=int, default=200, help='number of total iterations for training D') 56 | parser.add_argument('--num_epoch_decay', type=int, default=100, help='number of iterations for decaying lr') 57 | parser.add_argument('--g_lr', type=float, default=0.0002, help='learning rate for G') # Note that original paper is set to 0.0001. 58 | parser.add_argument('--d_lr', type=float, default=0.0002, help='learning rate for D') 59 | parser.add_argument('--n_critic', type=int, default=1, help='number of D updates per each G update') 60 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer') 61 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer') 62 | 63 | # Test configuration. 64 | parser.add_argument('--test_epoch', type=int, default=200000, help='test model from this step') 65 | 66 | # Miscellaneous. 67 | parser.add_argument('--num_workers', type=int, default=8) 68 | parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) 69 | 70 | # Directories. 71 | parser.add_argument('--result_dir', type=str, default='results') 72 | parser.add_argument('--exp_name', type=str, default='baseline') 73 | 74 | # Step size. 75 | parser.add_argument('--log_step', type=int, default=200) 76 | parser.add_argument('--sample_epoch', type=int, default=1) 77 | parser.add_argument('--model_save_step', type=int, default=40) 78 | 79 | config = parser.parse_args() 80 | config.log_dir = os.path.join(config.result_dir, config.exp_name, 'log') 81 | config.sample_dir = os.path.join(config.result_dir, config.exp_name, config.exp_name) 82 | config.model_save_dir = os.path.join(config.result_dir, config.exp_name, 'model') 83 | print(config) 84 | main(config) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # The detail shows in below supplementaly. 2 | # https://openaccess.thecvf.com/content_CVPR_2020/supplemental/Lee_Reference-Based_Sketch_Image_CVPR_2020_supplemental.pdf 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from torch.nn.utils import spectral_norm 8 | 9 | 10 | class ResidualBlock(nn.Module): 11 | """Residual Block with instance normalization.""" 12 | def __init__(self, dim_in, dim_out): 13 | super(ResidualBlock, self).__init__() 14 | self.main = nn.Sequential( 15 | spectral_norm(nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False)), 16 | nn.InstanceNorm2d(dim_out, affine=True), 17 | nn.ReLU(inplace=True), 18 | spectral_norm(nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False)), 19 | nn.InstanceNorm2d(dim_out, affine=True)) 20 | 21 | def forward(self, x): 22 | return x + self.main(x) 23 | 24 | 25 | class Generator(nn.Module): 26 | """Generator network.""" 27 | def __init__(self, conv_dim=16, d_channel=448, channel_1x1=256): 28 | super(Generator, self).__init__() 29 | 30 | self.enc_r = Encoder(in_channel=3, conv_dim=conv_dim) 31 | self.enc_s = Encoder(in_channel=1, conv_dim=conv_dim) 32 | 33 | self.scft = SCFT(d_channel=d_channel) 34 | 35 | self.conv1x1 = spectral_norm(nn.Conv2d(d_channel, channel_1x1, kernel_size=1, stride=1, padding=0)) 36 | 37 | self.resblocks = nn.Sequential( 38 | ResidualBlock(channel_1x1, channel_1x1), 39 | ResidualBlock(channel_1x1, channel_1x1), 40 | ResidualBlock(channel_1x1, channel_1x1), 41 | ResidualBlock(channel_1x1, channel_1x1), 42 | ) 43 | 44 | 45 | self.decoder = Decoder(in_channel=channel_1x1) 46 | self.activation = nn.Tanh() 47 | 48 | def forward(self, I_r, I_s, IsGTrain=False): 49 | v_r = self.enc_r(I_r) 50 | v_s, I_s_f1_9 = self.enc_s(I_s) 51 | 52 | f_scft, L_tr = self.scft(v_r, v_s) 53 | 54 | f_encoded = self.conv1x1(f_scft) 55 | 56 | f_encoded = self.resblocks(f_encoded) + f_encoded # [1,512,32,32] 57 | 58 | f_out = self.decoder(f_encoded, I_s_f1_9) 59 | 60 | if IsGTrain: 61 | return self.activation(f_out), L_tr 62 | 63 | return self.activation(f_out) 64 | 65 | class Discriminator(nn.Module): 66 | """Discriminator network with PatchGAN.""" 67 | def __init__(self, image_size=128, conv_dim=64, repeat_num=6): 68 | super(Discriminator, self).__init__() 69 | layers = [] 70 | layers.append(spectral_norm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))) 71 | layers.append(nn.LeakyReLU(0.01)) 72 | 73 | curr_dim = conv_dim 74 | for i in range(1, repeat_num): 75 | layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))) 76 | layers.append(nn.LeakyReLU(0.01)) 77 | curr_dim = curr_dim * 2 78 | 79 | kernel_size = int(image_size / np.power(2, repeat_num)) 80 | self.main = nn.Sequential(*layers) 81 | # self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False) 82 | self.conv1 = spectral_norm(nn.Conv2d(curr_dim, 1, kernel_size=kernel_size, bias=False)) 83 | 84 | def forward(self, x): 85 | h = self.main(x) 86 | out = self.conv1(h) 87 | return out.squeeze() 88 | 89 | 90 | class Encoder(nn.Module): 91 | """Encoder network.""" 92 | def __init__(self, in_channel=1, conv_dim=16): 93 | super(Encoder, self).__init__() 94 | 95 | self.conv1 = self.conv_block(in_channel, conv_dim, K=3, S=1, P=1) 96 | self.conv2 = self.conv_block(conv_dim, conv_dim, K=3, S=1, P=1) 97 | self.conv3 = self.conv_block(conv_dim, conv_dim*2, K=3, S=2, P=1) 98 | self.conv4 = self.conv_block(conv_dim*2, conv_dim*2, K=3, S=1, P=1) 99 | self.conv5 = self.conv_block(conv_dim*2, conv_dim*4, K=3, S=1, P=1) 100 | self.conv6 = self.conv_block(conv_dim*4, conv_dim*4, K=3, S=1, P=1) 101 | self.conv7 = self.conv_block(conv_dim*4, conv_dim*8, K=3, S=2, P=1) 102 | self.conv8 = self.conv_block(conv_dim*8, conv_dim*8, K=3, S=1, P=1) 103 | self.conv9 = self.conv_block(conv_dim*8, conv_dim*16, K=3, S=2, P=1) 104 | self.conv10 = self.conv_block(conv_dim*16, conv_dim*16, K=3, S=1, P=1) 105 | 106 | if in_channel == 3: 107 | self.mode = 'E_r' 108 | elif in_channel == 1: 109 | self.mode = 'E_s' 110 | else: 111 | raise NotImplementedError 112 | 113 | def conv_block(self, C_in, C_out, K=3, S=1, P=1): 114 | return nn.Sequential( 115 | # nn.Conv2d(C_in, C_out, kernel_size=K, stride=S, padding=P), 116 | spectral_norm(nn.Conv2d(C_in, C_out, kernel_size=K, stride=S, padding=P)), 117 | nn.InstanceNorm2d(C_out, affine=True), 118 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 119 | ) 120 | 121 | def forward(self, x): # x: [1,3,256,256] 122 | bs = x.size(0) 123 | 124 | f1 = self.conv1(x) # [1, 16, 256, 256] 125 | f2 = self.conv2(f1) # [1, 16, 256, 256] 126 | f3 = self.conv3(f2) # [1, 32, 128, 128] 127 | f4 = self.conv4(f3) # [1, 32, 128, 128] 128 | f5 = self.conv5(f4) # [1, 64, 128, 128] 129 | f6 = self.conv6(f5) # [1, 64, 128, 128] 130 | f7 = self.conv7(f6) # [1, 128, 64, 64] 131 | f8 = self.conv8(f7) # [1, 128, 64, 64] 132 | f9 = self.conv9(f8) # [1, 256, 32, 32] 133 | f10 = self.conv10(f9) # [1, 256, 32, 32] 134 | 135 | fs = (f1,f2,f3,f4,f5,f6,f7,f8,f9,f10) 136 | 137 | # f1 = F.interpolate(f1, scale_factor=f10.size(2)/f1.size(2), mode='nearest', recompute_scale_factor=True) # [1, 16, 32, 32] 138 | # f2 = F.interpolate(f2, scale_factor=f10.size(2)/f2.size(2), mode='nearest', recompute_scale_factor=True) # [1, 16, 32, 32] 139 | # f3 = F.interpolate(f3, scale_factor=f10.size(2)/f3.size(2), mode='nearest', recompute_scale_factor=True) # [1, 32, 32, 32] 140 | # f4 = F.interpolate(f4, scale_factor=f10.size(2)/f4.size(2), mode='nearest', recompute_scale_factor=True) # [1, 32, 32, 32] 141 | # f5 = F.interpolate(f5, scale_factor=f10.size(2)/f5.size(2), mode='nearest', recompute_scale_factor=True) # [1, 64, 32, 32] 142 | # f6 = F.interpolate(f6, scale_factor=f10.size(2)/f6.size(2), mode='nearest', recompute_scale_factor=True) # [1, 64, 32, 32] 143 | # f7 = F.interpolate(f7, scale_factor=f10.size(2)/f7.size(2), mode='nearest', recompute_scale_factor=True) # [1, 128, 32, 32] 144 | # f8 = F.interpolate(f8, scale_factor=f10.size(2)/f8.size(2), mode='nearest', recompute_scale_factor=True) # [1, 128, 32, 32] 145 | # f9 = F.interpolate(f9, scale_factor=f10.size(2)/f9.size(2), mode='nearest', recompute_scale_factor=True) # [1, 256, 32, 32] 146 | 147 | f1 = F.interpolate(f1, scale_factor=f10.size(2)/f1.size(2), mode='nearest') # [1, 16, 32, 32] 148 | f2 = F.interpolate(f2, scale_factor=f10.size(2)/f2.size(2), mode='nearest') # [1, 16, 32, 32] 149 | f3 = F.interpolate(f3, scale_factor=f10.size(2)/f3.size(2), mode='nearest') # [1, 32, 32, 32] 150 | f4 = F.interpolate(f4, scale_factor=f10.size(2)/f4.size(2), mode='nearest') # [1, 32, 32, 32] 151 | f5 = F.interpolate(f5, scale_factor=f10.size(2)/f5.size(2), mode='nearest') # [1, 64, 32, 32] 152 | f6 = F.interpolate(f6, scale_factor=f10.size(2)/f6.size(2), mode='nearest') # [1, 64, 32, 32] 153 | f7 = F.interpolate(f7, scale_factor=f10.size(2)/f7.size(2), mode='nearest') # [1, 128, 32, 32] 154 | f8 = F.interpolate(f8, scale_factor=f10.size(2)/f8.size(2), mode='nearest') # [1, 128, 32, 32] 155 | f9 = F.interpolate(f9, scale_factor=f10.size(2)/f9.size(2), mode='nearest') # [1, 256, 32, 32] 156 | 157 | V = torch.cat([f6,f8,f10], dim=1) # Eq.(1) # 64+128+256=448, [1, 448, 32, 32] 158 | V_bar = V.view(bs, V.size(1), -1) # [1, 448, 1024] 159 | 160 | 161 | if self.mode == 'E_r': 162 | return V_bar 163 | elif self.mode == 'E_s': 164 | return V_bar, fs 165 | 166 | class SCFT(nn.Module): 167 | def __init__(self, d_channel=448): 168 | super(SCFT, self).__init__() 169 | 170 | self.W_v = nn.Parameter(torch.randn(d_channel, d_channel)) # [448, 448] 171 | self.W_k = nn.Parameter(torch.randn(d_channel, d_channel)) # [448, 448] 172 | self.W_q = nn.Parameter(torch.randn(d_channel, d_channel)) # [448, 448] 173 | self.coef = d_channel ** .5 174 | 175 | self.gamma = 12. 176 | 177 | def forward(self, V_r, V_s): 178 | 179 | wq_vs = torch.matmul(self.W_q, V_s) # [1, 448, 1024] 180 | wk_vr = torch.matmul(self.W_k, V_r).permute(0, 2, 1) # [1, 448, 1024] 181 | alpha = F.softmax(torch.matmul(wq_vs, wk_vr) / self.coef, dim=-1) # Eq.(2) 182 | 183 | wv_vr = torch.matmul(self.W_v, V_r) 184 | v_asta = torch.matmul(alpha, wv_vr) # [1, 448, 1024] # Eq.(3) 185 | 186 | c_i = V_s + v_asta # [1, 448, 1024] # Eq.(4) 187 | 188 | bs,c,hw = c_i.size() 189 | spatial_c_i = torch.reshape(c_i.unsqueeze(-1), (bs,c,int(hw**0.5), int(hw**0.5))) # [1, 448, 32, 32] 190 | 191 | # Similarity-Based Triplet Loss 192 | a = wk_vr[0, :, :].detach().clone() 193 | b = wk_vr[1:, :, :].detach().clone() 194 | wk_vr_neg = torch.cat((b, a.unsqueeze(0))) 195 | alpha_negative = F.softmax(torch.matmul(wq_vs, wk_vr_neg) / self.coef, dim=-1) 196 | v_negative = torch.matmul(alpha_negative, wv_vr) 197 | 198 | L_tr = F.relu(-v_asta + v_negative + self.gamma).mean() 199 | 200 | return spatial_c_i, L_tr 201 | 202 | 203 | class Decoder(nn.Module): 204 | def __init__(self, in_channel=256, out_channel=3): 205 | super(Decoder, self).__init__() 206 | 207 | self.deconv1 = self.conv_block(in_channel, in_channel) 208 | self.deconv2 = self.conv_block(in_channel, in_channel//2, upsample=True) 209 | in_channel //= 2 210 | self.deconv3 = self.conv_block(in_channel, in_channel) 211 | self.deconv4 = self.conv_block(in_channel, in_channel//2, upsample=True) 212 | in_channel //= 2 213 | self.deconv5 = self.conv_block(in_channel, in_channel) 214 | self.deconv6 = self.conv_block(in_channel, in_channel//2) 215 | in_channel //= 2 216 | self.deconv7 = self.conv_block(in_channel, in_channel) 217 | self.deconv8 = self.conv_block(in_channel, in_channel//2, upsample=True) 218 | in_channel //= 2 219 | self.deconv9 = self.conv_block(in_channel, in_channel) 220 | self.deconv10 = self.conv_block(in_channel, in_channel) 221 | 222 | self.to_rgb = nn.Sequential( 223 | nn.InstanceNorm2d(in_channel, affine=True), 224 | nn.LeakyReLU(0.2), 225 | # nn.Conv2d(in_channel, 3, 1, 1, 0) 226 | spectral_norm(nn.Conv2d(in_channel, 3, 1, 1, 0)) 227 | ) 228 | 229 | 230 | def conv_block(self, C_in, C_out, K=3, S=1, P=1, upsample=False): 231 | layers = [spectral_norm(nn.Conv2d(C_in, C_out, kernel_size=K, stride=S, padding=P))] 232 | if upsample: 233 | layers += [nn.Upsample(scale_factor=2)] 234 | layers += [ 235 | nn.InstanceNorm2d(C_out, affine=True), 236 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 237 | ] 238 | return nn.Sequential(*layers) 239 | 240 | def forward(self, f_encoded, fs): 241 | f1,f2,f3,f4,f5,f6,f7,f8,f9,f10 = fs 242 | 243 | d1 = self.deconv1(f_encoded) # [1, 256, 32, 32] 244 | 245 | d1 = d1 + f9 # [1, 256, 32, 32] 246 | d2 = self.deconv2(d1) # [1, 128, 64, 64] 247 | 248 | d2 = d2 + f8 # [1, 128, 64, 64] 249 | d3 = self.deconv3(d2) # [1, 128, 64, 64] 250 | 251 | d3 = d3 + f7 # [1, 128, 64, 64] 252 | d4 = self.deconv4(d3) # [1, 64, 64, 64] 253 | 254 | d4 = d4 + f6 # [1, 64, 128, 128] 255 | d5 = self.deconv5(d4) # [1, 64, 64, 64] 256 | 257 | d5 = d5 + f5 # [1, 64, 128, 128] 258 | d6 = self.deconv6(d5) # [1, 32, 128, 128] 259 | 260 | d6 = d6 + f4 # [1, 32, 128, 128] 261 | d7 = self.deconv7(d6) # [1, 32, 128, 128] 262 | 263 | d7 = d7 + f3 # [1, 32, 128, 128] 264 | d8 = self.deconv8(d7) # [1, 16, 256, 256] 265 | 266 | d8 = d8 + f2 # [1, 16, 256, 256] 267 | d9 = self.deconv9(d8) # [1, 16, 256, 256] 268 | 269 | d9 = d9 + f1 # [1, 16, 256, 256] 270 | d10 = self.deconv10(d9) # [1, 16, 256, 256] 271 | 272 | out = self.to_rgb(d10) 273 | 274 | return out 275 | 276 | 277 | if __name__ == '__main__': 278 | I_s = torch.randn(4,1,256,256) 279 | I_r = torch.randn(4,3,256,256) 280 | 281 | G = Generator() 282 | G(I_r, I_s) -------------------------------------------------------------------------------- /preprocess/17938.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/preprocess/17938.png -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/preprocess/__init__.py -------------------------------------------------------------------------------- /preprocess/__pycache__/xdog_blend.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/preprocess/__pycache__/xdog_blend.cpython-37.pyc -------------------------------------------------------------------------------- /preprocess/sketch_keras_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | modified from https://github.com/lllyasviel/sketchKeras 3 | """ 4 | import cv2 5 | import numpy as np 6 | import os 7 | from keras.models import load_model 8 | from scipy import ndimage 9 | from .xdog_blend import * 10 | 11 | 12 | def get_light_map_single(img): 13 | gray = img 14 | gray = gray[None] 15 | gray = gray.transpose((1,2,0)) 16 | blur = cv2.GaussianBlur(gray, (0, 0), 3) 17 | gray = gray.reshape((gray.shape[0],gray.shape[1])) 18 | highPass = gray.astype(int) - blur.astype(int) 19 | highPass = highPass.astype(np.float) 20 | highPass = highPass / 128.0 21 | return highPass 22 | 23 | def normalize_pic(img): 24 | img = img / np.max(img) 25 | return img 26 | 27 | def resize_img_512_3d(img): 28 | zeros = np.zeros((1,3,512,512), dtype=np.float) 29 | zeros[0 , 0 : img.shape[0] , 0 : img.shape[1] , 0 : img.shape[2]] = img 30 | return zeros.transpose((1,2,3,0)) 31 | 32 | def to_keras_enhanced(img): 33 | mat = img.astype(np.float) 34 | mat[mat<0.1] = 0 35 | mat = - mat + 1 36 | mat = mat * 255.0 37 | mat[mat < 0] = 0 38 | mat[mat > 255] = 255 39 | mat=mat.astype(np.uint8) 40 | mat = ndimage.median_filter(mat, 1) 41 | 42 | return mat 43 | 44 | def get_light_map(img): 45 | from_mat = img 46 | width = float(from_mat.shape[1]) 47 | height = float(from_mat.shape[0]) 48 | new_width = 0 49 | new_height = 0 50 | 51 | if (width > height): 52 | if width != 512: 53 | from_mat = cv2.resize(from_mat, (512, int(512 / width * height)), interpolation=cv2.INTER_AREA) 54 | new_width = 512 55 | new_height = int(512 / width * height) 56 | else: 57 | if height != 512: 58 | from_mat = cv2.resize(from_mat, (int(512 / height * width), 512), interpolation=cv2.INTER_AREA) 59 | new_width = int(512 / height * width) 60 | new_height = 512 61 | 62 | from_mat = from_mat.transpose((2, 0, 1)) 63 | light_map = np.zeros(from_mat.shape, dtype=np.float) 64 | for channel in range(3): 65 | light_map[channel] = get_light_map_single(from_mat[channel]) 66 | 67 | light_map = normalize_pic(light_map) 68 | light_map = resize_img_512_3d(light_map) 69 | 70 | return new_height, new_width, light_map 71 | 72 | mod = None 73 | def get_mod(): 74 | global mod 75 | if mod is None: 76 | mod = load_model(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'sketchKeras.h5')) 77 | return mod 78 | 79 | def get_keras_enhanced(img): 80 | mod = get_mod() 81 | new_height, new_width, light_map = get_light_map(img) 82 | 83 | # TODO: batch this! 84 | line_mat = mod.predict(light_map, batch_size=1) 85 | 86 | line_mat = line_mat.transpose((3, 1, 2, 0))[0] 87 | line_mat = line_mat[0:int(new_height), 0:int(new_width), :] 88 | line_mat = np.amax(line_mat, 2) 89 | 90 | keras_enhanced = to_keras_enhanced(line_mat) 91 | return keras_enhanced 92 | 93 | def get_keras_high_intensity(img, intensity=1.7): 94 | keras_img = get_keras_enhanced(img) 95 | return add_intensity(keras_img, intensity) 96 | 97 | def batch_keras_enhanced(img_list): 98 | mod = get_mod() 99 | light_maps = list(map(get_light_map, img_list)) 100 | 101 | hw_list = [(h, w) for h, w, _ in light_maps] 102 | light_map = [l for _, _, l in light_maps] 103 | 104 | batch_light_map = np.concatenate(light_map, axis=0) 105 | batch_line_mat = mod.predict(batch_light_map, batch_size=len(img_list)) 106 | 107 | mat_list = np.array_split(batch_line_mat, len(img_list)) 108 | batch_line_mat = map(lambda map: map.transpose((3, 1, 2, 0))[0], mat_list) 109 | 110 | result_list = [] 111 | for line_mat, (new_height, new_width) in zip(batch_line_mat, hw_list): 112 | mat = line_mat[0:int(new_height), 0:int(new_width), :] 113 | mat = np.amax(mat, 2) 114 | keras_enhanced = to_keras_enhanced(mat) 115 | result_list.append(keras_enhanced) 116 | 117 | return result_list -------------------------------------------------------------------------------- /preprocess/xdog.py: -------------------------------------------------------------------------------- 1 | # This code is https://github.com/blandocs/Tag2Pix/blob/master/preprocessor/sketch_extractor.py 2 | import os, argparse 3 | import urllib3 4 | import shutil 5 | import random 6 | from glob import glob 7 | from multiprocessing import Pool 8 | from pathlib import Path 9 | from itertools import cycle 10 | from sys import exit 11 | import cv2 12 | from tqdm import tqdm 13 | try: 14 | from preprocess.xdog_blend import get_xdog_image, add_intensity 15 | except: 16 | from xdog_blend import get_xdog_image, add_intensity 17 | 18 | 19 | SKETCHKERAS_URL = 'http://github.com/lllyasviel/sketchKeras/releases/download/0.1/mod.h5' 20 | 21 | def make_xdog(img): 22 | s = 0.35 + 0.1 * random.random() 23 | # s = 0.7 24 | k = 2 + random.random() 25 | g = 0.95 26 | return get_xdog_image(img, sigma=s, k=k, gamma=g, epsilon=-0.5, phi=10**9) 27 | 28 | def download_sketchKeras(): 29 | curr_dir = Path(os.path.dirname(os.path.abspath(__file__))) 30 | save_path = curr_dir / 'utils' / 'sketchKeras.h5' 31 | 32 | if save_path.exists(): 33 | print('found sketchKeras.h5') 34 | return 35 | 36 | print('Downloading sketchKeras...') 37 | http = urllib3.PoolManager() 38 | 39 | with http.request('GET', SKETCHKERAS_URL, preload_content=False) as r, save_path.open('wb') as out_file: 40 | shutil.copyfileobj(r, out_file) 41 | 42 | print('Finished downloading sketchKeras.h5') 43 | 44 | def chunks(l, n): 45 | """Yield successive n-sized chunks from l.""" 46 | for i in range(0, len(l), n): 47 | yield l[i:i + n] 48 | 49 | def xdog_write(path_img): 50 | path, img, xdog_result_path = path_img 51 | img = cv2.imread(img, cv2.IMREAD_GRAYSCALE) 52 | xdog_img = make_xdog(img) 53 | print(str(xdog_result_path / path)) 54 | cv2.imwrite(str(xdog_result_path / path), xdog_img) 55 | 56 | if __name__=='__main__': 57 | desc = "XDoG extractor" 58 | 59 | parser = argparse.ArgumentParser(description=desc) 60 | 61 | parser.add_argument('--dataset_path', type=str, default='datasets/line_art/train/color') 62 | parser.add_argument('--xdog_result_path', type=str, default='datasets/line_art/train/xdog') 63 | parser.add_argument('--keras_result_path', type=str, default='datasets/line_art/train/keras') 64 | parser.add_argument('--xdog_only', action='store_true') 65 | # parser.add_argument('--keras_only', action='store_true') 66 | # parser.add_argument('--no_upscale', action='store_true', help='do not upscale keras_train') 67 | 68 | args = parser.parse_args() 69 | 70 | dataset_path = Path(args.dataset_path) 71 | xdog_result_path = Path(args.xdog_result_path) 72 | keras_result_path = Path(args.keras_result_path) 73 | 74 | # path = '17938.png' 75 | # img = '/Users/daichi/work/lab/aym/pfn/rbsic/data/xdog/17938.png' 76 | # out_dir = './' 77 | # img = cv2.imread(img, cv2.IMREAD_GRAYSCALE) 78 | # xdog_img = make_xdog(img) 79 | # cv2.imwrite(out_dir + path, xdog_img) 80 | # exit() 81 | 82 | if not keras_result_path.exists(): 83 | keras_result_path.mkdir() 84 | if not xdog_result_path.exists(): 85 | xdog_result_path.mkdir() 86 | 87 | print('reading images...') 88 | img_list = [] 89 | path_list = [] 90 | for img_f in (dataset_path).iterdir(): 91 | if not img_f.is_file(): 92 | continue 93 | if img_f.suffix.lower() != '.png': 94 | continue 95 | 96 | path_list.append(img_f.name) 97 | img_list.append(str(img_f)) 98 | print('images: ', len(path_list)) 99 | 100 | # if not args.xdog_only: 101 | # from utils.sketch_keras_util import batch_keras_enhanced 102 | 103 | # download_sketchKeras() 104 | 105 | # print('Extracting sketchKeras') 106 | # for p_list, chunk in tqdm(list(zip(chunks(path_list, 16), chunks(img_list, 16)))): 107 | # chunk = list(map(cv2.imread, chunk)) 108 | # krs = batch_keras_enhanced(chunk) 109 | 110 | # for name, sketch in zip(p_list, krs): 111 | # sketch = add_intensity(sketch, 1.4) 112 | # cv2.imwrite(str(keras_result_path / name), sketch) 113 | 114 | # if not args.no_upscale: 115 | # from crop_and_upscale import upscale_all 116 | # print('upscaling keras_train images...') 117 | # moved_temp_keras = dataset_path / 'temp_keras' 118 | # shutil.move(str(keras_result_path), str(moved_temp_keras)) 119 | # upscale_all(dataset_path, image_base=moved_temp_keras, save_path=keras_result_path) 120 | # shutil.rmtree(str(moved_temp_keras)) 121 | 122 | # print('extracting sketches from benchmark images...') 123 | # keras_test_dir = dataset_path / 'keras_test' 124 | # if not keras_test_dir.exists(): 125 | # keras_test_dir.mkdir() 126 | 127 | # benchmark_dir = dataset_path / 'benchmark' 128 | # bench_imgs = list(benchmark_dir.iterdir()) 129 | # for img_fs in tqdm(list(chunks(bench_imgs, 16))): 130 | # chunk = list(map(lambda x: cv2.imread(str(x)), img_fs)) 131 | # krs = batch_keras_enhanced(chunk) 132 | 133 | # for img_f, sketch in zip(img_fs, krs): 134 | # sketch = add_intensity(sketch, 1.4) 135 | # cv2.imwrite(str(keras_test_dir / img_f.name), sketch) 136 | 137 | print('Extracting XDoG with 8 threads') 138 | 139 | with Pool(8) as p: 140 | p.map(xdog_write, zip(path_list, img_list, cycle([xdog_result_path]))) -------------------------------------------------------------------------------- /preprocess/xdog_blend.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from scipy import ndimage 4 | 5 | def dog(img, size=(0,0), k=1.6, sigma=0.5, gamma=1): 6 | img1 = cv2.GaussianBlur(img, size, sigma) 7 | img2 = cv2.GaussianBlur(img, size, sigma * k) 8 | return (img1 - gamma * img2) 9 | 10 | def xdog(img, sigma=0.5, k=1.6, gamma=1, epsilon=1, phi=1): 11 | aux = dog(img, sigma=sigma, k=k, gamma=gamma) / 255 12 | for i in range(0, aux.shape[0]): 13 | for j in range(0, aux.shape[1]): 14 | if(aux[i, j] < epsilon): 15 | aux[i, j] = 1*255 16 | else: 17 | aux[i, j] = 255*(1 + np.tanh(phi * (aux[i, j]))) 18 | return aux 19 | 20 | def get_xdog_image(img, sigma=0.4, k=2.5, gamma=0.95, epsilon=-0.5, phi=10**9): 21 | xdog_image = xdog(img, sigma=sigma, k=k, gamma=gamma, epsilon=epsilon, phi=phi).astype(np.uint8) 22 | return xdog_image 23 | 24 | def add_intensity(img, intensity): 25 | if intensity == 1: 26 | return img 27 | inten_const = 255.0 ** (1 - intensity) 28 | return (inten_const * (img ** intensity)).astype(np.uint8) 29 | 30 | def blend_xdog_and_sketch(illust, sketch, intensity=1.7, degamma=(1/1.5), blend=0, **kwargs): 31 | gray_image = cv2.cvtColor(illust, cv2.COLOR_BGR2GRAY) 32 | gamma_sketch = add_intensity(sketch, intensity) 33 | 34 | if blend > 0: 35 | xdog_image = get_xdog_image(gray_image, **kwargs) 36 | xdog_blurred = cv2.GaussianBlur(xdog_image, (5, 5), 1) 37 | xdog_residual_blur = cv2.addWeighted(xdog_blurred, 0.75, xdog_image, 0.25, 0) 38 | 39 | if gamma_sketch.shape != xdog_residual_blur.shape: 40 | gamma_sketch = cv2.resize(gamma_sketch, xdog_residual_blur.shape, interpolation=cv2.INTER_AREA) 41 | 42 | blended_image = cv2.addWeighted(xdog_residual_blur, blend, gamma_sketch, (1-blend), 0) 43 | else: 44 | blended_image = gamma_sketch 45 | 46 | return add_intensity(blended_image, degamma) -------------------------------------------------------------------------------- /scripts/train_tag2pix_keras.sh: -------------------------------------------------------------------------------- 1 | python main.py --dataset tag2pix --line_type xdog --exp_name tag2pix_keras_$1 --line_type keras -------------------------------------------------------------------------------- /scripts/train_tag2pix_xdog.sh: -------------------------------------------------------------------------------- 1 | python main.py --dataset tag2pix --line_type xdog --exp_name tag2pix_xdog_$1 --line_type xdog -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | from model import Generator, Discriminator 2 | from torch.autograd import Variable 3 | from torchvision.utils import save_image 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import os 9 | import time 10 | import datetime 11 | from sys import exit 12 | 13 | from vgg import VGG16FeatureExtractor 14 | from loss import VGGLoss 15 | 16 | def weights_init(m): 17 | classname = m.__class__.__name__ 18 | if classname.find('Conv') != -1: 19 | nn.init.normal_(m.weight.data, 0.0, 0.02) 20 | elif classname.find('BatchNorm') != -1: 21 | nn.init.normal_(m.weight.data, 1.0, 0.02) 22 | nn.init.constant_(m.bias.data, 0) 23 | elif classname.find('InstanceNorm') != -1: 24 | nn.init.normal_(m.weight.data, 1.0, 0.02) 25 | nn.init.constant_(m.bias.data, 0) 26 | 27 | def r1_reg(d_out, x_in): 28 | # zero-centered gradient penalty for real images 29 | batch_size = x_in.size(0) 30 | grad_dout = torch.autograd.grad( 31 | outputs=d_out.sum(), inputs=x_in, 32 | create_graph=True, retain_graph=True, only_inputs=True 33 | )[0] 34 | grad_dout2 = grad_dout.pow(2) 35 | assert(grad_dout2.size() == x_in.size()) 36 | reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0) 37 | return reg 38 | 39 | class Solver(object): 40 | 41 | def __init__(self, data_loader, config): 42 | """Initialize configurations.""" 43 | 44 | self.data_loader = data_loader 45 | self.config = config 46 | 47 | self.build_model(config) 48 | 49 | def build_model(self, config): 50 | """Create a generator and a discriminator.""" 51 | self.G = Generator(config.g_conv_dim, config.d_channel, config.channel_1x1) # 2 for mask vector. 52 | self.D = Discriminator(config.crop_size, config.d_conv_dim, config.d_repeat_num) 53 | 54 | self.G.apply(weights_init) 55 | self.D.apply(weights_init) 56 | 57 | self.G.cuda() 58 | self.D.cuda() 59 | 60 | self.g_optimizer = torch.optim.Adam(self.G.parameters(), config.g_lr, [config.beta1, config.beta2]) 61 | self.d_optimizer = torch.optim.Adam(self.D.parameters(), config.d_lr, [config.beta1, config.beta2]) 62 | 63 | self.G = nn.DataParallel(self.G) 64 | self.D = nn.DataParallel(self.D) 65 | 66 | self.VGGLoss = VGGLoss().eval() 67 | self.VGGLoss.cuda() 68 | self.VGGLoss = nn.DataParallel(self.VGGLoss) 69 | 70 | def adv_loss(self, logits, target): 71 | assert target in [1, 0] 72 | targets = torch.full_like(logits, fill_value=target) 73 | loss = F.binary_cross_entropy_with_logits(logits, targets) 74 | return loss 75 | 76 | def restore_model(self, resume_iters): 77 | """Restore the trained generator and discriminator.""" 78 | print('Loading the trained models from step {}...'.format(resume_iters)) 79 | G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters)) 80 | D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters)) 81 | self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage)) 82 | self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage)) 83 | 84 | def reset_grad(self): 85 | """Reset the gradient buffers.""" 86 | self.g_optimizer.zero_grad() 87 | self.d_optimizer.zero_grad() 88 | 89 | def update_lr(self, g_lr, d_lr): 90 | """Decay learning rates of the generator and discriminator.""" 91 | for param_group in self.g_optimizer.param_groups: 92 | param_group['lr'] = g_lr 93 | for param_group in self.d_optimizer.param_groups: 94 | param_group['lr'] = d_lr 95 | 96 | def denorm(self, x): 97 | """Convert the range from [-1, 1] to [0, 1].""" 98 | out = (x + 1) / 2 99 | return out.clamp_(0, 1) 100 | 101 | def gradient_penalty(self, y, x): 102 | """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" 103 | weight = torch.ones(y.size()).cuda() 104 | dydx = torch.autograd.grad(outputs=y, 105 | inputs=x, 106 | grad_outputs=weight, 107 | retain_graph=True, 108 | create_graph=True, 109 | only_inputs=True)[0] 110 | 111 | dydx = dydx.view(dydx.size(0), -1) 112 | dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) 113 | return torch.mean((dydx_l2norm-1)**2) 114 | 115 | def train(self): 116 | data_loader = self.data_loader 117 | config = self.config 118 | 119 | # Learning rate cache for decaying. 120 | g_lr = config.g_lr 121 | d_lr = config.d_lr 122 | 123 | # Label for lsgan 124 | real_target = torch.full((config.batch_size,), 1.).cuda() 125 | fake_target = torch.full((config.batch_size,), 0.).cuda() 126 | criterion = nn.MSELoss().cuda() 127 | 128 | # Start training. 129 | print('Start training...') 130 | start_time = time.time() 131 | iteration = 0 132 | num_iters_decay = config.num_epoch_decay * len(data_loader) 133 | for epoch in range(config.num_epoch): 134 | 135 | for i, (I_ori, I_gt, I_r, I_s) in enumerate(data_loader): 136 | iteration += i 137 | 138 | I_ori = I_ori.cuda(non_blocking=True) 139 | I_gt = I_gt.cuda(non_blocking=True) 140 | I_r = I_r.cuda(non_blocking=True) 141 | I_s = I_s.cuda(non_blocking=True) 142 | 143 | # =================================================================================== # 144 | # 2. Train the discriminator # 145 | # =================================================================================== # 146 | 147 | # Compute loss with real images. 148 | I_gt.requires_grad_(requires_grad=True) 149 | out = self.D(I_gt) 150 | # d_loss_real = criterion(out, real_target) * 0.5 151 | d_loss_real = self.adv_loss(out, 1) 152 | d_loss_reg = r1_reg(out, I_gt) 153 | 154 | 155 | # Compute loss with fake images. 156 | I_fake = self.G(I_r, I_s) 157 | out = self.D(I_fake.detach()) 158 | # d_loss_fake = criterion(out, fake_target) * 0.5 159 | d_loss_fake = self.adv_loss(out, 0) 160 | 161 | # Backward and optimize. 162 | d_loss = d_loss_real + d_loss_fake + d_loss_reg 163 | self.reset_grad() 164 | d_loss.backward() 165 | self.d_optimizer.step() 166 | 167 | # Logging. 168 | loss = {} 169 | loss['D/loss_real'] = d_loss_real.item() 170 | loss['D/loss_fake'] = d_loss_fake.item() 171 | loss['D/loss_reg'] = d_loss_reg.item() 172 | 173 | # =================================================================================== # 174 | # 3. Train the generator # 175 | # =================================================================================== # 176 | I_gt.requires_grad_(requires_grad=False) 177 | # if (i+1) % config.n_critic == 0: 178 | I_fake, g_loss_tr = self.G(I_r, I_s, IsGTrain=True) 179 | out = self.D(I_fake) 180 | # g_loss_fake = criterion(out, real_target) 181 | g_loss_fake = self.adv_loss(out, 1) 182 | 183 | g_loss_rec = torch.mean(torch.abs(I_fake - I_gt)) # Eq.(6) 184 | 185 | g_loss_prec, g_loss_style = self.VGGLoss(I_gt, I_fake) 186 | g_loss_prec *= config.lambda_perc 187 | g_loss_style *= config.lambda_style 188 | 189 | # Backward and optimize. 190 | g_loss = g_loss_fake + config.lambda_rec * g_loss_rec + config.lambda_tr * g_loss_tr + g_loss_prec + g_loss_style 191 | self.reset_grad() 192 | g_loss.backward() 193 | self.g_optimizer.step() 194 | 195 | # Logging. 196 | loss['G/loss_fake'] = g_loss_fake.item() 197 | loss['G/loss_rec'] = g_loss_rec.item() 198 | loss['G/loss_tr'] = g_loss_tr.item() 199 | loss['G/loss_prec'] = g_loss_prec.item() 200 | loss['G/loss_style'] = g_loss_style.item() 201 | 202 | # =================================================================================== # 203 | # 4. Miscellaneous # 204 | # =================================================================================== # 205 | 206 | # Print out training information. 207 | if (i+1) % config.log_step == 0: 208 | et = time.time() - start_time 209 | et = str(datetime.timedelta(seconds=et))[:-7] 210 | log = "Elapsed [{}], Epoch [{}/{}], Iteration [{}/{}], g_lr {:.5f}, d_lr {:.5f}".format( 211 | et, epoch, config.num_epoch, i+1, len(data_loader), 212 | g_lr, d_lr) 213 | for tag, value in loss.items(): 214 | log += ", {}: {:.4f}".format(tag, value) 215 | print(log) 216 | 217 | # Decay learning rates. 218 | if (epoch+1) > config.num_epoch_decay: 219 | g_lr -= (config.g_lr / float(num_iters_decay)) 220 | d_lr -= (config.d_lr / float(num_iters_decay)) 221 | self.update_lr(g_lr, d_lr) 222 | # print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) 223 | 224 | 225 | # Translate fixed images for debugging. 226 | if (epoch+1) % config.sample_epoch == 0: 227 | with torch.no_grad(): 228 | 229 | I_fake_ori = self.G(I_ori, I_s) 230 | I_fake_zero = self.G(torch.zeros(I_ori.size()), I_s) 231 | 232 | sample_path = os.path.join(config.sample_dir, '{}.jpg'.format(epoch)) 233 | I_concat = self.denorm(torch.cat([I_ori, I_gt, I_r, I_fake, I_fake_ori, I_fake_zero], dim=2)) 234 | I_concat = torch.cat([I_concat, I_s.repeat(1,3,1,1)], dim=2) 235 | save_image(I_concat.data.cpu(), sample_path) 236 | print('Saved real and fake images into {}...'.format(sample_path)) 237 | 238 | 239 | G_path = os.path.join(config.model_save_dir, '{}-G.ckpt'.format(epoch+1)) 240 | torch.save(self.G.state_dict(), G_path) 241 | print('Saved model checkpoints into {}...'.format(config.model_save_dir)) 242 | 243 | 244 | def test(self): 245 | """Translate images using StarGAN trained on a single dataset.""" 246 | # Load the trained generator. 247 | self.restore_model(self.test_iters) 248 | 249 | # Set data loader. 250 | if self.dataset == 'CelebA': 251 | data_loader = self.celeba_loader 252 | elif self.dataset == 'RaFD': 253 | data_loader = self.rafd_loader 254 | 255 | with torch.no_grad(): 256 | for i, (x_real, c_org) in enumerate(data_loader): 257 | 258 | # Prepare input images and target domain labels. 259 | x_real = x_real.to(self.device) 260 | c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs) 261 | 262 | # Translate images. 263 | x_fake_list = [x_real] 264 | for c_trg in c_trg_list: 265 | x_fake_list.append(self.G(x_real, c_trg)) 266 | 267 | # Save the translated images. 268 | x_concat = torch.cat(x_fake_list, dim=3) 269 | result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1)) 270 | save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) 271 | print('Saved real and fake images into {}...'.format(result_path)) 272 | -------------------------------------------------------------------------------- /tmp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/tmp.png -------------------------------------------------------------------------------- /tps/__pycache__/numpy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UdonDa/Reference-Based-Sketch-Image-Colorization/4b5ff190d2fe9b08f41da9abed428891c3e4da9b/tps/__pycache__/numpy.cpython-37.pyc -------------------------------------------------------------------------------- /tps/numpy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Christoph Heindl. 2 | # 3 | # Licensed under MIT License 4 | # ============================================================ 5 | 6 | import numpy as np 7 | import cv2 8 | 9 | class TPS: 10 | @staticmethod 11 | def fit(c, lambd=0., reduced=False): 12 | n = c.shape[0] 13 | 14 | U = TPS.u(TPS.d(c, c)) 15 | K = U + np.eye(n, dtype=np.float32)*lambd 16 | 17 | P = np.ones((n, 3), dtype=np.float32) 18 | P[:, 1:] = c[:, :2] 19 | 20 | v = np.zeros(n+3, dtype=np.float32) 21 | v[:n] = c[:, -1] 22 | 23 | A = np.zeros((n+3, n+3), dtype=np.float32) 24 | A[:n, :n] = K 25 | A[:n, -3:] = P 26 | A[-3:, :n] = P.T 27 | 28 | theta = np.linalg.solve(A, v) # p has structure w,a 29 | return theta[1:] if reduced else theta 30 | 31 | @staticmethod 32 | def d(a, b): 33 | return np.sqrt(np.square(a[:, None, :2] - b[None, :, :2]).sum(-1)) 34 | 35 | @staticmethod 36 | def u(r): 37 | return r**2 * np.log(r + 1e-6) 38 | 39 | @staticmethod 40 | def z(x, c, theta): 41 | x = np.atleast_2d(x) 42 | U = TPS.u(TPS.d(x, c)) 43 | w, a = theta[:-3], theta[-3:] 44 | reduced = theta.shape[0] == c.shape[0] + 2 45 | if reduced: 46 | w = np.concatenate((-np.sum(w, keepdims=True), w)) 47 | b = np.dot(U, w) 48 | return a[0] + a[1]*x[:, 0] + a[2]*x[:, 1] + b 49 | 50 | def uniform_grid(shape): 51 | '''Uniform grid coordinates. 52 | 53 | Params 54 | ------ 55 | shape : tuple 56 | HxW defining the number of height and width dimension of the grid 57 | 58 | Returns 59 | ------- 60 | points: HxWx2 tensor 61 | Grid coordinates over [0,1] normalized image range. 62 | ''' 63 | 64 | H,W = shape[:2] 65 | c = np.empty((H, W, 2)) 66 | c[..., 0] = np.linspace(0, 1, W, dtype=np.float32) 67 | c[..., 1] = np.expand_dims(np.linspace(0, 1, H, dtype=np.float32), -1) 68 | 69 | return c 70 | 71 | def tps_theta_from_points(c_src, c_dst, reduced=False): 72 | delta = c_src - c_dst 73 | 74 | cx = np.column_stack((c_dst, delta[:, 0])) 75 | cy = np.column_stack((c_dst, delta[:, 1])) 76 | 77 | theta_dx = TPS.fit(cx, reduced=reduced) 78 | theta_dy = TPS.fit(cy, reduced=reduced) 79 | 80 | return np.stack((theta_dx, theta_dy), -1) 81 | 82 | 83 | def tps_grid(theta, c_dst, dshape): 84 | ugrid = uniform_grid(dshape) 85 | 86 | reduced = c_dst.shape[0] + 2 == theta.shape[0] 87 | 88 | dx = TPS.z(ugrid.reshape((-1, 2)), c_dst, theta[:, 0]).reshape(dshape[:2]) 89 | dy = TPS.z(ugrid.reshape((-1, 2)), c_dst, theta[:, 1]).reshape(dshape[:2]) 90 | dgrid = np.stack((dx, dy), -1) 91 | 92 | grid = dgrid + ugrid 93 | 94 | return grid # H'xW'x2 grid[i,j] in range [0..1] 95 | 96 | def tps_grid_to_remap(grid, sshape): 97 | '''Convert a dense grid to OpenCV's remap compatible maps. 98 | 99 | Params 100 | ------ 101 | grid : HxWx2 array 102 | Normalized flow field coordinates as computed by compute_densegrid. 103 | sshape : tuple 104 | Height and width of source image in pixels. 105 | 106 | 107 | Returns 108 | ------- 109 | mapx : HxW array 110 | mapy : HxW array 111 | ''' 112 | 113 | mx = (grid[:, :, 0] * sshape[1]).astype(np.float32) 114 | my = (grid[:, :, 1] * sshape[0]).astype(np.float32) 115 | 116 | return mx, my 117 | 118 | def warp_image_cv(img, c_src, c_dst, dshape=None): 119 | dshape = dshape or img.shape 120 | theta = tps_theta_from_points(c_src, c_dst, reduced=True) 121 | grid = tps_grid(theta, c_dst, dshape) 122 | mapx, mapy = tps_grid_to_remap(grid, img.shape) 123 | # print('=========================', img.shape) 124 | return cv2.remap(img, mapx, mapy, cv2.INTER_CUBIC) -------------------------------------------------------------------------------- /tps/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Christoph Heindl. 2 | # 3 | # Licensed under MIT License 4 | # ============================================================ 5 | 6 | import torch 7 | 8 | def tps(theta, ctrl, grid): 9 | '''Evaluate the thin-plate-spline (TPS) surface at xy locations arranged in a grid. 10 | The TPS surface is a minimum bend interpolation surface defined by a set of control points. 11 | The function value for a x,y location is given by 12 | 13 | TPS(x,y) := theta[-3] + theta[-2]*x + theta[-1]*y + \sum_t=0,T theta[t] U(x,y,ctrl[t]) 14 | 15 | This method computes the TPS value for multiple batches over multiple grid locations for 2 16 | surfaces in one go. 17 | 18 | Params 19 | ------ 20 | theta: Nx(T+3)x2 tensor, or Nx(T+2)x2 tensor 21 | Batch size N, T+3 or T+2 (reduced form) model parameters for T control points in dx and dy. 22 | ctrl: NxTx2 tensor or Tx2 tensor 23 | T control points in normalized image coordinates [0..1] 24 | grid: NxHxWx3 tensor 25 | Grid locations to evaluate with homogeneous 1 in first coordinate. 26 | 27 | Returns 28 | ------- 29 | z: NxHxWx2 tensor 30 | Function values at each grid location in dx and dy. 31 | ''' 32 | 33 | N, H, W, _ = grid.size() 34 | 35 | if ctrl.dim() == 2: 36 | ctrl = ctrl.expand(N, *ctrl.size()) 37 | 38 | T = ctrl.shape[1] 39 | 40 | diff = grid[...,1:].unsqueeze(-2) - ctrl.unsqueeze(1).unsqueeze(1) 41 | D = torch.sqrt((diff**2).sum(-1)) 42 | U = (D**2) * torch.log(D + 1e-6) 43 | 44 | w, a = theta[:, :-3, :], theta[:, -3:, :] 45 | 46 | reduced = T + 2 == theta.shape[1] 47 | if reduced: 48 | w = torch.cat((-w.sum(dim=1, keepdim=True), w), dim=1) 49 | 50 | # U is NxHxWxT 51 | b = torch.bmm(U.view(N, -1, T), w).view(N,H,W,2) 52 | # b is NxHxWx2 53 | z = torch.bmm(grid.view(N,-1,3), a).view(N,H,W,2) + b 54 | 55 | return z 56 | 57 | def tps_grid(theta, ctrl, size): 58 | '''Compute a thin-plate-spline grid from parameters for sampling. 59 | 60 | Params 61 | ------ 62 | theta: Nx(T+3)x2 tensor 63 | Batch size N, T+3 model parameters for T control points in dx and dy. 64 | ctrl: NxTx2 tensor, or Tx2 tensor 65 | T control points in normalized image coordinates [0..1] 66 | size: tuple 67 | Output grid size as NxCxHxW. C unused. This defines the output image 68 | size when sampling. 69 | 70 | Returns 71 | ------- 72 | grid : NxHxWx2 tensor 73 | Grid suitable for sampling in pytorch containing source image 74 | locations for each output pixel. 75 | ''' 76 | N, _, H, W = size 77 | 78 | grid = theta.new(N, H, W, 3) 79 | grid[:, :, :, 0] = 1. 80 | grid[:, :, :, 1] = torch.linspace(0, 1, W) 81 | grid[:, :, :, 2] = torch.linspace(0, 1, H).unsqueeze(-1) 82 | 83 | z = tps(theta, ctrl, grid) 84 | return (grid[...,1:] + z)*2-1 # [-1,1] range required by F.sample_grid 85 | 86 | def tps_sparse(theta, ctrl, xy): 87 | if xy.dim() == 2: 88 | xy = xy.expand(theta.shape[0], *xy.size()) 89 | 90 | N, M = xy.shape[:2] 91 | grid = xy.new(N, M, 3) 92 | grid[..., 0] = 1. 93 | grid[..., 1:] = xy 94 | 95 | z = tps(theta, ctrl, grid.view(N,M,1,3)) 96 | return xy + z.view(N, M, 2) 97 | 98 | def uniform_grid(shape): 99 | '''Uniformly places control points aranged in grid accross normalized image coordinates. 100 | 101 | Params 102 | ------ 103 | shape : tuple 104 | HxW defining the number of control points in height and width dimension 105 | 106 | Returns 107 | ------- 108 | points: HxWx2 tensor 109 | Control points over [0,1] normalized image range. 110 | ''' 111 | H,W = shape[:2] 112 | c = torch.zeros(H, W, 2) 113 | c[..., 0] = torch.linspace(0, 1, W) 114 | c[..., 1] = torch.linspace(0, 1, H).unsqueeze(-1) 115 | return c 116 | 117 | if __name__ == '__main__': 118 | c = torch.tensor([ 119 | [0., 0], 120 | [1., 0], 121 | [1., 1], 122 | [0, 1], 123 | ]).unsqueeze(0) 124 | theta = torch.zeros(1, 4+3, 2) 125 | size= (1,1,6,3) 126 | print(tps_grid(theta, c, size).shape) -------------------------------------------------------------------------------- /vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | 4 | 5 | class VGG16FeatureExtractor(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | vgg16 = models.vgg16(pretrained=True) 9 | self.enc_1 = nn.Sequential(*vgg16.features[:5]) 10 | self.enc_2 = nn.Sequential(*vgg16.features[5:10]) 11 | self.enc_3 = nn.Sequential(*vgg16.features[10:17]) 12 | 13 | # fix the encoder 14 | for i in range(3): 15 | for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters(): 16 | param.requires_grad = False 17 | 18 | def forward(self, image): 19 | results = [image] 20 | for i in range(3): 21 | func = getattr(self, 'enc_{:d}'.format(i + 1)) 22 | results.append(func(results[-1])) 23 | return results[1:] 24 | 25 | 26 | 27 | if __name__ == '__main__': 28 | import torch 29 | x = torch.randn(1,3,256,256) 30 | 31 | model = VGG16FeatureExtractor() 32 | out = model(x) 33 | 34 | 35 | print(len(out)) 36 | for o in out: 37 | print(o.size()) 38 | 39 | print(model.enc_1) --------------------------------------------------------------------------------