├── app └── src │ ├── build │ └── settings │ │ ├── mac.json │ │ ├── linux.json │ │ └── base.json │ └── main │ ├── icons │ ├── Icon.ico │ ├── base │ │ ├── 16.png │ │ ├── 24.png │ │ ├── 32.png │ │ ├── 48.png │ │ └── 64.png │ ├── mac │ │ ├── 128.png │ │ ├── 256.png │ │ ├── 512.png │ │ └── 1024.png │ ├── linux │ │ ├── 128.png │ │ ├── 256.png │ │ ├── 512.png │ │ └── 1024.png │ └── README.md │ ├── resources │ └── base │ │ ├── eraser.png │ │ └── MAINUI.ui │ └── python │ ├── view │ ├── __init__.py │ ├── err_log_dialog.py │ ├── painter.py │ └── winodw.py │ ├── main.py │ ├── transforms.py │ └── inference.py ├── src ├── 01.png ├── 02.jpg ├── 03.jpg ├── 04.jpg ├── 05.png ├── 06.png └── web_01.png ├── .github ├── PULL_REQUEST_TEMPLATE.md └── ISSUE_TEMPLATE │ ├── BUG_ISSUE.md │ └── FEATURE_REQUEST.md ├── data ├── __init__.py ├── dataloader.py ├── imageProcessing.py └── dataset.py ├── trainer ├── __init__.py ├── opt.py ├── autoencoder_trainer.py ├── colorization_trainer.py └── draftmodel_trainer.py ├── .gitignore ├── models ├── __init__.py ├── discriminator.py ├── autoencoder.py ├── generator.py └── opt.py ├── main.py ├── LICENSE ├── convert_onnx_model.py ├── hyperparameters.yml └── README.md /app/src/build/settings/mac.json: -------------------------------------------------------------------------------- 1 | { 2 | "mac_bundle_identifier": "" 3 | } -------------------------------------------------------------------------------- /src/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/src/01.png -------------------------------------------------------------------------------- /src/02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/src/02.jpg -------------------------------------------------------------------------------- /src/03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/src/03.jpg -------------------------------------------------------------------------------- /src/04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/src/04.jpg -------------------------------------------------------------------------------- /src/05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/src/05.png -------------------------------------------------------------------------------- /src/06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/src/06.png -------------------------------------------------------------------------------- /src/web_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/src/web_01.png -------------------------------------------------------------------------------- /app/src/main/icons/Icon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/app/src/main/icons/Icon.ico -------------------------------------------------------------------------------- /app/src/main/icons/base/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/app/src/main/icons/base/16.png -------------------------------------------------------------------------------- /app/src/main/icons/base/24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/app/src/main/icons/base/24.png -------------------------------------------------------------------------------- /app/src/main/icons/base/32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/app/src/main/icons/base/32.png -------------------------------------------------------------------------------- /app/src/main/icons/base/48.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/app/src/main/icons/base/48.png -------------------------------------------------------------------------------- /app/src/main/icons/base/64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/app/src/main/icons/base/64.png -------------------------------------------------------------------------------- /app/src/main/icons/mac/128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/app/src/main/icons/mac/128.png -------------------------------------------------------------------------------- /app/src/main/icons/mac/256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/app/src/main/icons/mac/256.png -------------------------------------------------------------------------------- /app/src/main/icons/mac/512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/app/src/main/icons/mac/512.png -------------------------------------------------------------------------------- /app/src/main/icons/linux/128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/app/src/main/icons/linux/128.png -------------------------------------------------------------------------------- /app/src/main/icons/linux/256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/app/src/main/icons/linux/256.png -------------------------------------------------------------------------------- /app/src/main/icons/linux/512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/app/src/main/icons/linux/512.png -------------------------------------------------------------------------------- /app/src/main/icons/mac/1024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/app/src/main/icons/mac/1024.png -------------------------------------------------------------------------------- /app/src/main/icons/linux/1024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/app/src/main/icons/linux/1024.png -------------------------------------------------------------------------------- /app/src/build/settings/linux.json: -------------------------------------------------------------------------------- 1 | { 2 | "categories": "Utility;", 3 | "description": "", 4 | "author_email": "", 5 | "url": "" 6 | } -------------------------------------------------------------------------------- /app/src/main/resources/base/eraser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rapidrabbit76/SketchColorization/HEAD/app/src/main/resources/base/eraser.png -------------------------------------------------------------------------------- /app/src/main/python/view/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .painter import Painter 3 | from .err_log_dialog import ErrLogDialog 4 | from .winodw import Window 5 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Summary 2 | 3 | content 4 | 5 | # Work 6 | 7 | content 8 | 9 | # Related issues [optional] 10 | 11 | write related issues 12 | -------------------------------------------------------------------------------- /app/src/build/settings/base.json: -------------------------------------------------------------------------------- 1 | { 2 | "app_name": "SketchColorization", 3 | "author": "yslee", 4 | "main_module": "src/main/python/main.py", 5 | "version": "0.0.0" 6 | } -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .imageProcessing import dilate_abs_line, xdog, DraftArgumentation, Tensor2Image, Denormalize, RandomCrop, random_flip 2 | from .dataloader import create_data_loader 3 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoencoder_trainer import AutoEncoderTrainer 2 | from .draftmodel_trainer import DraftModelTrainer 3 | from.colorization_trainer import ColorizationModelTrainer 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE 2 | .vscode/ 3 | 4 | # ETC 5 | __pycache__/ 6 | tensorboard/ 7 | 8 | # models 9 | *.zip 10 | *.pt 11 | *.pth 12 | *.onnx 13 | 14 | # TEST 15 | test/ 16 | target/ 17 | TEST.py -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/BUG_ISSUE.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report Template 3 | about: 버그 리포트 템플릿 4 | title: "" 5 | assignees: "yslee" 6 | --- 7 | 8 | # System info 9 | 10 | # Describe of bug 11 | 12 | # Code example 13 | 14 | # Log or Screenshot 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/FEATURE_REQUEST.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request Template 3 | about: 기능 및 추가 요청을 위한 템플릿 4 | title: "" 5 | assignees: "yslee" 6 | --- 7 | 8 | # Description 9 | 10 | # TODO 11 | 12 | - [ ] todo 13 | - [ ] todo 14 | 15 | # ETC 16 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .opt import UpBlock, DownBlock, kaiming_normal, xavier_normal, ResNeXtBottleneck, Flatten 2 | from .autoencoder import AutoEncoder 3 | from .generator import Generator, SketchColorizationModel 4 | from .discriminator import Discriminator 5 | -------------------------------------------------------------------------------- /app/src/main/icons/README.md: -------------------------------------------------------------------------------- 1 | ![Sample app icon](linux/128.png) 2 | 3 | This directory contains the icons that are displayed for your app. Feel free to 4 | change them. 5 | 6 | The difference between the icons on Mac and the other platforms is that on Mac, 7 | they contain a ~5% transparent margin. This is because otherwise they look too 8 | big (eg. in the Dock or in the app switcher). 9 | 10 | You can create Icon.ico from the .png files with 11 | [an online tool](http://icoconvert.com/Multi_Image_to_one_icon/). -------------------------------------------------------------------------------- /app/src/main/python/view/err_log_dialog.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtGui, QtWidgets, QtCore, uic 2 | 3 | 4 | class ErrLogDialog(QtWidgets.QDialog): 5 | def __init__(self): 6 | super(ErrLogDialog, self).__init__() 7 | self.__status_label = QtWidgets.QTextEdit() 8 | self.__status_label.setStyleSheet(self.styleSheet()) 9 | self.setLayout(QtWidgets.QVBoxLayout()) 10 | self.layout().addWidget(self.__status_label) 11 | self.setWindowFlag(QtCore.Qt.FramelessWindowHint) 12 | 13 | def exec_(self, title: str = "Err!", contents: str = "") -> int: 14 | self.setWindowTitle(title) 15 | self.__status_label.setText(contents) 16 | super(ErrLogDialog, self).exec_() 17 | return -1 18 | 19 | def hideEvent(self, a0: QtGui.QHideEvent) -> None: 20 | self.__status_label.clear() 21 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | from trainer import AutoEncoderTrainer, DraftModelTrainer, ColorizationModelTrainer 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser(description='Select Model') 7 | parser.add_argument('--mode', '-M', type=str, 8 | help='(draft, colorization, autoencoder)') 9 | args = parser.parse_args() 10 | 11 | with open('hyperparameters.yml') as yml: 12 | hp = yaml.load(yml, Loader=yaml.FullLoader) 13 | 14 | trainer = None 15 | 16 | if args.mode == 'draft': 17 | trainer = DraftModelTrainer(hp) 18 | elif args.mode == 'colorization': 19 | trainer = ColorizationModelTrainer(hp) 20 | elif args.mode == 'autoencoder': 21 | trainer = AutoEncoderTrainer(hp) 22 | else: 23 | raise NotImplementedError('mode : %s' % args.mode) 24 | 25 | trainer.train() 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 yslee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /convert_onnx_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import SketchColorizationModel 3 | 4 | 5 | def convert_onnx_model(path: str, save_dir: str): 6 | model_ts = torch.jit.load(path) 7 | model = SketchColorizationModel() 8 | model.load_state_dict(model_ts.state_dict()) 9 | 10 | sample_data_line = torch.randn(1, 1, 512, 512) 11 | sample_data_line_draft = torch.randn(1, 1, 128, 128) 12 | sample_data_hint = torch.randn(1, 4, 128, 128) 13 | 14 | model.eval() 15 | 16 | model.forward(sample_data_line, 17 | sample_data_line_draft, 18 | sample_data_hint) 19 | 20 | torch.onnx.export(model, 21 | (sample_data_line, 22 | sample_data_line_draft, 23 | sample_data_hint), 24 | save_dir, 25 | opset_version=11, 26 | export_params=True, 27 | input_names=['line', 28 | 'line_draft', 29 | 'hint'], 30 | output_names=['colored']) 31 | 32 | 33 | if __name__ == "__main__": 34 | model_path = './SketchColorizationModel.zip' 35 | save_dir = './SketchColorizationModel.onnx' 36 | convert_onnx_model(model_path, save_dir) 37 | -------------------------------------------------------------------------------- /app/src/main/python/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import multiprocessing 4 | import os 5 | import qdarkstyle 6 | from fbs_runtime.application_context.PyQt5 import ApplicationContext, cached_property 7 | 8 | from view import Window 9 | 10 | import inference 11 | 12 | 13 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 14 | 15 | multiprocessing.set_start_method('forkserver', force=True) 16 | multiprocessing.freeze_support() 17 | 18 | 19 | class AppContext(ApplicationContext): 20 | __MAIN_UI__ = 'MAINUI.ui' 21 | __SETTINGDIALOG_UI__ = 'dialog.ui' 22 | __ERASER_ICON__ = 'eraser.png' 23 | __MODEL_PATH__ = 'SketchColorizationModel.onnx' 24 | 25 | def __init__(self): 26 | super(AppContext, self).__init__() 27 | 28 | @cached_property 29 | def main_ui(self): 30 | return self.get_resource(self.__MAIN_UI__) 31 | 32 | @cached_property 33 | def eraser_icon(self): 34 | return self.get_resource(self.__ERASER_ICON__) 35 | 36 | @cached_property 37 | def model_path(self): 38 | return self.get_resource(self.__MODEL_PATH__) 39 | 40 | 41 | if __name__ == '__main__': 42 | appctxt = AppContext() 43 | window = Window(appctxt) 44 | window.setStyleSheet(qdarkstyle.load_stylesheet_pyqt5()) 45 | window.show() 46 | inference.__HANDLER__ = inference.InferenceHandler(appctxt) 47 | exit_code = appctxt.app.exec_() 48 | sys.exit(exit_code) 49 | -------------------------------------------------------------------------------- /hyperparameters.yml: -------------------------------------------------------------------------------- 1 | { 2 | image_path: , 3 | line_path: , 4 | pin_memory: True, 5 | device: cuda:0, 6 | seed: 2331, 7 | logdir: ./tensorboard, 8 | 9 | autoencoder: { 10 | ckpt: , 11 | 12 | gf_dim: 64, 13 | ch_dim: 3, 14 | 15 | epoch: 12, 16 | batch_size: 32, 17 | 18 | lr: 0.0001, 19 | beta1: 0.5, 20 | beta2: 0.9, 21 | 22 | lr_milestones: 8, 23 | 24 | log_interval: 10, 25 | sampling_interval: 1000, 26 | validation_interval: 1000, 27 | }, 28 | 29 | draft: { 30 | ckpt: , 31 | autoencoder_path: , 32 | 33 | in_dim: 5, 34 | gf_dim: 64, 35 | df_dim: 64, 36 | ch_dim: 3, 37 | 38 | epoch: 18, 39 | batch_size: 64, 40 | 41 | lr: 0.0001, 42 | beta1: 0.5, 43 | beta2: 0.9, 44 | 45 | lr_milestones: 10, 46 | 47 | w_gan: 0.01, 48 | w_cont: 0.1, 49 | w_recon: 1.0, 50 | w_line: 1.0, 51 | 52 | log_interval: 10, 53 | sampling_interval: 1000, 54 | validation_interval: 1000, 55 | }, 56 | 57 | colorization: { 58 | ckpt: , 59 | draft_model_path: , 60 | 61 | in_dim: 4, 62 | gf_dim: 64, 63 | ch_dim: 3, 64 | 65 | epoch: 8, 66 | batch_size: 8, 67 | 68 | lr: 0.0001, 69 | beta1: 0.5, 70 | beta2: 0.9, 71 | 72 | lr_milestones: 5, 73 | 74 | log_interval: 10, 75 | sampling_interval: 100, 76 | validation_interval: 1000, 77 | }, 78 | 79 | 80 | } -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from models import ResNeXtBottleneck, DownBlock, Flatten 4 | 5 | 6 | class Discriminator(nn.Module): 7 | def __init__(self, dim: int): 8 | super(Discriminator, self).__init__() 9 | 10 | self.main = nn.Sequential(*[ 11 | DownBlock(3, dim // 2, 4), 12 | DownBlock(dim // 2, dim // 2, 3), 13 | DownBlock(dim // 2, dim * 1, 4), 14 | ResNeXtBottleneck(dim * 1, dim * 1, cardinality=4, dilate=1), 15 | DownBlock(dim * 1, dim * 1, 3), 16 | DownBlock(dim * 1, dim * 2, 4), 17 | ResNeXtBottleneck(dim * 2, dim * 2, cardinality=4, dilate=1), 18 | DownBlock(dim * 2, dim * 2, 3), 19 | DownBlock(dim * 2, dim * 4, 4), 20 | ResNeXtBottleneck(dim * 4, dim * 4, cardinality=4, dilate=1), 21 | Flatten() 22 | ]) 23 | 24 | self.last = nn.Linear(256 * 8 * 8, 1, bias=False) 25 | self.sigmoid = nn.Sigmoid() 26 | 27 | def forward(self, tensor: torch.Tensor) -> torch.Tensor: 28 | """ Feed forward method of Discriminator 29 | 30 | Args: 31 | tensor (torch.Tensor): 4D(BCHW) RGB image tensor 32 | Returns: 33 | torch.Tensor: [description] 2D(BU) sigmoid output tensor 34 | """ 35 | tensor = self.main(tensor) 36 | tensor = self.last(tensor) 37 | return self.sigmoid(tensor) 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SketchColorization ([Web](https://omnissiah.ys2lee.com/)) 2 | [![web](./src/web_01.png)](https://omnissiah.ys2lee.com/) 3 | 4 | ![01](./src/01.png) 5 | # Model Structure 6 | 7 | ![02](./src/02.jpg) 8 | # Samples 9 | 10 | ![03](./src/03.jpg) 11 | 12 | ![04](./src/04.jpg) 13 | # GUI 14 | 15 | --- 16 | 17 | ![5](./src/06.png) 18 | 19 | 20 | 21 | # Requirements 22 | 23 | - torch==1.7.1 24 | - torchvision==0.82 25 | - numpy==1.19.1 26 | - tensorboard==2.3.0 27 | - tqdm==4.28.1 28 | - opencv_python==4.4.0.46 29 | - scipy==1.5.2 30 | - Pillow==7.2.0 31 | - scikit-learn==0.23.2 32 | - fbs==0.9.0 33 | - onnx==1.7.0 34 | - onnxruntime==1.5.1 35 | - PyQt5==5.15.1 36 | - QDarkStyle==2.8.1 37 | 38 | # Dataset 39 | - We crawled over 700,000 illustrations from [shuushuu-image-board](https://e-shuushuu.net/) and used them for learning. 40 | 41 | - We have filtered out noise such as extreme aspect ratio, black and white image, low / high key images and etc. 42 | 43 | 44 | # Training 45 | 46 | - The learning sequence is 1. autoencoder, 2. draft, 3. colorization. 47 | 48 | - set hyperparameters.yml, e.g. paths (image_path and line_path, logdir) 49 | 50 | - Start learning after adjusting hyperparameters for each learning step 51 | 52 | - run 'python main.py -M {autoencoder | draft | colorization}' 53 | 54 | # Run APP with source code 55 | 56 | - download pretrained onnx model [SketchColorizationModel.onnx](https://github.com/rapidrabbit76/SketchColorization/releases) 57 | - Copy model to "app/src/main/resources/base/SketchColorizationModel.onnx" 58 | - cd app 59 | - fbs run 60 | -------------------------------------------------------------------------------- /app/src/main/python/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | class Lambda: 6 | def __init__(self, lambd): 7 | self.lambd = lambd 8 | 9 | def __call__(self, img): 10 | return self.lambd(img) 11 | 12 | 13 | class Compose: 14 | def __init__(self, transforms): 15 | self.transforms = transforms 16 | 17 | def __call__(self, img): 18 | for t in self.transforms: 19 | img = t(img) 20 | return img 21 | 22 | 23 | class ToTensor: 24 | def __call__(self, image: Image.Image) -> np.ndarray: 25 | image = np.array(image, dtype=np.float32) 26 | image /= 255.0 27 | if len(image.shape) == 2: 28 | image = np.expand_dims(image, -1) 29 | image = np.transpose(image, [2, 0, 1]) 30 | return image 31 | 32 | 33 | class Resize: 34 | def __init__(self, size, interpolation=Image.BICUBIC): 35 | self.size = size 36 | self.interpolation = interpolation 37 | 38 | def __call__(self, image: Image.Image) -> Image.Image: 39 | return image.resize(self.size, self.interpolation) 40 | 41 | 42 | class Normalize: 43 | def __init__(self, mean, std): 44 | self.mean = np.array(mean, dtype=np.float32) 45 | self.std = np.array(std, dtype=np.float) 46 | 47 | def __call__(self, tensor: np.ndarray) -> np.ndarray: 48 | mean = self.mean 49 | std = self.std 50 | 51 | if mean.ndim == 1: 52 | mean = mean[:, None, None] 53 | if std.ndim == 1: 54 | std = std[:, None, None] 55 | tensor -= mean 56 | tensor /= std 57 | return tensor 58 | -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | import data 2 | import os 3 | from glob import glob 4 | from torch.utils.data import DataLoader 5 | from data.dataset import AutoEncoderDataset, DraftModelDataset, ColorizationModelDataset 6 | 7 | __DATASET_CANDIDATE__ = ['draft', 'colorization', 'autoencoder'] 8 | 9 | 10 | def create_data_loader(hyperparameters: dict, 11 | dataset: str) -> (DataLoader, DataLoader): 12 | """ Create Data Loader "dataset" must be one of candidate 13 | Candidate is 'draft','colorization','autoencoder' 14 | 15 | Args: 16 | hyperparameters (dict): hyperparameter dict(yml) 17 | dataset (str): one of dataset candidate 18 | 19 | Returns: 20 | [Tuple] : Dataloaders 21 | """ 22 | 23 | assert dataset in __DATASET_CANDIDATE__, \ 24 | "Dataset {} is not in {}".format( 25 | dataset, str(__DATASET_CANDIDATE__)) 26 | 27 | color_path = hyperparameters['image_path'] 28 | line_path = hyperparameters['line_path'] 29 | batch_size = hyperparameters[dataset]['batch_size'] 30 | 31 | color_paths = sorted(glob(os.path.join(color_path, '*'))) 32 | line_paths = sorted(glob(os.path.join(line_path, '*'))) 33 | 34 | pivot = int(len(color_paths) * 0.1) 35 | assert len(color_paths) != 0,\ 36 | "Image path {} is Empty".format(color_path) 37 | 38 | assert len(color_paths) == len(line_paths),\ 39 | "image, line count is not same" 40 | 41 | image_paths = [] 42 | 43 | for data in zip(color_paths, line_paths): 44 | image_paths.append(data) 45 | 46 | train_image_paths = image_paths[:-pivot] 47 | test_image_paths = image_paths[-pivot:] 48 | 49 | Dataset = None 50 | 51 | if dataset == 'draft': 52 | Dataset = DraftModelDataset 53 | elif dataset == 'colorization': 54 | Dataset = ColorizationModelDataset 55 | else: 56 | Dataset = AutoEncoderDataset 57 | 58 | # Create Dataset 59 | train_ds = Dataset(train_image_paths, training=True) 60 | test_ds = Dataset(test_image_paths, training=False) 61 | cpu_count = os.cpu_count() 62 | 63 | # Create DataLoader 64 | train_dl = DataLoader(train_ds, 65 | batch_size=batch_size, 66 | shuffle=True, 67 | num_workers=cpu_count, 68 | pin_memory=hyperparameters['pin_memory']) 69 | 70 | test_dl = DataLoader(test_ds, 71 | batch_size=8, 72 | shuffle=False, 73 | num_workers=cpu_count, 74 | pin_memory=hyperparameters['pin_memory']) 75 | 76 | return train_dl, test_dl 77 | -------------------------------------------------------------------------------- /trainer/opt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from abc import ABCMeta, abstractmethod 4 | 5 | import torch 6 | from torch import nn 7 | from torch.utils import tensorboard 8 | from torchvision import models 9 | 10 | 11 | class TrainerBase(metaclass=ABCMeta): 12 | """ Abstract Class for Trainer """ 13 | 14 | def __init__(self, 15 | hp: dict, 16 | model_name: str): 17 | 18 | self.epoch = 0 19 | self.itr = 0 20 | self.hp = hp 21 | seed = hp['seed'] 22 | 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed(seed) 25 | torch.backends.cudnn.benchmark = True 26 | 27 | self.device = torch.device( 28 | hp['device'] if torch.cuda.is_available()else 'cpu') 29 | 30 | start_time = time.strftime( 31 | '%Y-%m-%d-%H:00', 32 | time.localtime(time.time())) 33 | logdir = os.path.join( 34 | hp['logdir'], model_name, start_time) 35 | self.tb = tensorboard.SummaryWriter(logdir) 36 | os.makedirs(self.tb.log_dir, exist_ok=True) 37 | os.makedirs(os.path.join(logdir, 'image'), 38 | exist_ok=True) 39 | os.makedirs(os.path.join(logdir, 'ckpt'), 40 | exist_ok=True) 41 | os.makedirs(os.path.join(logdir, 'torch_script'), 42 | exist_ok=True) 43 | 44 | @abstractmethod 45 | def train(self): 46 | pass 47 | 48 | 49 | class GANLoss(nn.Module): 50 | def __init__(self): 51 | super(GANLoss, self).__init__() 52 | self.register_buffer('real_label', torch.tensor(1.0)) 53 | self.register_buffer('fake_label', torch.tensor(0.0)) 54 | self._loss = nn.BCELoss() 55 | 56 | def __call__(self, inputs: torch.Tensor, 57 | target_is_real: bool) -> torch.Tensor: 58 | """ 59 | Args: 60 | inputs (torch.Tensor): Tensor from Discriminator 61 | target_is_real (bool): bool flag 62 | 63 | Returns: 64 | torch.Tensor: BCE Loss """ 65 | 66 | target_tensor = self.real_label \ 67 | if target_is_real else self.fake_label 68 | target_tensor = target_tensor.expand_as(inputs) 69 | return self._loss(inputs, target_tensor) 70 | 71 | 72 | class ContentLoss(nn.Module): 73 | def __init__(self): 74 | super(ContentLoss, self).__init__() 75 | vgg16 = models.vgg16(pretrained=True) 76 | vgg16.features = nn.Sequential(*list(vgg16.features.children())[:9]) 77 | self.__model = vgg16.features 78 | self.__model.eval() 79 | 80 | self.register_buffer('mean', torch.FloatTensor( 81 | [0.485 - 0.5, 0.456 - 0.5, 0.406 - 0.5]).view(1, 3, 1, 1)) 82 | self.register_buffer('std', torch.FloatTensor( 83 | [0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 84 | 85 | for p in self.__model.parameters(): 86 | p.requires_grad = False 87 | 88 | self.__loss = torch.nn.MSELoss() 89 | 90 | def __pred(self, tensor: torch.Tensor) -> torch.Tensor: 91 | return self.__model((tensor * 0.5 - self.mean) / self.std) 92 | 93 | def forward(self, 94 | fake: torch.Tensor, 95 | target: torch.Tensor) -> torch.Tensor: 96 | """ Calculate Content loss 97 | MSE of VGG16 model's feature map between fake, target 98 | 99 | Args: 100 | fake (torch.Tensor): 4D Tensor of generated Image 101 | target (torch.Tensor): 4D Tensor of real dataset Image 102 | 103 | Returns: 104 | torch.Tensor: Content loss (MSE loss) 105 | """ 106 | with torch.no_grad(): 107 | _fake = self.__pred(fake) 108 | _target = self.__pred(target) 109 | 110 | loss = self.__loss(_fake, _target) 111 | return loss 112 | -------------------------------------------------------------------------------- /app/src/main/python/view/painter.py: -------------------------------------------------------------------------------- 1 | import io 2 | import numpy as np 3 | from PIL import Image 4 | from PyQt5 import QtGui, QtWidgets, QtCore, uic 5 | 6 | 7 | class Painter(QtWidgets.QWidget): 8 | def __init__(self, color_picker): 9 | super().__init__() 10 | self.color_picker = color_picker 11 | self.chosen_point = [] 12 | self._hint = QtGui.QPixmap(500, 500) 13 | self._hint.fill(QtCore.Qt.white) 14 | 15 | dummy = np.require( 16 | np.zeros(shape=[500, 500, 3], 17 | dtype=np.uint8), np.uint8, 'C') 18 | self.line_np = dummy 19 | dummy = QtGui.QImage(dummy, 500, 500, QtGui.QImage.Format_RGB888) 20 | self._line = QtGui.QPixmap(dummy) 21 | 22 | self.setMinimumSize(520, 520) 23 | self.last_x, self.last_y = None, None 24 | 25 | self.pen = QtGui.QPen() 26 | self.pen.setWidth(4) 27 | self.pen.setColor(QtCore.Qt.red) 28 | 29 | def QImageToCvMat(self, incomingImage): 30 | ''' Converts a QImage into an opencv MAT format ''' 31 | 32 | incomingImage = incomingImage.convertToFormat( 33 | QtGui.QImage.Format.Format_RGBA8888) 34 | 35 | width = incomingImage.width() 36 | height = incomingImage.height() 37 | 38 | ptr = incomingImage.constBits() 39 | ptr.setsize(height * width * 4) 40 | arr = np.frombuffer(ptr, np.uint8).reshape((height, width, 4)) 41 | return arr 42 | 43 | def QImageToImage(self, image: QtGui.QImage): 44 | buf = QtCore.QBuffer() 45 | image.save(buf, 'png') 46 | return Image.open(io.BytesIO(buf.data())) 47 | 48 | def get_image(self): 49 | size = self._line.size() 50 | hint_map = QtGui.QPixmap(size) 51 | hint_map.fill(QtCore.Qt.transparent) 52 | painter = QtGui.QPainter(hint_map) 53 | 54 | for pos in self.chosen_point: 55 | self.pen.setColor(pos['color']) 56 | self.pen.setWidth(pos['width']) 57 | painter.setPen(self.pen) 58 | painter.drawPoint(pos['pos']) 59 | 60 | painter.end() 61 | hint = self.QImageToImage(hint_map.toImage()) 62 | hint = np.array(hint) 63 | return self.line_np, hint 64 | 65 | @staticmethod 66 | def create_pixmap(image: np.ndarray): 67 | image = QtGui.QImage(image, image.shape[1], 68 | image.shape[0], image.shape[1] * 3, 69 | QtGui.QImage.Format_RGB888) 70 | return QtGui.QPixmap(image) 71 | 72 | def set_line(self, image: Image.Image, parent) -> None: 73 | image = image.convert('RGB') 74 | w, h = image.size 75 | image = np.array(image) 76 | 77 | self.line_np = image 78 | self._line = self.create_pixmap(image) 79 | self.setFixedHeight(h) 80 | self.setFixedWidth(w) 81 | parent.resize(parent.minimumSize()) 82 | 83 | def paintEvent(self, a0: QtGui.QPaintEvent) -> None: 84 | painter = QtGui.QPainter(self) 85 | painter.drawPixmap(self.rect(), self._line) 86 | color = self.pen.color() 87 | size = self.pen.width() 88 | painter.setRenderHint(QtGui.QPainter.Antialiasing, True) 89 | 90 | for pos in self.chosen_point: 91 | self.pen.setColor(pos['color']) 92 | self.pen.setWidth(pos['width']) 93 | painter.setPen(self.pen) 94 | painter.drawPoint(pos['pos']) 95 | 96 | self.pen.setColor(color) 97 | self.pen.setWidth(size) 98 | 99 | def mouseReleaseEvent(self, e: QtGui.QMouseEvent) -> None: 100 | if e.button() == QtCore.Qt.LeftButton: 101 | data = { 102 | 'pos': e.pos(), 103 | 'color': self.pen.color(), 104 | 'width': self.pen.width() 105 | } 106 | self.chosen_point.append(data) 107 | self.update() 108 | 109 | def mousePressEvent(self, e): 110 | if e.button() == QtCore.Qt.RightButton: 111 | self.remove() 112 | self.update() 113 | 114 | def remove(self): 115 | if len(self.chosen_point) > 0: 116 | self.chosen_point.pop() 117 | self.update() 118 | -------------------------------------------------------------------------------- /models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Encoder(nn.Module): 6 | class Block(nn.Module): 7 | """ Convolution Block Conv,Norm,Activate """ 8 | 9 | def __init__(self, 10 | in_channels: int, 11 | out_channels: int, 12 | kernel_size: int) -> None: 13 | super(Encoder.Block, self).__init__() 14 | stride = 1 if kernel_size == 3 else 2 15 | self.block = nn.Sequential( 16 | nn.Conv2d(in_channels=in_channels, 17 | out_channels=out_channels, 18 | kernel_size=kernel_size, 19 | stride=stride, 20 | padding=1, bias=False), 21 | nn.GroupNorm(4, out_channels), 22 | nn.ReLU()) 23 | 24 | def forward(self, tensor: torch.Tensor) -> torch.Tensor: 25 | return self.block(tensor) 26 | 27 | def __init__(self, dim: int): 28 | Block = self.Block 29 | super(Encoder, self).__init__() 30 | self.main = nn.Sequential( 31 | Block(3, dim // 2, 3), # 32 ,3,1 32 | Block(dim // 2, dim * 1, 4), # 64 ,4,2 33 | Block(dim * 1, dim * 1, 3), # 64 ,3,1 34 | Block(dim * 1, dim * 2, 4), # 128,4,2 35 | Block(dim * 2, dim * 2, 3), # 128,3,1 36 | Block(dim * 2, dim * 4, 4), # 256,4,2 37 | Block(dim * 4, dim * 4, 3), # 256,3,1 38 | Block(dim * 4, dim * 8, 4), # 512,4,2 39 | Block(dim * 8, dim * 8, 3), # 512,3,1 40 | ) 41 | 42 | def forward(self, tensor: torch.Tensor): 43 | return self.main(tensor) 44 | 45 | 46 | class Decoder(Encoder): 47 | class UpBlock(nn.Module): 48 | """ Convolution Block Conv,Norm,Activate """ 49 | 50 | def __init__(self, 51 | in_channels: int, 52 | out_channels: int, 53 | kernel_size: int) -> None: 54 | super(Decoder.UpBlock, self).__init__() 55 | stride = 1 if kernel_size == 3 else 2 56 | self.block = nn.Sequential( 57 | nn.ConvTranspose2d(in_channels=in_channels, 58 | out_channels=out_channels, 59 | kernel_size=kernel_size, 60 | stride=stride, 61 | padding=1, 62 | bias=False), 63 | nn.GroupNorm(4, out_channels), 64 | nn.ReLU()) 65 | 66 | def forward(self, tensor: torch.Tensor) -> torch.Tensor: 67 | return self.block(tensor) 68 | 69 | def __init__(self, dim: int) -> None: 70 | super(Decoder, self).__init__(dim) 71 | UpBlock = self.UpBlock 72 | Block = self.Block 73 | 74 | self.main = nn.Sequential( 75 | UpBlock(dim * 8, dim * 8, 4), 76 | Block(dim * 8, dim * 4, 3), 77 | UpBlock(dim * 4, dim * 4, 4), 78 | Block(dim * 4, dim * 2, 3), 79 | UpBlock(dim * 2, dim * 2, 4), 80 | Block(dim * 2, dim * 1, 3), 81 | UpBlock(dim * 1, dim * 1, 4), 82 | Block(dim * 1, dim // 2, 3), 83 | ) 84 | 85 | self.last = nn.Sequential( 86 | nn.Conv2d(dim // 2, 1, kernel_size=3, 87 | stride=1, padding=1, bias=False), 88 | nn.Tanh()) 89 | 90 | def forward(self, tensor: torch.Tensor): 91 | tensor = self.main(tensor) 92 | return self.last(tensor) 93 | 94 | 95 | class AutoEncoder(nn.Module): 96 | """ 97 | AutoEncoder 98 | input shape : 3x128x128 99 | output shape : 1x128x128 100 | """ 101 | 102 | def __init__(self, dim: int): 103 | super(AutoEncoder, self).__init__() 104 | self.encoder = Encoder(dim) 105 | self.decoder = Decoder(dim) 106 | 107 | def forward(self, tensor: torch.Tensor): 108 | encoder = self.encoder(tensor) 109 | return self.decoder(encoder) 110 | 111 | def weight_init(self): 112 | """ Model weight init methods """ 113 | 114 | def init(m): 115 | kaiming = nn.init.kaiming_normal_ 116 | if isinstance(m, nn.conv2d): 117 | kaiming(m.weight.data) 118 | elif isinstance(m, nn.batchnorm2d): 119 | nn.init.normal_(m.weight.data, 0., 0.02) 120 | elif isinstance(m, nn.linear): 121 | kaiming(m.weight.data) 122 | m.bias.data.fill_(0.000001) 123 | self.apply(init) 124 | -------------------------------------------------------------------------------- /models/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from models import UpBlock, DownBlock, kaiming_normal, xavier_normal 4 | 5 | 6 | class Generator(nn.Module): 7 | """ Generator of draft model and colorization model 8 | draft model : in_dim: 5(line(1),hint(4)) 9 | colorization model : in_dim: 4(line(1), draft(3)) """ 10 | 11 | def __init__(self, in_dim: int, dim: int): 12 | super(Generator, self).__init__() 13 | 14 | self.e0 = DownBlock(in_dim, dim // 2, 3) # 128*128*32 15 | self.e1 = DownBlock(dim // 2, dim * 1, 4) # 64*64*64 16 | 17 | self.e2 = DownBlock(dim * 1, dim * 1, 3) # 64*64*63 18 | self.e3 = DownBlock(dim * 1, dim * 2, 4) # 32*32*128 19 | 20 | self.e4 = DownBlock(dim * 2, dim * 2, 3) # 32*32*128 21 | self.e5 = DownBlock(dim * 2, dim * 4, 4) # 16*16*256 22 | 23 | self.e6 = DownBlock(dim * 4, dim * 4, 3) # 16*16*256 24 | self.e7 = DownBlock(dim * 4, dim * 8, 4) # 8*8*512 25 | 26 | self.e8 = DownBlock(dim * 8, dim * 8, 3) # 8*8*512 27 | 28 | self.d8 = UpBlock(dim * 8 * 2, dim * 8) # 16*16* 256 29 | self.d6 = UpBlock(dim * 4 * 2, dim * 4) # 32*32*128 30 | self.d4 = UpBlock(dim * 2 * 2, dim * 2) # 64*64*64 31 | self.d2 = UpBlock(dim * 1 * 2, dim * 1) # 128*128*32 32 | self.d1 = DownBlock(dim // 2, dim // 2, 3) # 128*128*32 33 | 34 | self.d0 = nn.Sequential( # 128*128*3 35 | nn.Conv2d(dim, 3, kernel_size=3, 36 | stride=1, padding=1, bias=False), 37 | nn.Tanh()) 38 | 39 | self.__relu_layers = [self.e0, self.e1, self.e2, self.e3, 40 | self.e4, self.e5, self.e6, self.e7, 41 | self.e8, self.d8, self.d6, self.d4, 42 | self.d2, self.d1] 43 | 44 | self.__than_layers = [self.d0] 45 | 46 | def forward(self, 47 | line: torch.Tensor, 48 | hint: torch.Tensor) -> torch.Tensor: 49 | """ Feed forward method of Generator(draft/colorizaton model) 50 | 51 | Args: 52 | line (torch.Tensor): 4D(BCHW) greyscale image tensor 53 | hint (torch.Tensor): 4D(BCHW) RGBA image tensor 54 | in image tensor RGB scale is -1 to 1 and Alpha scale is 0 to 1 55 | if colorization model hint is draft tensor. 56 | draft tensor color space is RGB and scale is -1 to 1. 57 | 58 | draft model : line, hint shape is N,1,128,128 (line) 59 | and N,4,128,128 (hint) 60 | colorization model : line, hint shape is N,1,512,512 (line) 61 | and N,3,512,512 (hint) 62 | 63 | Returns: 64 | torch.Tensor: draft or color image tensor (RGB Color space) """ 65 | 66 | inputs = torch.cat([line, hint], 1) 67 | e0 = self.e0(inputs) 68 | e2 = self.e2(self.e1(e0)) 69 | e4 = self.e4(self.e3(e2)) 70 | e6 = self.e6(self.e5(e4)) 71 | e7 = self.e7(e6) 72 | e8 = self.e8(e7) 73 | 74 | tensor = self.d8(torch.cat([e7, e8], 1)) 75 | tensor = self.d6(torch.cat([e6, tensor], 1)) 76 | tensor = self.d4(torch.cat([e4, tensor], 1)) 77 | tensor = self.d2(torch.cat([e2, tensor], 1)) 78 | tensor = self.d1(tensor) 79 | 80 | return self.d0(torch.cat([e0, tensor], 1)) 81 | 82 | def weight_init(self): 83 | for layer in self.__relu_layers: 84 | layer.apply(kaiming_normal) 85 | for layer in self.__than_layers: 86 | layer.apply(xavier_normal) 87 | 88 | 89 | class SketchColorizationModel(nn.Module): 90 | 91 | def __init__(self, dim=64): 92 | super(SketchColorizationModel, self).__init__() 93 | self.draft_model = Generator(5, 64) 94 | self.colorization_model = Generator(4, 64) 95 | self.resize = nn.Upsample(scale_factor=4, mode='nearest') 96 | 97 | def forward(self, 98 | line: torch.Tensor, 99 | line_draft: torch.Tensor, 100 | hint: torch.Tensor) -> torch.Tensor: 101 | """ Feed forward method 102 | 103 | Args: 104 | line (torch.Tensor): 4D(BCHW) greyscale image tensor 105 | line_draft (torch.Tensor): 4D(BCHW) greyscale image tensor 106 | hint (torch.Tensor): 4D(BCHW) RGBA image tensor 107 | in image tensor RGB scale is -1 to 1 and Alpha scale is 0 to 1 108 | 109 | Returns: 110 | torch.Tensor: colored image tensor (RGB Color space) """ 111 | 112 | draft = self.draft_model(line_draft, hint) 113 | draft = self.resize(draft) 114 | colored = self.colorization_model(line, draft) 115 | return colored 116 | -------------------------------------------------------------------------------- /app/src/main/python/inference.py: -------------------------------------------------------------------------------- 1 | import random 2 | from glob import glob 3 | from os import makedev, path 4 | import numpy as np 5 | from PIL import Image, ImageFilter, ImageChops, ImageOps 6 | import onnxruntime 7 | from transforms import Lambda, Compose, ToTensor, Resize, Normalize 8 | from image4layer import Image4Layer 9 | import torch 10 | from torch.nn.functional import interpolate 11 | 12 | 13 | class InferenceHandler: 14 | """ TorchServe Handler for PaintsTorch""" 15 | 16 | def __init__(self, content=None): 17 | 18 | get = content.get_resource 19 | 20 | self.__model = onnxruntime.InferenceSession( 21 | get(content.model_path)) 22 | 23 | self.line_transform = Compose([ 24 | Resize((512, 512)), 25 | ToTensor(), 26 | Normalize([0.5], [0.5]), 27 | Lambda(lambda img: np.expand_dims(img, 0)) 28 | ]) 29 | self.hint_transform = Compose([ 30 | # input must RGBA ! 31 | Resize((128, 128), Image.NEAREST), 32 | Lambda(lambda img: img.convert(mode='RGB')), 33 | ToTensor(), 34 | Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 35 | Lambda(lambda img: np.expand_dims(img, 0)) 36 | ]) 37 | self.line_draft_transform = Compose([ 38 | Resize((128, 128)), 39 | ToTensor(), 40 | Normalize([0.5], [0.5]), 41 | Lambda(lambda img: np.expand_dims(img, 0)) 42 | ]) 43 | self.alpha_transform = Compose([ 44 | Lambda(lambda img: self.get_alpha(img)), 45 | ]) 46 | 47 | def convert_to_pil_image(self, image): 48 | image = np.transpose(image, (0, 2, 3, 1)) 49 | image = image * 0.5 + 0.5 50 | image = image * 255 51 | image = image.astype(np.uint8)[0] 52 | image = Image.fromarray(image).convert('RGB') 53 | return image 54 | 55 | def convert_to_pil_line(self, image, size=512): 56 | image = np.transpose(image, (0, 2, 3, 1)) 57 | image = image * 0.5 + 0.5 58 | image = image * 255 59 | image = image.astype(np.uint8)[0] 60 | image = np.reshape(image, (size, size)) 61 | image = Image.fromarray(image).convert('RGB') 62 | return image 63 | 64 | def get_alpha(self, hint: Image.Image): 65 | """ 66 | :param hint: 67 | :return: 68 | """ 69 | hint = hint.resize((128, 128), Image.NEAREST) 70 | hint = np.array(hint) 71 | alpha = hint[:, :, -1] 72 | alpha = np.expand_dims(alpha, 0) 73 | alpha = np.expand_dims(alpha, 0).astype(np.float32) 74 | alpha[alpha > 0] = 1.0 75 | alpha[alpha > 0] = 1.0 76 | alpha[alpha < 1.0] = 0 77 | return alpha 78 | 79 | def prepare(self, line: Image.Image, hint: Image.Image): 80 | """ 81 | :param req: 82 | :return: 83 | """ 84 | 85 | line = line.convert(mode='L') 86 | alpha = hint.convert(mode='RGBA') 87 | hint = hint.convert(mode='RGBA') 88 | 89 | w, h = line.size 90 | 91 | alpha = self.alpha_transform(alpha) 92 | line_draft = self.line_draft_transform(line) 93 | line = self.line_transform(line) 94 | hint = self.hint_transform(hint) 95 | hint = hint * alpha 96 | hint = np.concatenate([hint, alpha], 1) 97 | return line, line_draft, hint, (w, h) 98 | 99 | def inference(self, data, **kwargs): 100 | """ 101 | PaintsTorch inference 102 | colorization Line Art Image 103 | :param data: tuple (line, line_draft, hint, size) 104 | :return: tuple image, size(w,h) 105 | """ 106 | line, line_draft, hint = data 107 | 108 | inputs_tag = self.__model.get_inputs() 109 | inputs = { 110 | inputs_tag[0].name: line, 111 | inputs_tag[1].name: line_draft, 112 | inputs_tag[2].name: hint 113 | } 114 | image = self.__model.run(None, inputs)[0] 115 | return image 116 | 117 | def resize(self, image: Image.Image, size: tuple) -> Image.Image: 118 | """ 119 | Image resize to 512 120 | :param image: PIL Image data 121 | :param size: w,h tuple 122 | :return: resized Image 123 | """ 124 | (width, height) = size 125 | 126 | if width > height: 127 | rate = width / height 128 | new_height = 512 129 | new_width = int(512 * rate) 130 | else: 131 | rate = height / width 132 | new_width = 512 133 | new_height = int(rate * 512) 134 | 135 | return image.resize((new_width, new_height), Image.BICUBIC) 136 | 137 | def postprocess(self, data) -> Image.Image: 138 | """ 139 | POST processing from inference image Tensor 140 | :param data: tuple image, size(w,h) 141 | :return: processed Image json 142 | """ 143 | pred, size = data 144 | pred = self.convert_to_pil_image(pred) 145 | image = self.resize(pred, size) 146 | return image 147 | 148 | 149 | __HANDLER__: InferenceHandler = None 150 | 151 | 152 | def predict(line: Image.Image, hint: Image.Image, connect_info: tuple): 153 | line, line_draft, hint, size = __HANDLER__.prepare(line, hint) 154 | pred = __HANDLER__.inference((line, line_draft, hint)) 155 | image = __HANDLER__.postprocess((pred, size)) 156 | return image 157 | -------------------------------------------------------------------------------- /trainer/autoencoder_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | from tqdm import tqdm 5 | import yaml 6 | 7 | import torch 8 | from torch import nn, optim 9 | from torchvision.utils import save_image, make_grid 10 | from torch.utils import tensorboard 11 | 12 | from models import AutoEncoder 13 | from trainer import opt 14 | from data import create_data_loader 15 | 16 | 17 | class AutoEncoderTrainer(opt.TrainerBase): 18 | def __init__(self, hp: dict, model_name='AutoEncoder'): 19 | super(AutoEncoderTrainer, self).__init__(hp, model_name) 20 | 21 | hyperparameters = hp['autoencoder'] 22 | self.__model = AutoEncoder( 23 | dim=hyperparameters['gf_dim']) 24 | self.__model = self.__model.to(self.device) 25 | 26 | self.__optimizer = optim.Adam( 27 | self.__model.parameters(), 28 | hyperparameters['lr'], 29 | betas=( 30 | hyperparameters['beta1'], 31 | hyperparameters['beta2'] 32 | )) 33 | 34 | self.__opti_scheduler = optim.lr_scheduler.MultiStepLR( 35 | self.__optimizer, 36 | milestones=[hyperparameters['lr_milestones']], 37 | gamma=0.1 38 | ) 39 | 40 | self.__l1_loss = nn.L1Loss() 41 | self.__l1_loss = self.__l1_loss.to(self.device) 42 | 43 | try: 44 | ckpt = torch.load(hyperparameters['ckpt']) 45 | self.epoch = ckpt['epoch'] 46 | self.itr = ckpt['itr'] 47 | self.__model.load_state_dict(ckpt['AutoEncoder']) 48 | self.__optimizer.load_state_dict(ckpt['adam']) 49 | self.__opti_scheduler.load_state_dict(ckpt['scheduler']) 50 | except Exception as e: 51 | pass 52 | finally: 53 | print("AutoEncoder Trainer Init Done") 54 | 55 | def train(self): 56 | """ train methods """ 57 | 58 | train_set, test_set = create_data_loader(self.hp, 'autoencoder') 59 | batch = next(iter(test_set)) 60 | image, line = [data.to(self.device) for data in batch] 61 | sample_batch = (image, line) 62 | hyperparametsers = self.hp['autoencoder'] 63 | 64 | while self.epoch < hyperparametsers['epoch']: 65 | p_bar = tqdm(train_set, total=len(train_set)) 66 | for batch in p_bar: 67 | loss = self._train_step(batch) 68 | 69 | if self.itr % hyperparametsers['sampling_interval'] == 0: 70 | self._test_step(sample_batch) 71 | 72 | msg = 'E:%d, Itr:%d, Loss:%0.4f' % ( 73 | self.epoch + 1, self.itr, loss) 74 | p_bar.set_description(msg) 75 | self.itr += 1 76 | 77 | self._check_point() 78 | self.__opti_scheduler.step() 79 | self.epoch += 1 80 | 81 | """ Model save as torch script """ 82 | file_name = os.path.join(self.tb.log_dir, 'torch_script') 83 | file_name = os.path.join(file_name, 'AutoEncoder_ts.zip') 84 | ts_model = torch.jit.script(self.__model.cpu(), 85 | torch.rand([1, 3, 128, 128])) 86 | ts_model.save(file_name) 87 | 88 | def _train_step(self, batch: tuple) -> float: 89 | image, line = [data.to(self.device) for data in batch] 90 | 91 | fake_line = self.__model(image) 92 | l1_loss = self.__l1_loss(fake_line, line) 93 | 94 | self.__model.zero_grad() 95 | l1_loss.backward() 96 | self.__optimizer.step() 97 | 98 | if self.itr % self.hp['autoencoder']['log_interval'] == 0: 99 | self.tb.add_scalar('TRAINING/L1_loss', l1_loss.item(), self.itr) 100 | self.tb.add_scalar('Learning Rate', 101 | self.__opti_scheduler.get_last_lr()[0], 102 | self.itr) 103 | 104 | if self.itr % self.hp['autoencoder']['sampling_interval'] == 0: 105 | log_image = [make_grid(image, image.size(0), 106 | 0, range=(-1, 1), normalize=True), 107 | make_grid(fake_line, image.size(0), 108 | 0, range=(-1, 1), normalize=True), 109 | make_grid(line, image.size(0), 110 | 0, range=(-1, 1), normalize=True)] 111 | 112 | self.tb.add_image('TRAINING/SampleImage', 113 | make_grid(log_image, 1, 0), 114 | self.itr) 115 | return l1_loss.item() 116 | 117 | @torch.no_grad() 118 | def _test_step(self, batch: tuple): 119 | """ Test step 120 | this section's tensor not need to trace gradient 121 | 122 | Args: 123 | batch (tuple): batch data tuple (image, line) 124 | """ 125 | 126 | self.__model.eval() 127 | image, line = [data for data in batch] 128 | 129 | fake_line = self.__model.forward(image) 130 | l1_loss = self.__l1_loss(fake_line, line) 131 | 132 | pix_range = (-1, 1) 133 | image = make_grid(image, image.size(0), 134 | 0, True, range=pix_range) 135 | line = make_grid(line, line.size(0), 136 | 0, True, range=pix_range) 137 | fake_line = make_grid(fake_line, fake_line.size(0), 138 | 0, True, range=pix_range) 139 | 140 | sample_image = make_grid([image, fake_line, line], 1, 0, range=(0, 1)) 141 | 142 | file_name = 'sample_image_GS:%s_Loss:%0.4f.jpg' % ( 143 | self.itr, l1_loss.item()) 144 | 145 | file_name = os.path.join(self.tb.log_dir, 146 | 'image', 147 | file_name) 148 | 149 | save_image(sample_image, file_name) 150 | 151 | self.__model.train(True) 152 | 153 | def _check_point(self): 154 | """ Save Checkpoint objects 155 | checkpoint objects contain epoch, itr, model, optimizer, scheduler """ 156 | file_name = os.path.join(self.tb.log_dir, 157 | 'ckpt') 158 | file_name = os.path.join( 159 | file_name, 160 | 'AutoEncoder_E:%d_GS:%d.pth' % (self.epoch, self.itr)) 161 | 162 | ckpt = {'epoch': self.epoch, 163 | 'itr': self.itr, 164 | 'AutoEncoder': self.__model, 165 | 'adam': self.__optimizer, 166 | 'scheduler': self.__opti_scheduler} 167 | torch.save(ckpt, file_name) 168 | -------------------------------------------------------------------------------- /models/opt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | 6 | 7 | class DownBlock(nn.Module): 8 | """ Simple Convolution Block Conv->activation(leaky relu) """ 9 | 10 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int): 11 | super(DownBlock, self).__init__() 12 | 13 | stride = 1 if kernel_size == 3 else 2 14 | self.convolution = nn.Conv2d(in_channels=in_channels, 15 | out_channels=out_channels, 16 | kernel_size=kernel_size, 17 | stride=stride, 18 | padding=1, 19 | bias=False) 20 | self.act = nn.LeakyReLU(0.2, True) 21 | 22 | def forward(self, tensor: torch.Tensor) -> torch.Tensor: 23 | tensor = self.convolution(tensor) 24 | return self.act(tensor) 25 | 26 | 27 | class UpBlock(nn.Module): 28 | """ Up Convolution Block Using Pixelshuffle 29 | conv->act-> resnextblocks -> conv -> pixelshuffle -> act """ 30 | 31 | def __init__(self, in_channels: int, out_channels: int): 32 | super(UpBlock, self).__init__() 33 | 34 | resnext = nn.Sequential( 35 | *[ResNeXtBottleneck(out_channels, 36 | out_channels, 37 | cardinality=32, 38 | dilate=1) for _ in range(10)]) 39 | 40 | self.block = nn.Sequential(nn.Conv2d(in_channels, 41 | out_channels, 42 | kernel_size=3, 43 | stride=1, 44 | padding=1, 45 | bias=False), 46 | nn.LeakyReLU(0.2, True), 47 | resnext, 48 | nn.Conv2d(out_channels, 49 | out_channels // 2 * 4, 50 | kernel_size=3, 51 | stride=1, 52 | padding=1, 53 | bias=False), 54 | nn.PixelShuffle(2), 55 | nn.LeakyReLU(0.2, True)) 56 | 57 | def forward(self, inputs: torch.Tensor): 58 | return self.block(inputs) 59 | 60 | 61 | class ResNeXtBottleneck(nn.Module): 62 | """ ResNext : 63 | (Aggregated Residual Transformations for Deep Neural Networks) """ 64 | 65 | def __init__(self, 66 | in_channels: int = 256, 67 | out_channels: int = 256, 68 | stride: int = 1, 69 | cardinality: int = 32, 70 | dilate: int = 1): 71 | super(ResNeXtBottleneck, self).__init__() 72 | D = out_channels // 2 73 | 74 | self.out_channels = out_channels 75 | self.conv_reduce = nn.Conv2d(in_channels, D, 76 | kernel_size=1, 77 | stride=1, 78 | padding=0, 79 | bias=False) 80 | 81 | self.conv_conv = nn.Conv2d(D, D, 82 | kernel_size=2 + stride, 83 | stride=stride, 84 | padding=dilate, 85 | dilation=dilate, 86 | groups=cardinality, 87 | bias=False) 88 | 89 | self.conv_expand = nn.Conv2d( 90 | D, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 91 | self.shortcut = nn.Sequential() 92 | if stride != 1: 93 | self.shortcut.add_module('shortcut', 94 | nn.AvgPool2d(2, stride=2)) 95 | 96 | def forward(self, x): 97 | bottleneck = self.conv_reduce.forward(x) 98 | bottleneck = F.leaky_relu(bottleneck, 0.2, True) 99 | bottleneck = self.conv_conv.forward(bottleneck) 100 | bottleneck = F.leaky_relu(bottleneck, 0.2, True) 101 | bottleneck = self.conv_expand.forward(bottleneck) 102 | x = self.shortcut.forward(x) 103 | return x + bottleneck 104 | 105 | 106 | class Flatten(nn.Module): 107 | """ Flatten Layer """ 108 | 109 | def forward(self, tensor: torch.Tensor): 110 | """ 111 | :param tensor: 4D Tensor 112 | :return: 4D Tensor 113 | """ 114 | return tensor.view(tensor.size(0), -1) 115 | 116 | 117 | def kaiming_normal(module): 118 | r"""Fills the input `Tensor` with values according to the method 119 | described in `Delving deep into rectifiers: Surpassing human-level 120 | performance on ImageNet classification` - He, K. et al. (2015), using a 121 | normal distribution. The resulting tensor will have values sampled from 122 | :math:`\mathcal{N}(0, \text{std}^2)` where 123 | 124 | .. math:: 125 | \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} 126 | 127 | Also known as He initialization. 128 | 129 | Args: 130 | tensor: an n-dimensional `torch.Tensor` 131 | a: the negative slope of the rectifier used after this layer (only 132 | used with ``'leaky_relu'``) 133 | mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` 134 | preserves the magnitude of the variance of the weights in the 135 | forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the 136 | backwards pass. 137 | nonlinearity: the non-linear function (`nn.functional` name), 138 | recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). 139 | 140 | Examples: 141 | >>> w = torch.empty(3, 5) 142 | >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') 143 | """ 144 | 145 | if isinstance(module, nn.Conv2d): 146 | init.kaiming_normal_(module.weight.data, a=0.2) 147 | elif isinstance(module, nn.ConvTranspose2d): 148 | init.kaiming_normal_(module.weight.data, a=0.2) 149 | elif isinstance(module, nn.Linear): 150 | init.kaiming_normal_(module.weight.data, a=0.2) 151 | 152 | 153 | def xavier_normal(module): 154 | r"""Fills the input `Tensor` with values according to the method 155 | described in `Understanding the difficulty of training deep feedforward 156 | neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal 157 | distribution. The resulting tensor will have values sampled from 158 | :math:`\mathcal{N}(0, \text{std}^2)` where 159 | 160 | .. math:: 161 | \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}} 162 | 163 | Also known as Glorot initialization. 164 | 165 | Args: 166 | tensor: an n-dimensional `torch.Tensor` 167 | gain: an optional scaling factor 168 | 169 | Examples: 170 | >>> w = torch.empty(3, 5) 171 | >>> nn.init.xavier_normal_(w) 172 | """ 173 | if isinstance(module, nn.Conv2d): 174 | init.xavier_normal_(module.weight.data) 175 | elif isinstance(module, nn.ConvTranspose2d): 176 | init.xavier_normal_(module.weight.data) 177 | elif isinstance(module, nn.Linear): 178 | init.xavier_normal_(module.weight.data) 179 | -------------------------------------------------------------------------------- /app/src/main/python/view/winodw.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from PyQt5 import QtGui, QtWidgets, QtCore, uic 3 | from inference import predict 4 | 5 | from . import Painter, ErrLogDialog 6 | 7 | import platform 8 | 9 | op_sys = platform.system() 10 | err_log_dialog = None 11 | 12 | 13 | def popup_err_dialog(title: str = "Err!", contents: str = ""): 14 | global err_log_dialog 15 | if err_log_dialog is None: 16 | err_log_dialog = ErrLogDialog() 17 | 18 | err_log_dialog.exec_(title, contents) 19 | 20 | 21 | class Window(QtWidgets.QMainWindow): 22 | def show(self) -> None: 23 | super(Window, self).show() 24 | 25 | def __init__(self, cfx): 26 | super(Window, self).__init__() 27 | self.ctx = cfx 28 | 29 | uic.loadUi(self.ctx.main_ui, self) 30 | self.eBtn.setIcon(QtGui.QIcon(self.ctx.eraser_icon)) 31 | 32 | self.color_picker = QtWidgets.QColorDialog() 33 | self.painter = Painter(self.color_picker) 34 | self.verticalLayout_3.addWidget(self.painter) 35 | 36 | self.resize(self.minimumSize()) 37 | screen = QtGui.QGuiApplication.screenAt(QtGui.QCursor().pos()) 38 | self.move(screen.size().width() // 3, screen.size().height() // 4) 39 | self.setWindowTitle("") 40 | self.setFixedSize(self.size()) 41 | self.__change_color(None) 42 | self.setAcceptDrops(True) 43 | 44 | self.__event_init() 45 | 46 | self.update() 47 | 48 | def __change_color(self, color): 49 | if color is None: 50 | color = self.color_picker.currentColor() 51 | 52 | rgb = list() 53 | rgba = color.getRgb() 54 | 55 | for index in range(3): 56 | rgb.append(rgba[index]) 57 | color = QtGui.QColor(rgb[0], rgb[1], rgb[2]) 58 | self.colorBtn.setStyleSheet("color:black;" 59 | "border-style: outset;" 60 | "border-width: 1px;" 61 | "border-radius: 5px;" 62 | "border-color: black;" 63 | "background-color: %s;" % color.name()) 64 | self.painter.pen.setColor(color) 65 | self.update() 66 | 67 | def __event_init(self): 68 | self.__liner_flag = False 69 | self.colorBtn.clicked.connect(self.__color_btn_clicked) 70 | self.eBtn.clicked.connect(self.painter.remove) 71 | self.runBtn.clicked.connect(self.__run_btn_clicked) 72 | self.fileOpen.triggered.connect(self.__file_open) 73 | self.fileSave.triggered.connect(self.__file_save) 74 | self.penSizeSlider.valueChanged.connect( 75 | lambda size: self.__set_pan_size(size)) 76 | 77 | self.status.setText("HI") 78 | 79 | def __color_btn_clicked(self): 80 | color = self.color_picker.getColor() 81 | self.__change_color(color) 82 | 83 | def __set_pan_size(self, size=2): 84 | self.penSizeLabel.setText(str(size)) 85 | self.painter.pen.setWidth(size) 86 | 87 | def __pen_btn_clicked(self): 88 | self.penSizeLabel.setText(str(2)) 89 | self.painter.setpen(pen_size=2, color=QtCore.Qt.black) 90 | 91 | def __get_pred_image(self): 92 | _, hint = self.painter.get_image() 93 | self.__status_update(1) 94 | line = self.origin_line 95 | hint = Image.fromarray(hint) 96 | img = predict(line, hint, None) 97 | 98 | if not isinstance(img, Image.Image): 99 | popup_err_dialog('Inference Err', img) 100 | 101 | return img 102 | 103 | def __run_btn_clicked(self): 104 | img = self.__get_pred_image() 105 | self.__status_update("Done") 106 | img.show() 107 | 108 | def __liner_btn_clicked(self): 109 | pass 110 | 111 | def __imread(self, path): 112 | img = Image.open(path) 113 | return img.convert('RGB') 114 | 115 | def dragEnterEvent(self, e): 116 | if e.mimeData().hasUrls: 117 | e.accept() 118 | else: 119 | e.ignore() 120 | 121 | def dragMoveEvent(self, e): 122 | if e.mimeData().hasUrls: 123 | e.accept() 124 | else: 125 | e.ignore() 126 | 127 | def wheelEvent(self, e: QtGui.QWheelEvent): 128 | modifiers = QtGui.QGuiApplication.keyboardModifiers() 129 | if modifiers == QtCore.Qt.ControlModifier: 130 | pen_size = self.penSizeSlider.value() 131 | delta = e.angleDelta().y() 132 | if delta > 0: 133 | pen_size += 1 134 | if delta < 0: 135 | pen_size -= 1 136 | self.penSizeSlider.setValue(pen_size) 137 | 138 | def dropEvent(self, e): 139 | if e.mimeData().hasUrls: 140 | e.setDropAction(QtCore.Qt.CopyAction) 141 | e.accept() 142 | for url in e.mimeData().urls(): 143 | file_name = str(url.toLocalFile()) 144 | self.__file_open(file_name) 145 | else: 146 | e.ignore() 147 | 148 | def __file_open(self, file_path=None): 149 | options = QtWidgets.QFileDialog.Options() 150 | options |= QtWidgets.QFileDialog.DontUseNativeDialog 151 | fileters = self.tr( 152 | "Image Files (*.png *.jpg *.bmp *.jpeg *.JPG *.PNG *.JPEG)") 153 | if file_path is None: 154 | file_path, _ = QtWidgets.QFileDialog.getOpenFileName( 155 | self, 'Open File', None, fileters, options=options) 156 | if file_path == "": 157 | return -1 158 | 159 | try: 160 | img = self.__imread(file_path) 161 | except Exception: 162 | self.__file_open() 163 | return -1 164 | 165 | if 'png' in file_path.lower(): 166 | img = img.convert('LA') 167 | 168 | self.origin_line = img.copy() 169 | 170 | width = float(img.size[0]) 171 | height = float(img.size[1]) 172 | 173 | if width > height: 174 | rate = width / height 175 | new_height = 512 176 | new_width = int(512 * rate) 177 | else: 178 | rate = height / width 179 | new_width = 512 180 | new_height = int(rate * 512) 181 | 182 | img = img.resize((new_width, new_height), Image.BICUBIC) 183 | self.painter.chosen_point.clear() 184 | self.painter.set_line(img, self) 185 | self.setFixedSize(new_width + 14, new_height + 170) 186 | self.update() 187 | self.__liner_flag = False 188 | 189 | def _imwrite(self, image: Image): 190 | try: 191 | fileters = self.tr( 192 | "Image Files (*.png *.jpg *.bmp *.jpeg *.JPG *.PNG *.JPEG)") 193 | options = QtWidgets.QFileDialog.Options() 194 | options |= QtWidgets.QFileDialog.DontUseNativeDialog 195 | 196 | file_path, _ = QtWidgets.QFileDialog.getSaveFileName( 197 | self, 'Save File', None, fileters, 198 | options=options) 199 | 200 | if file_path == "": 201 | return -1 202 | image.save(file_path) 203 | except Exception: 204 | self._imwrite(image) 205 | 206 | def __file_save(self): 207 | image = self.__get_pred_image() 208 | if image == -1: 209 | return -1 210 | self.__status_update(8) 211 | self.__status_update(10) 212 | self._imwrite(image) 213 | self.__status_update("Done") 214 | 215 | def __status_update(self, message): 216 | self.status.setText(str(message)) 217 | QtGui.QGuiApplication.processEvents() 218 | self.update() 219 | -------------------------------------------------------------------------------- /trainer/colorization_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm import tqdm 4 | 5 | import torch 6 | from torch import nn, optim 7 | from torchvision.utils import save_image, make_grid 8 | 9 | from models import Generator, SketchColorizationModel 10 | from trainer import opt 11 | from data import create_data_loader, DraftArgumentation 12 | 13 | 14 | class ColorizationModelTrainer(opt.TrainerBase): 15 | def __init__(self, hp: dict, model_name='ColorizationModel'): 16 | super(ColorizationModelTrainer, self).__init__(hp, model_name) 17 | 18 | hyperparameters = hp['colorization'] 19 | 20 | self.__generator = Generator(hyperparameters['in_dim'], 21 | hyperparameters['gf_dim']) 22 | self.__generator = self.__generator.to(self.device) 23 | 24 | self.__draft_model = torch.jit.load( 25 | hyperparameters['draft_model_path']).to(self.device) 26 | self.__draft_model.eval() 27 | 28 | self.__l1_loss = nn.L1Loss().to(self.device) 29 | 30 | self.__optimizer = optim.Adam( 31 | self.__generator.parameters(), 32 | lr=hyperparameters['lr'], 33 | betas=(hyperparameters['beta1'], hyperparameters['beta2'])) 34 | 35 | self.__draft_augmentation = DraftArgumentation(self.device) 36 | 37 | self.__scheduler = optim.lr_scheduler.MultiStepLR( 38 | self.__optimizer, 39 | milestones=[hyperparameters['lr_milestones']], 40 | gamma=0.1) 41 | 42 | try: 43 | ckpt = torch.load(hyperparameters['ckpt']) 44 | self.epoch = ckpt['epoch'] 45 | self.itr = ckpt['itr'] 46 | self.__generator.load_state_dict(ckpt['Generator']) 47 | self.__optimizer.load_state_dict(ckpt['adma']) 48 | self.__scheduler.load_state_dict(ckpt['scheduler']) 49 | except Exception: 50 | pass 51 | finally: 52 | print("Colorization Trainer Init Done") 53 | 54 | def train(self): 55 | """ train methods """ 56 | 57 | train_set, test_set = create_data_loader(self.hp, 'colorization') 58 | batch = next(iter(test_set)) 59 | target, hint, line, line_draft = [ 60 | data.to(self.device) for data in batch] 61 | sample_batch = (target, hint, line, line_draft) 62 | hyperparametsers = self.hp['colorization'] 63 | 64 | while self.epoch < hyperparametsers['epoch']: 65 | p_bar = tqdm(train_set, total=len(train_set)) 66 | for batch in p_bar: 67 | loss = self._train_step(batch) 68 | 69 | if self.itr % hyperparametsers['sampling_interval'] == 0: 70 | self._test_step(sample_batch) 71 | 72 | msg = 'E:%d, Itr:%d, Loss:%0.4f' % ( 73 | self.epoch + 1, self.itr, loss) 74 | p_bar.set_description(msg) 75 | self.itr += 1 76 | 77 | self._check_point() 78 | self.__scheduler.step() 79 | self.epoch += 1 80 | 81 | """ Model save as torch script """ 82 | file_name = os.path.join(self.tb.log_dir, 'torch_script') 83 | deployment_model_file_name = os.path.join( 84 | file_name, 'SketchColorizationModel.zip') 85 | file_name = os.path.join(file_name, 'Colorization_ts.zip') 86 | ts_model = torch.jit.script(self.__generator.cpu(), 87 | torch.rand([1, 3, 128, 128])) 88 | ts_model.save(file_name) 89 | 90 | deployment_model = SketchColorizationModel( 91 | hyperparametsers['gf_dim']) 92 | deployment_model_ts = torch.jit.script( 93 | deployment_model) 94 | deployment_model_ts.draft_model.load_state_dict( 95 | self.__draft_model.cpu().state_dict()) 96 | deployment_model_ts.colorization_model.load_state_dict( 97 | ts_model.state_dict()) 98 | deployment_model_ts.save(deployment_model_file_name) 99 | 100 | def _train_step(self, batch: tuple) -> float: 101 | target, hint, line, line_draft = [ 102 | data.to(self.device) for data in batch] 103 | 104 | ############# 105 | # Draft # 106 | ############# 107 | with torch.no_grad(): 108 | draft = self.__draft_model.forward(line_draft, hint) 109 | draft = self.__draft_augmentation(draft) 110 | draft = nn.functional.interpolate(draft, size=512) 111 | 112 | #################### 113 | # Generator # 114 | #################### 115 | 116 | fake_image = self.__generator(line, draft) 117 | generator_loss = self.__l1_loss(fake_image, target) 118 | 119 | self.__generator.zero_grad() 120 | generator_loss.backward() 121 | self.__optimizer.step() 122 | 123 | #################### 124 | # Logging # 125 | #################### 126 | if self.itr % self.hp['colorization']['log_interval'] == 0: 127 | self.tb.add_scalar('TRAINING/Generator.loss', 128 | generator_loss.item(), self.itr) 129 | self.tb.add_scalar('Learning Rate', 130 | self.__scheduler.get_last_lr()[0], 131 | self.itr) 132 | 133 | if self.itr % self.hp['colorization']['sampling_interval'] == 0: 134 | r, g, b, a = torch.chunk(hint, 4, 1) 135 | hint = torch.cat([r, g, b], 1) 136 | batch_size = target.size(0) 137 | hint = nn.functional.interpolate(hint, size=512) 138 | 139 | log_image = [make_grid(target, batch_size, 0, 140 | range=(-1, 1), normalize=True), 141 | make_grid(fake_image, batch_size, 0, 142 | range=(-1, 1), normalize=True), 143 | make_grid(line, batch_size, 0, 144 | range=(-1, 1), normalize=True), 145 | make_grid(hint, batch_size, 0, 146 | range=(-1, 1), normalize=True)] 147 | 148 | self.tb.add_image('TRAINING/SampleImage', 149 | make_grid(log_image, 1, 0), 150 | self.itr) 151 | return generator_loss.item() 152 | 153 | @ torch.no_grad() 154 | def _test_step(self, batch: tuple): 155 | """ Test step 156 | this section's tensor not need to trace gradient 157 | 158 | Args: 159 | batch (tuple): batch data tuple (target, hint, line) """ 160 | 161 | self.__generator.eval() 162 | target, hint, line, line_draft = [ 163 | data.to(self.device) for data in batch] 164 | zero_hint = torch.zeros_like(hint) 165 | 166 | draft = self.__draft_model.forward(line_draft, hint) 167 | draft = nn.functional.interpolate(draft, size=512) 168 | fake_image = self.__generator.forward(line, draft) 169 | 170 | draft_zero = self.__draft_model.forward(line_draft, zero_hint) 171 | draft_zero = nn.functional.interpolate(draft_zero, size=512) 172 | fake_zero_hint_image = self.__generator.forward(line, draft_zero) 173 | 174 | loss = self.__l1_loss(fake_image, target) 175 | 176 | self.tb.add_scalar('TESTING/Generator.loss', 177 | loss.item(), self.itr) 178 | 179 | r, g, b, a = torch.chunk(hint, 4, 1) 180 | hint = torch.cat([r, g, b], 1) 181 | hint = nn.functional.interpolate(hint, size=512) 182 | 183 | batch_size = target.size(0) 184 | log_image = [make_grid(target, batch_size, 0, 185 | range=(-1, 1), normalize=True), 186 | make_grid(fake_image, batch_size, 0, 187 | range=(-1, 1), normalize=True), 188 | make_grid(fake_zero_hint_image, batch_size, 0, 189 | range=(-1, 1), normalize=True), 190 | make_grid(draft, batch_size, 0, 191 | range=(-1, 1), normalize=True), 192 | make_grid(line, batch_size, 0, 193 | range=(-1, 1), normalize=True), 194 | make_grid(hint, batch_size, 0, 195 | range=(-1, 1), normalize=True)] 196 | 197 | sample_image = make_grid(log_image, 198 | 1, 0, range=(0, 1)) 199 | 200 | self.tb.add_image('TESTING/SampleImage', 201 | sample_image, 202 | self.itr) 203 | 204 | file_name = 'sample_image_GS:%d.jpg' % self.itr 205 | file_name = os.path.join(self.tb.log_dir, 'image', 206 | file_name) 207 | save_image(sample_image, file_name) 208 | self.__generator.train(True) 209 | 210 | def _check_point(self): 211 | """ Save Checkpoint objects 212 | checkpoint objects contain epoch, itr, model, optimizer, scheduler """ 213 | file_name = os.path.join(self.tb.log_dir, 214 | 'ckpt') 215 | file_name = os.path.join( 216 | file_name, 217 | 'AutoEncoder_E:%d_GS:%d.pth' % (self.epoch, self.itr)) 218 | 219 | ckpt = {'epoch': self.epoch, 220 | 'itr': self.itr, 221 | 'Generator': self.__generator, 222 | 'adma': self.__optimizer, 223 | 'scheduler': self.__scheduler} 224 | 225 | torch.save(ckpt, file_name) 226 | -------------------------------------------------------------------------------- /data/imageProcessing.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Tuple 3 | import colorsys 4 | import math 5 | import numbers 6 | import cv2 7 | import numpy as np 8 | from PIL import Image 9 | import torch 10 | from torchvision.transforms import Resize, CenterCrop, ToPILImage, ToTensor, Normalize 11 | 12 | 13 | class DraftArgumentation: 14 | """ Draft Model Output Image Argumentation """ 15 | 16 | def __init__(self, device): 17 | self.denomal = Denormalize() 18 | self.to_tensor = ToTensor() 19 | self.norm = Normalize((0.5, 0.5, 0.5), 20 | (0.5, 0.5, 0.5)) 21 | self.spay = RandomSpray() 22 | self.t2i = ToPILImage() 23 | self.device = device 24 | 25 | def __call__(self, images: torch.Tensor) -> torch.Tensor: 26 | if random.random() < 0.6: 27 | images = self.denomal(images) 28 | images = images.cpu() 29 | new_images = [] 30 | for image in images: 31 | image = self.t2i(image) 32 | image = self.spay(image) 33 | image = self.to_tensor(image) 34 | image = self.norm(image) 35 | image = torch.unsqueeze(image, 0) 36 | new_images.append(image) 37 | images = torch.cat(new_images, 0) 38 | return images.to(self.device) 39 | 40 | 41 | class RandomSpray: 42 | """ Color Spay Image Argumentation""" 43 | 44 | def __call__(self, image: Image.Image) -> Image.Image: 45 | ori_img = image 46 | color = self.get_dominant_color(ori_img) 47 | # Random Color Spray 48 | img = np.array(ori_img) 49 | 50 | h = int(img.shape[0] / 30) 51 | w = int(img.shape[1] / 30) 52 | a_x = np.random.randint(0, h) 53 | a_y = np.random.randint(0, w) 54 | b_x = np.random.randint(0, h) 55 | b_y = np.random.randint(0, w) 56 | begin_point = np.array([min(a_x, b_x), a_y]) 57 | end_point = np.array([max(a_x, b_x), b_y]) 58 | tan = (begin_point[1] - end_point[1]) / \ 59 | (begin_point[0] - end_point[0] + 0.001) 60 | 61 | center_point_list = [] 62 | for i in range(begin_point[0], end_point[0] + 1): 63 | a = i 64 | b = (i - begin_point[0]) * tan + begin_point[1] 65 | center_point_list.append(np.array([int(a), int(b)])) 66 | center_point_list = np.array(center_point_list) 67 | 68 | lamda = random.uniform(0.01, 10) 69 | paper = np.zeros((h, w, 3)) 70 | mask = np.zeros((h, w)) 71 | for i in range(h): 72 | for j in range(w): 73 | dis = self.min_dis([i, j], center_point_list) 74 | paper[i, j, :] = np.array(color) / np.exp(lamda * dis) 75 | mask[i, j] = np.array([255]) / np.exp(lamda * dis) 76 | 77 | paper = (paper).astype('uint8') 78 | mask = (mask).astype('uint8') 79 | 80 | mask = cv2.resize( 81 | mask, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) 82 | im = cv2.resize( 83 | paper, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) 84 | imq = Image.fromarray(im) 85 | imp = ori_img.copy() 86 | imp.paste( 87 | imq, (0, 0, imp.size[0], imp.size[1]), mask=Image.fromarray(mask)) 88 | return imp 89 | 90 | def min_dis(self, point, point_list: list): 91 | dis = [] 92 | for p in point_list: 93 | dis.append( 94 | np.sqrt(np.sum(np.square(np.array(point) - np.array(p))))) 95 | return min(dis) 96 | 97 | def get_dominant_color(self, image: Image.Image): 98 | image = image.convert('RGBA') 99 | image.thumbnail((200, 200)) 100 | max_score = 0 101 | dominant_color = 0 102 | 103 | for count, (r, g, b, a) in image.getcolors(image.size[0] * image.size[1]): 104 | 105 | if a == 0: 106 | continue 107 | 108 | saturation = colorsys.rgb_to_hsv( 109 | r / 255.0, g / 255.0, b / 255.0)[1] 110 | y = min(abs(r * 2104 + g * 4130 + b * 111 | 802 + 4096 + 131072) >> 13, 235) 112 | y = (y - 16.0) / (235 - 16) 113 | if y > 0.9: 114 | continue 115 | 116 | if ((r > 230) & (g > 230) & (b > 230)): 117 | continue 118 | 119 | score = (saturation + 0.1) * count 120 | 121 | if score > max_score: 122 | max_score = score 123 | dominant_color = (r, g, b) 124 | 125 | return dominant_color 126 | 127 | 128 | class RandomCrop: 129 | """ Image Pairs Randomly Crop 130 | """ 131 | 132 | def __init__(self, size: int): 133 | """ 134 | Args: 135 | size (int): Crop Size 136 | """ 137 | if isinstance(size, numbers.Number): 138 | self.size = (int(size), int(size)) 139 | else: 140 | self.size = size 141 | 142 | def __call__(self, img1: Image, img2: Image): 143 | w, h = img1.size 144 | th, tw = self.size 145 | 146 | if w == tw and h == th: 147 | return img1, img2 148 | 149 | if w == tw: 150 | x1 = 0 151 | y1 = random.randint(0, h - th) 152 | return img1.crop((x1, y1, x1 + tw, y1 + th)), img2.crop((x1, y1, x1 + tw, y1 + th)) 153 | elif h == th: 154 | x1 = random.randint(0, w - tw) 155 | y1 = 0 156 | return img1.crop((x1, y1, x1 + tw, y1 + th)), img2.crop((x1, y1, x1 + tw, y1 + th)) 157 | else: 158 | x1 = random.randint(0, w - tw) 159 | y1 = random.randint(0, h - th) 160 | return img1.crop((x1, y1, x1 + tw, y1 + th)), img2.crop((x1, y1, x1 + tw, y1 + th)) 161 | 162 | 163 | class Tensor2Image(): 164 | def __init__(self): 165 | self.__t2p = ToPILImage() 166 | 167 | def __call__(self, images): 168 | new_images = [] 169 | for image in images: 170 | img = self.__t2p(image).convert("RGB") 171 | new_images.append(img) 172 | return new_images 173 | 174 | 175 | class Denormalize: 176 | 177 | def __call__(self, tensor: torch.Tensor) -> torch.Tensor: 178 | ''' 179 | :param tesnsor: tensor range -1 to 1 180 | :return: tensor range 0 to 1 181 | ''' 182 | tensor = tensor.cpu() 183 | return (tensor + 1.0) / 2.0 184 | 185 | 186 | class Kmeans: 187 | def __init__(self, k: int): 188 | self.__k = k 189 | 190 | def __call__(self, image: Image): 191 | image = np.array(image) 192 | z = image.reshape(-1, 3) 193 | z = np.float32(z) 194 | criteria = (cv2.TERM_CRITERIA_EPS + 195 | cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0) 196 | ret, label, center = cv2.kmeans( 197 | z, self.__k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS) 198 | 199 | # Now convert back into uint8, and make original image 200 | center = np.uint8(center) 201 | res = center[label.flatten()] 202 | image = res.reshape(image.shape) 203 | return Image.fromarray(image) 204 | 205 | 206 | class RandomSizedCrop(object): 207 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 208 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 209 | This is popularly used to train the Inception networks 210 | size: size of the smaller edge 211 | interpolation: Default: PIL.Image.BILINEAR 212 | """ 213 | 214 | def __init__(self, size, interpolation=Image.BICUBIC): 215 | self.size = size 216 | self.interpolation = interpolation 217 | 218 | def __call__(self, img): 219 | for attempt in range(10): 220 | area = img.size[0] * img.size[1] 221 | target_area = random.uniform(0.9, 1.) * area 222 | aspect_ratio = random.uniform(7. / 8, 8. / 7) 223 | 224 | w = int(round(math.sqrt(target_area * aspect_ratio))) 225 | h = int(round(math.sqrt(target_area / aspect_ratio))) 226 | 227 | if random.random() < 0.5: 228 | w, h = h, w 229 | 230 | if w <= img.size[0] and h <= img.size[1]: 231 | x1 = random.randint(0, img.size[0] - w) 232 | y1 = random.randint(0, img.size[1] - h) 233 | img = img.crop((x1, y1, x1 + w, y1 + h)) 234 | assert (img.size == (w, h)) 235 | 236 | return img.resize((self.size, self.size), self.interpolation) 237 | 238 | # Fallback 239 | scale = Resize(self.size, interpolation=self.interpolation) 240 | crop = CenterCrop(self.size) 241 | return crop(scale(img)) 242 | 243 | 244 | def dilate_abs_line(image: Image) -> Image: 245 | image = np.asarray(image) 246 | k = 5 if random.random() >= 0.5 else 4 247 | kernel = np.ones([k, k], dtype=np.uint8) 248 | dilated = cv2.dilate(image, kernel=kernel) 249 | diff = cv2.absdiff(dilated, image) 250 | line = 255 - diff 251 | return Image.fromarray(line).convert("L") 252 | 253 | 254 | def xdog(img: Image, k_sigma: float = 4.5, 255 | p: float = 0.95, epsilon: float = -0.1, 256 | phi: float = 200) -> Image: 257 | 258 | img = np.asarray(img) 259 | sigma = random.choice([0.3, 0.4, 0.5]) 260 | 261 | def soft_threshold(si, epsilon, phi): 262 | t = np.zeros(si.shape) 263 | si_bright = si >= epsilon 264 | si_dark = si < epsilon 265 | t[si_dark] = 1.0 266 | t[si_bright] = 1.0 + np.tanh(phi * (si[si_bright])) 267 | return t 268 | 269 | def _xdog(img_1, sigma, k_sigma, p, epsilon, phi): 270 | s = dog(img_1, sigma, k_sigma, p) 271 | t = soft_threshold(s, epsilon, phi) 272 | return (t * 127.5).astype(np.uint8) 273 | 274 | def dog(img_2, sigma, k_sigma, p): 275 | sigma_large = sigma * k_sigma 276 | g_small = cv2.GaussianBlur(img_2, (0, 0), sigma) 277 | g_large = cv2.GaussianBlur(img_2, (0, 0), sigma_large) 278 | s = g_small - p * g_large 279 | return s 280 | 281 | line = _xdog(img, sigma=sigma, k_sigma=k_sigma, 282 | p=p, epsilon=epsilon, phi=phi) 283 | return Image.fromarray(line).convert("L") 284 | 285 | 286 | def random_flip(color: Image, 287 | line: Image) -> (Image, Image): 288 | if random.random() > 0.5: 289 | color = color.transpose(Image.FLIP_LEFT_RIGHT) 290 | line = line.transpose(Image.FLIP_LEFT_RIGHT) 291 | 292 | if random.random() > 0.5: 293 | color = color.transpose(Image.FLIP_TOP_BOTTOM) 294 | line = line.transpose(Image.FLIP_TOP_BOTTOM) 295 | 296 | return (color, line) 297 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Tuple 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from PIL import Image, ImageOps 7 | from torch.utils.data import Dataset 8 | from torchvision.transforms import Lambda, Normalize, Resize, ToTensor, Compose, ColorJitter 9 | from torchvision.transforms.functional import to_tensor 10 | from abc import ABCMeta 11 | from data import xdog, dilate_abs_line, RandomCrop, random_flip 12 | 13 | 14 | class DatasetBase: 15 | """ Image Data Loader Base Class """ 16 | 17 | @staticmethod 18 | def _jitter(line: torch.Tensor) -> torch.Tensor: 19 | """ Line Arts Image Luminance Data-Augmentation 20 | 21 | Args: 22 | line (Tensor): line-arts image Tensor 23 | 24 | Returns: 25 | Tensor: line-arts data 26 | """ 27 | ran = random.uniform(0.7, 1) 28 | line = line * ran + (1 - ran) 29 | return torch.clamp( 30 | line, min=0, max=1) 31 | 32 | @staticmethod 33 | def _create_mask(zero_p: float = 0.4) -> torch.Tensor: 34 | """ Create Random Hint Mask (0 or 1 binary 2D Mask) 35 | 36 | Args: 37 | zero_p (float): 38 | Probability flag for zero hint count Defaults to 0.4. 39 | 40 | Returns: 41 | Tensor: Binary 2D Hint Mask 42 | """ 43 | hint_count = 0 44 | 45 | if random.random() < zero_p: 46 | if random.random() < 0.4: 47 | hint_count = random.randint(1, 5) 48 | else: 49 | hint_count = random.randint( 50 | random.randint(5, 32), 51 | random.randint(42, 65)) 52 | 53 | area = 128 * 128 54 | 55 | zero = np.zeros(shape=[area - hint_count], dtype=np.uint8) 56 | one = np.ones(shape=[hint_count], dtype=np.uint8) 57 | mask = np.concatenate([zero, one], -1) 58 | np.random.shuffle(mask) 59 | mask = np.reshape(mask, newshape=[128, 128]) * 255 60 | _, mask = cv2.threshold(mask, 61 | 127, 255, 62 | cv2.THRESH_BINARY) 63 | return to_tensor(mask) 64 | 65 | @staticmethod 66 | def _create_line(image: Image) -> Image: 67 | """ Create Line-arts Image 68 | 69 | Args: 70 | image (Image): Color PIL Image (target Image) 71 | 72 | Returns: 73 | Image: Greyscale PIL Image (Line-Arts) 74 | """ 75 | if random.random() > 0.5: 76 | return xdog(image) 77 | else: 78 | return dilate_abs_line(image) 79 | 80 | 81 | class DraftModelDataset(Dataset, DatasetBase): 82 | def __init__(self, 83 | image_paths: list, 84 | training: bool, 85 | size: int = 128): 86 | 87 | self._image_paths = image_paths 88 | self._training = training 89 | self._random_crop = RandomCrop(512) 90 | self._color_compose = Compose([ 91 | Resize(size), 92 | ToTensor(), 93 | Normalize((0.5, 0.5, 0.5), 94 | (0.5, 0.5, 0.5)) 95 | ]) 96 | 97 | compose_processing = [Resize(size), ToTensor()] 98 | 99 | if training: 100 | compose_processing.append(Lambda(self._jitter)) 101 | 102 | compose_processing.append(Normalize([0.5], [0.5])) 103 | 104 | self._line_compose = Compose(compose_processing) 105 | 106 | self._color_jitter = ColorJitter(brightness=0, 107 | contrast=0.1, 108 | saturation=0.1, 109 | hue=0.03) 110 | 111 | def __len__(self): 112 | return len(self._image_paths) 113 | 114 | def __getitem__(self, item) -> (torch.Tensor, 115 | torch.Tensor, 116 | torch.Tensor): 117 | 118 | paths = self._image_paths[item] 119 | target_image = Image.open(paths[0]).convert('RGB') 120 | if random.random() > 0.0001: 121 | if random.random() > 0.5: 122 | line_image = dilate_abs_line(target_image) 123 | else: 124 | line_image = Image.open(paths[1]).convert('L') 125 | else: 126 | # Data argumentation color image to greyscale image 127 | line_image = target_image.convert('L') 128 | 129 | target_image, line_image = self._random_crop(target_image, line_image) 130 | 131 | # Data argumentation 132 | if self._training is True: 133 | target_image, line_image = self._argumentation( 134 | target_image, line_image) 135 | mask = self._create_mask() 136 | else: 137 | mask = self._create_mask(0) 138 | 139 | # Preprocessing 140 | target_image = self._color_compose(target_image) 141 | line_image = self._line_compose(line_image) 142 | 143 | # Build Hint 144 | hint = target_image.clone() 145 | hint = hint * mask 146 | hint_image = torch.cat([hint, mask], 0) 147 | 148 | return target_image, hint_image, line_image 149 | 150 | def _argumentation(self, 151 | target: Image, 152 | line: Image) -> (Image, Image): 153 | """ Data Argumentataion """ 154 | line = ImageOps.equalize(line) if random.random() >= 0.5 else line 155 | 156 | target, line = random_flip(target, line) 157 | target = self._color_jitter(target) 158 | return target, line 159 | 160 | 161 | class ColorizationModelDataset(DraftModelDataset): 162 | def __init__(self, image_paths: list, training: bool): 163 | super(ColorizationModelDataset, self).__init__( 164 | image_paths, training) 165 | 166 | self._hint_compos = Compose([ 167 | Resize(128), 168 | ToTensor(), 169 | Normalize((0.5, 0.5, 0.5), 170 | (0.5, 0.5, 0.5)) 171 | ]) 172 | 173 | self._color_compose = Compose([ 174 | ToTensor(), 175 | Normalize((0.5, 0.5, 0.5), 176 | (0.5, 0.5, 0.5)) 177 | ]) 178 | 179 | self._line_draft_compose = Compose([ 180 | Resize(128), 181 | ToTensor(), 182 | Normalize([0.5], [0.5]) 183 | ]) 184 | 185 | self._line_compose = Compose([ 186 | ToTensor(), 187 | Normalize([0.5], [0.5]) 188 | ]) 189 | 190 | def __len__(self): 191 | return len(self._image_paths) 192 | 193 | def __getitem__(self, item) -> (torch.Tensor, torch.Tensor, 194 | torch.Tensor, torch.Tensor): 195 | paths = self._image_paths[item] 196 | target_image = Image.open(paths[0]).convert('RGB') 197 | 198 | if random.random() > 0.0001: 199 | if random.random() > 0.5: 200 | line_image = dilate_abs_line(target_image) 201 | else: 202 | line_image = Image.open(paths[1]).convert('L') 203 | else: 204 | # Data argumentation color image to greyscale image 205 | line_image = target_image.convert('L') 206 | 207 | target_image, line_image = self._random_crop(target_image, 208 | line_image) 209 | 210 | if self._training is True: 211 | target_image, line_image = self._argumentation(target_image, 212 | line_image) 213 | 214 | hint_image = self._hint_compos(target_image) 215 | target_image = self._color_compose(target_image) 216 | line_draft = self._line_draft_compose(line_image) 217 | line_image = self._line_compose(line_image) 218 | 219 | mask = self._create_mask() 220 | hint_image = torch.cat([hint_image * mask, mask], 0) 221 | 222 | return target_image, hint_image, line_image, line_draft 223 | 224 | @ staticmethod 225 | def _create_mask() -> torch.Tensor: 226 | area = 128 * 128 227 | hint_count = random.randint(28, 128) 228 | 229 | zero = np.zeros(shape=[area - hint_count], dtype=np.uint8) 230 | one = np.ones(shape=[hint_count], dtype=np.uint8) 231 | mask = np.concatenate([zero, one], -1) 232 | np.random.shuffle(mask) 233 | mask = np.reshape(mask, newshape=[128, 128]) * 255 234 | _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) 235 | 236 | return to_tensor(mask) 237 | 238 | def _argumentation(self, 239 | target: Image, 240 | line: Image) -> (Image, Image): 241 | """ Data Argumentataion """ 242 | target, line = random_flip(target, line) 243 | target = self._color_jitter(target) 244 | return target, line 245 | 246 | 247 | class AutoEncoderDataset(Dataset, DatasetBase): 248 | def __init__(self, image_paths: list, training: bool): 249 | """ 250 | :param image_paths:color_image_path list 251 | :param training: boolean flag for DataArgumentation 252 | :param size: resize size 253 | """ 254 | self._image_paths = image_paths 255 | self._random_crop = RandomCrop(512) 256 | self._image_compose = Compose([ 257 | Resize(128), 258 | ToTensor(), 259 | Normalize((0.5, 0.5, 0.5), 260 | (0.5, 0.5, 0.5)) 261 | ]) 262 | self._line_compose = Compose([ 263 | Resize(128), 264 | ToTensor(), 265 | Normalize([0.5], [0.5]) 266 | ]) 267 | 268 | self._color_jitter = ColorJitter(brightness=0, 269 | contrast=0.1, 270 | saturation=0.1, 271 | hue=0.03) 272 | 273 | def __getitem__(self, item) -> (Image, Image): 274 | paths = self._image_paths[item] 275 | target = Image.open( 276 | paths[0]).convert('RGB') 277 | 278 | if random.random() > 0.5: 279 | line = dilate_abs_line(target) 280 | else: 281 | line = Image.open(paths[1]).convert('L') 282 | 283 | target, line = self._random_crop(target, line) 284 | target, line = self._argumentation(target, line) 285 | target = self._image_compose(target) 286 | line = self._line_compose(line) 287 | return target, line 288 | 289 | def __len__(self): 290 | return len(self._image_paths) 291 | 292 | def _argumentation(self, 293 | target: Image, 294 | line: Image) -> (Image, Image): 295 | """ Data Argumentataion """ 296 | target, line = random_flip(target, line) 297 | target = self._color_jitter(target) 298 | return target, line 299 | -------------------------------------------------------------------------------- /trainer/draftmodel_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | from tqdm import tqdm 5 | import yaml 6 | 7 | import torch 8 | from torch import nn, optim 9 | from torchvision.utils import save_image, make_grid 10 | from torch.utils import tensorboard 11 | 12 | from models import Generator, Discriminator 13 | from trainer import opt 14 | from data import create_data_loader 15 | 16 | 17 | class DraftModelTrainer(opt.TrainerBase): 18 | def __init__(self, hp: dict, model_name='DraftModel'): 19 | super(DraftModelTrainer, self).__init__(hp, model_name) 20 | 21 | hyperparameters = hp['draft'] 22 | 23 | self.__generator = Generator(hyperparameters['in_dim'], 24 | hyperparameters['gf_dim']) 25 | self.__generator = self.__generator.to(self.device) 26 | 27 | self.__discriminator = Discriminator(hyperparameters['gf_dim']) 28 | self.__discriminator = self.__discriminator.to(self.device) 29 | 30 | self.__autoencoder = torch.jit.load( 31 | hyperparameters['autoencoder_path']) 32 | self.__autoencoder = self.__autoencoder.to(self.device) 33 | self.__autoencoder.eval() 34 | 35 | self.__gan_loss = opt.GANLoss().to(self.device) 36 | self.__content_loss = opt.ContentLoss().to(self.device) 37 | self.__l1_loss = nn.L1Loss().to(self.device) 38 | 39 | self.__w_gan = hyperparameters['w_gan'] 40 | self.__w_recon = hyperparameters['w_recon'] 41 | self.__w_cont = hyperparameters['w_cont'] 42 | self.__w_line = hyperparameters['w_line'] 43 | 44 | self.__optimizer_generator = optim.Adam( 45 | self.__generator.parameters(), 46 | lr=hyperparameters['lr'], 47 | betas=(hyperparameters['beta1'], hyperparameters['beta2'])) 48 | 49 | self.__optimizer_discriminator = optim.Adam( 50 | self.__discriminator.parameters(), 51 | lr=hyperparameters['lr'], 52 | betas=(hyperparameters['beta1'], hyperparameters['beta2'])) 53 | 54 | self.__opti_gen_scheduler = optim.lr_scheduler.MultiStepLR( 55 | self.__optimizer_generator, 56 | milestones=[hyperparameters['lr_milestones']], 57 | gamma=0.1) 58 | 59 | self.__opti_dis_scheduler = optim.lr_scheduler.MultiStepLR( 60 | self.__optimizer_discriminator, 61 | milestones=[hyperparameters['lr_milestones']], 62 | gamma=0.1) 63 | 64 | try: 65 | ckpt = torch.load(hyperparameters['ckpt']) 66 | self.epoch = ckpt['epoch'] 67 | self.itr = ckpt['itr'] 68 | self.__generator.load_state_dict(ckpt['Generator']) 69 | self.__discriminator.load_state_dict(ckpt['Discriminator']) 70 | self.__optimizer_generator.load_state_dict( 71 | ckpt['adma_generator']) 72 | self.__optimizer_discriminator.load_state_dict( 73 | ckpt['adma_discriminator']) 74 | self.__opti_gen_scheduler.load_state_dict( 75 | ckpt['scheduler_generator']) 76 | self.__opti_dis_scheduler.load_state_dict( 77 | ckpt['scheduler_discriminator']) 78 | except Exception as e: 79 | pass 80 | finally: 81 | print("DratfModel Trainer Init Done") 82 | 83 | def train(self): 84 | """ train methods """ 85 | 86 | train_set, test_set = create_data_loader(self.hp, 'draft') 87 | batch = next(iter(test_set)) 88 | target, hint, line = [data.to(self.device) for data in batch] 89 | sample_batch = (target, hint, line) 90 | hyperparametsers = self.hp['draft'] 91 | 92 | while self.epoch < hyperparametsers['epoch']: 93 | p_bar = tqdm(train_set, total=len(train_set)) 94 | for batch in p_bar: 95 | loss = self._train_step(batch) 96 | 97 | if self.itr % hyperparametsers['sampling_interval'] == 0: 98 | self._test_step(sample_batch) 99 | 100 | msg = 'E:%d, Itr:%d, Loss:%0.4f' % ( 101 | self.epoch + 1, self.itr, loss) 102 | p_bar.set_description(msg) 103 | self.itr += 1 104 | 105 | self._check_point() 106 | self.__opti_dis_scheduler.step() 107 | self.__opti_gen_scheduler.step() 108 | self.epoch += 1 109 | 110 | """ Model save as torch script """ 111 | file_name = os.path.join(self.tb.log_dir, 'torch_script') 112 | file_name = os.path.join(file_name, 'DraftModel_ts.zip') 113 | ts_model = torch.jit.script(self.__generator.cpu(), 114 | torch.rand([1, 3, 128, 128])) 115 | ts_model.save(file_name) 116 | 117 | def _train_step(self, batch: tuple) -> float: 118 | 119 | target, hint, line = [data.to(self.device) for data in batch] 120 | 121 | #################### 122 | # Discriminator # 123 | #################### 124 | fake_image = self.__generator(line, hint) # G(l,h) 125 | fake_dis = self.__discriminator(fake_image.detach()) # D(G(l,h)) 126 | real_dis = self.__discriminator(target) # D(c) 127 | 128 | fake_loss = self.__gan_loss(fake_dis, False) 129 | real_loss = self.__gan_loss(real_dis, True) 130 | discriminator_loss = fake_loss + real_loss 131 | 132 | self.__discriminator.zero_grad() 133 | discriminator_loss.backward() 134 | self.__optimizer_discriminator.step() 135 | 136 | #################### 137 | # Generator # 138 | #################### 139 | _fake_dis = self.__discriminator(fake_image) 140 | 141 | with torch.no_grad(): 142 | fake_line = self.__autoencoder(fake_image) 143 | real_line = self.__autoencoder(target) 144 | 145 | adv_loss = self.__gan_loss(_fake_dis, True) 146 | recon_loss = self.__l1_loss(fake_image, target) 147 | content_loss = self.__content_loss(fake_image, target) 148 | line_loss = self.__l1_loss(fake_line, real_line) 149 | 150 | generator_loss = (adv_loss * self.__w_gan) \ 151 | + (recon_loss * self.__w_recon) \ 152 | + (content_loss * self.__w_cont) \ 153 | + (line_loss * self.__w_line) 154 | 155 | self.__generator.zero_grad() 156 | generator_loss.backward() 157 | self.__optimizer_generator.step() 158 | 159 | #################### 160 | # Logging # 161 | #################### 162 | if self.itr % self.hp['draft']['log_interval'] == 0: 163 | self.tb.add_scalar('TRAINING/Discriminator.loss', 164 | discriminator_loss.item(), self.itr) 165 | self.tb.add_scalar('TRAINING/Discriminator.loss.fake', 166 | fake_loss.item(), self.itr) 167 | self.tb.add_scalar('TRAINING/Discriminator.loss.real', 168 | real_loss.item(), self.itr) 169 | 170 | self.tb.add_scalar('TRAINING/Generator.loss', 171 | generator_loss.item(), self.itr) 172 | self.tb.add_scalar('TRAINING/Generator.loss.adv', 173 | adv_loss.item(), self.itr) 174 | self.tb.add_scalar('TRAINING/Generator.loss.recon', 175 | recon_loss.item(), self.itr) 176 | self.tb.add_scalar('TRAINING/Generator.loss.content', 177 | content_loss.item(), self.itr) 178 | self.tb.add_scalar('TRAINING/Generator.line', 179 | line_loss.item(), self.itr) 180 | 181 | self.tb.add_scalar('Learning Rate', 182 | self.__opti_dis_scheduler.get_last_lr()[0], 183 | self.itr) 184 | 185 | if self.itr % self.hp['draft']['sampling_interval'] == 0: 186 | r, g, b, a = torch.chunk(hint, 4, 1) 187 | hint = torch.cat([r, g, b], 1) 188 | batch_size = target.size(0) 189 | 190 | log_image = [make_grid(target, batch_size, 0, 191 | range=(-1, 1), normalize=True), 192 | make_grid(fake_image, batch_size, 0, 193 | range=(-1, 1), normalize=True), 194 | make_grid(line, batch_size, 0, 195 | range=(-1, 1), normalize=True), 196 | make_grid(hint, batch_size, 0, 197 | range=(-1, 1), normalize=True)] 198 | 199 | self.tb.add_image('TRAINING/SampleImage', 200 | make_grid(log_image, 1, 0), 201 | self.itr) 202 | return generator_loss.item() 203 | 204 | @ torch.no_grad() 205 | def _test_step(self, batch: tuple): 206 | """ Test step 207 | this section's tensor not need to trace gradient 208 | 209 | Args: 210 | batch (tuple): batch data tuple (target, hint, line) """ 211 | 212 | self.__generator.eval() 213 | target, hint, line = [data for data in batch] 214 | zero_hint = torch.zeros_like(hint) 215 | fake_image = self.__generator.forward(line, hint) 216 | fake_zero_hint_image = \ 217 | self.__generator.forward(line, zero_hint) 218 | fake_line = self.__autoencoder(fake_image) 219 | real_line = self.__autoencoder(target) 220 | 221 | recon_loss = self.__l1_loss(fake_image, target) 222 | content_loss = self.__content_loss(fake_image, target) 223 | line_loss = self.__l1_loss(fake_line, real_line) 224 | 225 | self.tb.add_scalar('TESTING/Generator.loss.recon', 226 | recon_loss.item(), self.itr) 227 | self.tb.add_scalar('TESTING/Generator.loss.content', 228 | content_loss.item(), self.itr) 229 | self.tb.add_scalar('TESTING/Generator.line', 230 | line_loss.item(), self.itr) 231 | 232 | r, g, b, a = torch.chunk(hint, 4, 1) 233 | hint = torch.cat([r, g, b], 1) 234 | batch_size = target.size(0) 235 | log_image = [make_grid(target, batch_size, 0, 236 | range=(-1, 1), normalize=True), 237 | make_grid(fake_image, batch_size, 0, 238 | range=(-1, 1), normalize=True), 239 | make_grid(fake_zero_hint_image, batch_size, 0, 240 | range=(-1, 1), normalize=True), 241 | make_grid(line, batch_size, 0, 242 | range=(-1, 1), normalize=True), 243 | make_grid(hint, batch_size, 0, 244 | range=(-1, 1), normalize=True)] 245 | 246 | sample_image = make_grid(log_image, 247 | 1, 0, range=(0, 1)) 248 | 249 | self.tb.add_image('TESTING/SampleImage', 250 | sample_image, 251 | self.itr) 252 | 253 | file_name = 'sample_image_GS:%d.jpg' % self.itr 254 | file_name = os.path.join(self.tb.log_dir, 'image', 255 | file_name) 256 | save_image(sample_image, file_name) 257 | self.__generator.train(True) 258 | 259 | def _check_point(self): 260 | """ Save Checkpoint objects 261 | checkpoint objects contain epoch, itr, model, optimizer, scheduler """ 262 | file_name = os.path.join(self.tb.log_dir, 263 | 'ckpt') 264 | file_name = os.path.join( 265 | file_name, 266 | 'AutoEncoder_E:%d_GS:%d.pth' % (self.epoch, self.itr)) 267 | 268 | ckpt = {'epoch': self.epoch, 269 | 'itr': self.itr, 270 | 'Generator': self.__generator, 271 | 'Discriminator': self.__discriminator, 272 | 'adma_generator': self.__optimizer_generator, 273 | 'adma_discriminator': self.__optimizer_discriminator, 274 | 'scheduler_generator': self.__opti_gen_scheduler, 275 | 'scheduler_discriminator': self.__opti_dis_scheduler} 276 | torch.save(ckpt, file_name) 277 | -------------------------------------------------------------------------------- /app/src/main/resources/base/MAINUI.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | MainWindow 4 | 5 | 6 | 7 | 0 8 | 0 9 | 684 10 | 717 11 | 12 | 13 | 14 | 15 | 0 16 | 0 17 | 18 | 19 | 20 | 21 | 0 22 | 0 23 | 24 | 25 | 26 | MainWindow 27 | 28 | 29 | true 30 | 31 | 32 | true 33 | 34 | 35 | 36 | 37 | 0 38 | 0 39 | 40 | 41 | 42 | false 43 | 44 | 45 | true 46 | 47 | 48 | 49 | 50 | 51 | 52 | QLayout::SetMinAndMaxSize 53 | 54 | 55 | 4 56 | 57 | 58 | 4 59 | 60 | 61 | 4 62 | 63 | 64 | 4 65 | 66 | 67 | 68 | 69 | 70 | 0 71 | 0 72 | 73 | 74 | 75 | 76 | 0 77 | 15 78 | 79 | 80 | 81 | 82 | 16777215 83 | 15 84 | 85 | 86 | 87 | Done 88 | 89 | 90 | 91 | 92 | 93 | 94 | QFrame::StyledPanel 95 | 96 | 97 | 98 | 0 99 | 100 | 101 | 0 102 | 103 | 104 | 0 105 | 106 | 107 | 0 108 | 109 | 110 | 0 111 | 112 | 113 | 114 | 115 | 116 | 0 117 | 0 118 | 119 | 120 | 121 | QFrame::NoFrame 122 | 123 | 124 | 125 | QLayout::SetFixedSize 126 | 127 | 128 | 2 129 | 130 | 131 | 2 132 | 133 | 134 | 2 135 | 136 | 137 | 2 138 | 139 | 140 | 141 | 142 | 143 | 0 144 | 0 145 | 146 | 147 | 148 | 149 | 16777215 150 | 60 151 | 152 | 153 | 154 | false 155 | 156 | 157 | QFrame::StyledPanel 158 | 159 | 160 | QFrame::Plain 161 | 162 | 163 | 164 | 165 | 166 | 167 | 0 168 | 0 169 | 170 | 171 | 172 | color:black; 173 | background-color: #000000; 174 | border-style: outset; 175 | border-width: 1px; 176 | border-radius: 5px; 177 | border-color: black; 178 | 179 | 180 | 181 | 182 | 183 | Ctrl+C 184 | 185 | 186 | false 187 | 188 | 189 | false 190 | 191 | 192 | false 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 0 201 | 0 202 | 203 | 204 | 205 | 206 | 207 | 208 | Ctrl+Z 209 | 210 | 211 | false 212 | 213 | 214 | false 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 0 223 | 0 224 | 225 | 226 | 227 | RUN 228 | 229 | 230 | Ctrl+R 231 | 232 | 233 | false 234 | 235 | 236 | false 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 0 248 | 0 249 | 250 | 251 | 252 | QFrame::StyledPanel 253 | 254 | 255 | 256 | 257 | 258 | true 259 | 260 | 261 | 262 | 0 263 | 0 264 | 265 | 266 | 267 | 268 | 0 269 | 20 270 | 271 | 272 | 273 | 274 | 60 275 | 25 276 | 277 | 278 | 279 | Qt::NoFocus 280 | 281 | 282 | Qt::LeftToRight 283 | 284 | 285 | QFrame::NoFrame 286 | 287 | 288 | size 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 0 297 | 0 298 | 299 | 300 | 301 | 302 | 17 303 | 20 304 | 305 | 306 | 307 | 308 | 11 309 | 25 310 | 311 | 312 | 313 | 314 | true 315 | 316 | 317 | 318 | -1 319 | 320 | 321 | QFrame::StyledPanel 322 | 323 | 324 | 2 325 | 326 | 327 | Qt::AutoText 328 | 329 | 330 | true 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 0 339 | 0 340 | 341 | 342 | 343 | 2 344 | 345 | 346 | 10 347 | 348 | 349 | 2 350 | 351 | 352 | Qt::Horizontal 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 0 364 | 0 365 | 366 | 367 | 368 | true 369 | 370 | 371 | true 372 | 373 | 374 | 375 | 1 376 | 377 | 378 | 0 379 | 380 | 381 | 0 382 | 383 | 384 | 0 385 | 386 | 387 | 0 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 0 404 | 0 405 | 684 406 | 22 407 | 408 | 409 | 410 | 411 | F&ile 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | &open 421 | 422 | 423 | 424 | 425 | &save 426 | 427 | 428 | 429 | 430 | 1 431 | 432 | 433 | 434 | 435 | &2 436 | 437 | 438 | 439 | 440 | &3 441 | 442 | 443 | 444 | 445 | &4 446 | 447 | 448 | 449 | 450 | &5 451 | 452 | 453 | 454 | 455 | &6 456 | 457 | 458 | 459 | 460 | &7 461 | 462 | 463 | 464 | 465 | &open 466 | 467 | 468 | Ctrl+O 469 | 470 | 471 | 472 | 473 | &save 474 | 475 | 476 | Ctrl+S 477 | 478 | 479 | 480 | 481 | &connection 482 | 483 | 484 | Ctrl+. 485 | 486 | 487 | 488 | 489 | eBtn 490 | colorBtn 491 | 492 | 493 | 494 | 495 | --------------------------------------------------------------------------------