├── .gitignore ├── LICENSE ├── MangaLineExtraction.ipynb ├── README.md ├── assets ├── color.png ├── comparison.png ├── gallery1.png ├── gallery2.png └── teaser.png ├── model_torch.py ├── pytorchResults └── PrismHeart_079.png ├── pytorchTestCases └── PrismHeart_079.jpg └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | # pytorch related 133 | *.pth -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Miaomiao Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MangaLineExtraction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "MangaLineExtraction.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "machine_shape": "hm" 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | }, 18 | "accelerator": "GPU" 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "Rismr8AnFlqB" 25 | }, 26 | "source": [ 27 | "## MangaLineExtraction_Pytorch\n", 28 | "\n", 29 | "_This is an interactive demo of the paper [\"Deep Extraction of Manga Structural Lines\"](https://www.cse.cuhk.edu.hk/~ttwong/papers/linelearn/linelearn.html)_\n", 30 | "\n", 31 | "Firstly run the follwing cell to get the enviornment set up. Please ensure you have the GPU runtime setting set to \"on\"." 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "metadata": { 37 | "cellView": "form", 38 | "id": "UxDL2nO2-_Wq" 39 | }, 40 | "source": [ 41 | "#@title Environment setup\n", 42 | "\n", 43 | "%cd ~\n", 44 | "! git clone https://github.com/ljsabc/MangaLineExtraction_PyTorch.git\n", 45 | "%cd MangaLineExtraction_PyTorch\n", 46 | "! wget -O erika.pth https://github.com/ljsabc/MangaLineExtraction_PyTorch/releases/download/v1/erika.pth\n", 47 | "\n", 48 | "\n", 49 | "import torch\n", 50 | "import cv2\n", 51 | "\n", 52 | "from google.colab import files\n", 53 | "import os\n", 54 | "import numpy as np\n", 55 | "from google.colab.patches import cv2_imshow\n", 56 | "\n", 57 | "from model_torch import res_skip\n", 58 | "\n", 59 | "model = res_skip()\n", 60 | "model.load_state_dict(torch.load('erika.pth'))\n", 61 | "\n", 62 | "model.cuda();\n", 63 | "model.eval();\n", 64 | "\n", 65 | "print(\"Setup Complete\")" 66 | ], 67 | "execution_count": null, 68 | "outputs": [] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": { 73 | "id": "mZnrHmxFGN4k" 74 | }, 75 | "source": [ 76 | "### Test with your own image\n", 77 | "\n", 78 | "After the environment setup, run this cell to test with your own image. When the file upload button emerge in the output, select any picture from your local device and wait for the code to run. The output will be shown on the bottom. \n", 79 | "\n", 80 | "Right click on the result to save the output. Re-run this cell to upload and process again for a new round." 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "metadata": { 86 | "cellView": "form", 87 | "id": "xiMxM8ctCxIT" 88 | }, 89 | "source": [ 90 | "#@title File upload and processing\n", 91 | "\n", 92 | "uploaded = files.upload()\n", 93 | "outputLoc = None\n", 94 | "with torch.no_grad():\n", 95 | " for imname in uploaded.keys():\n", 96 | " srcc = cv2.imread(imname)\n", 97 | " print(\"Original Image:\")\n", 98 | " cv2_imshow(srcc)\n", 99 | "\n", 100 | " src = cv2.imread(imname,cv2.IMREAD_GRAYSCALE)\n", 101 | " \n", 102 | " rows = int(np.ceil(src.shape[0]/16))*16\n", 103 | " cols = int(np.ceil(src.shape[1]/16))*16\n", 104 | " \n", 105 | " # manually construct a batch. You can change it based on your usecases. \n", 106 | " patch = np.ones((1,1,rows,cols),dtype=\"float32\")\n", 107 | " patch[0,0,0:src.shape[0],0:src.shape[1]] = src\n", 108 | "\n", 109 | " tensor = torch.from_numpy(patch).cuda()\n", 110 | " y = model(tensor)\n", 111 | " print(imname, torch.max(y), torch.min(y))\n", 112 | "\n", 113 | " yc = y.cpu().numpy()[0,0,:,:]\n", 114 | " yc[yc>255] = 255\n", 115 | " yc[yc<0] = 0\n", 116 | "\n", 117 | " head, tail = os.path.split(imname)\n", 118 | " if not os.path.exists(\"output\"):\n", 119 | " os.mkdir(\"output\")\n", 120 | "\n", 121 | " print(\"Output Image:\")\n", 122 | " output = yc[0:src.shape[0],0:src.shape[1]]\n", 123 | " cv2_imshow(output)\n", 124 | "\n", 125 | " outputLoc = \"output/\"+tail.replace(\".jpg\",\".png\")\n", 126 | " cv2.imwrite(outputLoc,output)" 127 | ], 128 | "execution_count": null, 129 | "outputs": [] 130 | } 131 | ] 132 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MangaLineExtraction_PyTorch 2 | The (Official) PyTorch Implementation of the paper _[Deep Extraction of Manga Structural Lines](https://www.cse.cuhk.edu.hk/~ttwong/papers/linelearn/linelearn.html)_. This project aims to extract the structural lines from 2D manga, cartoons, and illustrations. 3 | 4 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ljsabc/MangaLineExtraction_PyTorch/blob/main/MangaLineExtraction.ipynb) 5 | 6 | Besides colab, There is also a gradio-like [web demo](https://moeka.me/mangaLine/). (i18n help needed!) 7 | 8 | ![teaser](./assets/teaser.png) 9 | 10 | 11 | ### Usage 12 | 13 | model_torch.py [source folder] [output folder] 14 | 15 | Example: 16 | 17 | model_torch.py ./pytorchTestCases/ ./pytorchResults/ 18 | 19 | ### The model weights (erika.pth) 20 | 21 | You can always refer to the standardized [HuggingFace pipeline](https://huggingface.co/p1atdev/MangaLineExtraction-hf) to perform simple inferences. Thanks for the port and model quantization. 22 | 23 | For manual weight download, please refer to the **[release](https://github.com/ljsabc/MangaLineExtraction_PyTorch/releases)** section of this repo. Alternatively, you may use this link: 24 | 25 | https://www.dropbox.com/s/y8pulix3zs73y62/erika.pth?dl=0 26 | 27 | ### Requirement 28 | 29 | + Python3 30 | + PyTorch (tested on version 1.9) 31 | + Python-opencv 32 | 33 | ### How the model is prepared 34 | 35 | The PyTorch weights are exactly the same as the theano(!) model. I make some efforts to convert the original weights to the new model and ensure the overall error is less than 1e-3 over the image range from 0-255. 36 | 37 | Moreover, the functional PyTorch interface allows easier fine-tuning of this model. You can also take the whole model as a sub-module for your own work (e.g., use the on-the-fly extraction of lines as a structural constraint). 38 | 39 | ### About model training 40 | 41 | I really don't want to admit it, but the legacy code looks like some artworks by a two-years old. Please refer to #5 if you have any interest. 42 | 43 | ### Go beyond manga 44 | 45 | Surprisingly, this model works quite well on color cartoons and other nijigen-like images, as long as they have clear hand-drawn lines. Simply load the image as grayscale(by default) and check out the results! 46 | 47 | ![Visual comparison](./assets/comparison.png) 48 | From left to right: input, [sketchKeras](https://github.com/lllyasviel/sketchKeras), [Anime2Sketch](https://github.com/Mukosame/Anime2Sketch) (considered as SOTA), Ours. 49 | 50 | ### Gallery 51 | 52 | I'm glad to share some model results. Some of the images are copyrighted and I will list the original source below. Feel free to share your creaions with me in the issues section. 53 | 54 | ![](./assets/gallery1.png) 55 | [©IWAYUU](http://iwayu2.blog.fc2.com/blog-entry-9.html), from the fc2 blog. 56 | 57 | ![](./assets/gallery2.png) 58 | 59 | ### BibTeX: 60 | 61 | @article{li-2017-deep, 62 | author = {Chengze Li and Xueting Liu and Tien-Tsin Wong}, 63 | title = {Deep Extraction of Manga Structural Lines}, 64 | journal = {ACM Transactions on Graphics (SIGGRAPH 2017 issue)}, 65 | month = {July}, 66 | year = {2017}, 67 | volume = {36}, 68 | number = {4}, 69 | pages = {117:1--117:12}, 70 | } 71 | 72 | ### Credit: 73 | 74 | + Xueting Liu and Tien-Tsin Wong, who contributed this work 75 | + Wenliang Wu and Ziheng Ma, who inspired me to port this great thing to PyTorch 76 | + Toda Erika, where the project name comes from 77 | 78 | -------------------------------------------------------------------------------- /assets/color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljsabc/MangaLineExtraction_PyTorch/6ec136d5332180b65476e62c2558f2873d5d936a/assets/color.png -------------------------------------------------------------------------------- /assets/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljsabc/MangaLineExtraction_PyTorch/6ec136d5332180b65476e62c2558f2873d5d936a/assets/comparison.png -------------------------------------------------------------------------------- /assets/gallery1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljsabc/MangaLineExtraction_PyTorch/6ec136d5332180b65476e62c2558f2873d5d936a/assets/gallery1.png -------------------------------------------------------------------------------- /assets/gallery2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljsabc/MangaLineExtraction_PyTorch/6ec136d5332180b65476e62c2558f2873d5d936a/assets/gallery2.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljsabc/MangaLineExtraction_PyTorch/6ec136d5332180b65476e62c2558f2873d5d936a/assets/teaser.png -------------------------------------------------------------------------------- /model_torch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.data.dataset import Dataset 5 | from PIL import Image 6 | import fnmatch 7 | import cv2 8 | 9 | import sys 10 | 11 | import numpy as np 12 | 13 | #torch.set_printoptions(precision=10) 14 | 15 | 16 | class _bn_relu_conv(nn.Module): 17 | def __init__(self, in_filters, nb_filters, fw, fh, subsample=1): 18 | super(_bn_relu_conv, self).__init__() 19 | self.model = nn.Sequential( 20 | nn.BatchNorm2d(in_filters, eps=1e-3), 21 | nn.LeakyReLU(0.2), 22 | nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2), padding_mode='zeros') 23 | ) 24 | 25 | def forward(self, x): 26 | return self.model(x) 27 | 28 | # the following are for debugs 29 | print("****", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape) 30 | for i,layer in enumerate(self.model): 31 | if i != 2: 32 | x = layer(x) 33 | else: 34 | x = layer(x) 35 | #x = nn.functional.pad(x, (1, 1, 1, 1), mode='constant', value=0) 36 | print("____", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape) 37 | print(x[0]) 38 | return x 39 | 40 | 41 | class _u_bn_relu_conv(nn.Module): 42 | def __init__(self, in_filters, nb_filters, fw, fh, subsample=1): 43 | super(_u_bn_relu_conv, self).__init__() 44 | self.model = nn.Sequential( 45 | nn.BatchNorm2d(in_filters, eps=1e-3), 46 | nn.LeakyReLU(0.2), 47 | nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2)), 48 | nn.Upsample(scale_factor=2, mode='nearest') 49 | ) 50 | 51 | def forward(self, x): 52 | return self.model(x) 53 | 54 | 55 | 56 | class _shortcut(nn.Module): 57 | def __init__(self, in_filters, nb_filters, subsample=1): 58 | super(_shortcut, self).__init__() 59 | self.process = False 60 | self.model = None 61 | if in_filters != nb_filters or subsample != 1: 62 | self.process = True 63 | self.model = nn.Sequential( 64 | nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample) 65 | ) 66 | 67 | def forward(self, x, y): 68 | #print(x.size(), y.size(), self.process) 69 | if self.process: 70 | y0 = self.model(x) 71 | #print("merge+", torch.max(y0+y), torch.min(y0+y),torch.mean(y0+y), torch.std(y0+y), y0.shape) 72 | return y0 + y 73 | else: 74 | #print("merge", torch.max(x+y), torch.min(x+y),torch.mean(x+y), torch.std(x+y), y.shape) 75 | return x + y 76 | 77 | class _u_shortcut(nn.Module): 78 | def __init__(self, in_filters, nb_filters, subsample): 79 | super(_u_shortcut, self).__init__() 80 | self.process = False 81 | self.model = None 82 | if in_filters != nb_filters: 83 | self.process = True 84 | self.model = nn.Sequential( 85 | nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample, padding_mode='zeros'), 86 | nn.Upsample(scale_factor=2, mode='nearest') 87 | ) 88 | 89 | def forward(self, x, y): 90 | if self.process: 91 | return self.model(x) + y 92 | else: 93 | return x + y 94 | 95 | 96 | class basic_block(nn.Module): 97 | def __init__(self, in_filters, nb_filters, init_subsample=1): 98 | super(basic_block, self).__init__() 99 | self.conv1 = _bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample) 100 | self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3) 101 | self.shortcut = _shortcut(in_filters, nb_filters, subsample=init_subsample) 102 | 103 | def forward(self, x): 104 | x1 = self.conv1(x) 105 | x2 = self.residual(x1) 106 | return self.shortcut(x, x2) 107 | 108 | class _u_basic_block(nn.Module): 109 | def __init__(self, in_filters, nb_filters, init_subsample=1): 110 | super(_u_basic_block, self).__init__() 111 | self.conv1 = _u_bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample) 112 | self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3) 113 | self.shortcut = _u_shortcut(in_filters, nb_filters, subsample=init_subsample) 114 | 115 | def forward(self, x): 116 | y = self.residual(self.conv1(x)) 117 | return self.shortcut(x, y) 118 | 119 | 120 | class _residual_block(nn.Module): 121 | def __init__(self, in_filters, nb_filters, repetitions, is_first_layer=False): 122 | super(_residual_block, self).__init__() 123 | layers = [] 124 | for i in range(repetitions): 125 | init_subsample = 1 126 | if i == repetitions - 1 and not is_first_layer: 127 | init_subsample = 2 128 | if i == 0: 129 | l = basic_block(in_filters=in_filters, nb_filters=nb_filters, init_subsample=init_subsample) 130 | else: 131 | l = basic_block(in_filters=nb_filters, nb_filters=nb_filters, init_subsample=init_subsample) 132 | layers.append(l) 133 | 134 | self.model = nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | return self.model(x) 138 | 139 | 140 | class _upsampling_residual_block(nn.Module): 141 | def __init__(self, in_filters, nb_filters, repetitions): 142 | super(_upsampling_residual_block, self).__init__() 143 | layers = [] 144 | for i in range(repetitions): 145 | l = None 146 | if i == 0: 147 | l = _u_basic_block(in_filters=in_filters, nb_filters=nb_filters)#(input) 148 | else: 149 | l = basic_block(in_filters=nb_filters, nb_filters=nb_filters)#(input) 150 | layers.append(l) 151 | 152 | self.model = nn.Sequential(*layers) 153 | 154 | def forward(self, x): 155 | return self.model(x) 156 | 157 | 158 | class res_skip(nn.Module): 159 | 160 | def __init__(self): 161 | super(res_skip, self).__init__() 162 | self.block0 = _residual_block(in_filters=1, nb_filters=24, repetitions=2, is_first_layer=True)#(input) 163 | self.block1 = _residual_block(in_filters=24, nb_filters=48, repetitions=3)#(block0) 164 | self.block2 = _residual_block(in_filters=48, nb_filters=96, repetitions=5)#(block1) 165 | self.block3 = _residual_block(in_filters=96, nb_filters=192, repetitions=7)#(block2) 166 | self.block4 = _residual_block(in_filters=192, nb_filters=384, repetitions=12)#(block3) 167 | 168 | self.block5 = _upsampling_residual_block(in_filters=384, nb_filters=192, repetitions=7)#(block4) 169 | self.res1 = _shortcut(in_filters=192, nb_filters=192)#(block3, block5, subsample=(1,1)) 170 | 171 | self.block6 = _upsampling_residual_block(in_filters=192, nb_filters=96, repetitions=5)#(res1) 172 | self.res2 = _shortcut(in_filters=96, nb_filters=96)#(block2, block6, subsample=(1,1)) 173 | 174 | self.block7 = _upsampling_residual_block(in_filters=96, nb_filters=48, repetitions=3)#(res2) 175 | self.res3 = _shortcut(in_filters=48, nb_filters=48)#(block1, block7, subsample=(1,1)) 176 | 177 | self.block8 = _upsampling_residual_block(in_filters=48, nb_filters=24, repetitions=2)#(res3) 178 | self.res4 = _shortcut(in_filters=24, nb_filters=24)#(block0,block8, subsample=(1,1)) 179 | 180 | self.block9 = _residual_block(in_filters=24, nb_filters=16, repetitions=2, is_first_layer=True)#(res4) 181 | self.conv15 = _bn_relu_conv(in_filters=16, nb_filters=1, fh=1, fw=1, subsample=1)#(block7) 182 | 183 | def forward(self, x): 184 | x0 = self.block0(x) 185 | x1 = self.block1(x0) 186 | x2 = self.block2(x1) 187 | x3 = self.block3(x2) 188 | x4 = self.block4(x3) 189 | 190 | x5 = self.block5(x4) 191 | res1 = self.res1(x3, x5) 192 | 193 | x6 = self.block6(res1) 194 | res2 = self.res2(x2, x6) 195 | 196 | x7 = self.block7(res2) 197 | res3 = self.res3(x1, x7) 198 | 199 | x8 = self.block8(res3) 200 | res4 = self.res4(x0, x8) 201 | 202 | x9 = self.block9(res4) 203 | y = self.conv15(x9) 204 | 205 | return y 206 | 207 | class MyDataset(Dataset): 208 | def __init__(self, image_paths, transform=None): 209 | self.image_paths = image_paths 210 | self.transform = transform 211 | 212 | def get_class_label(self, image_name): 213 | # your method here 214 | head, tail = os.path.split(image_name) 215 | #print(tail) 216 | return tail 217 | 218 | def __getitem__(self, index): 219 | image_path = self.image_paths[index] 220 | x = Image.open(image_path) 221 | y = self.get_class_label(image_path.split('/')[-1]) 222 | if self.transform is not None: 223 | x = self.transform(x) 224 | return x, y 225 | 226 | def __len__(self): 227 | return len(self.image_paths) 228 | 229 | def loadImages(folder): 230 | imgs = [] 231 | matches = [] 232 | for root, dirnames, filenames in os.walk(folder): 233 | for filename in fnmatch.filter(filenames, '*'): 234 | matches.append(os.path.join(root, filename)) 235 | 236 | return matches 237 | 238 | if __name__ == "__main__": 239 | model = res_skip() 240 | model.load_state_dict(torch.load('erika.pth')) 241 | is_cuda = torch.cuda.is_available() 242 | if is_cuda: 243 | model.cuda() 244 | else: 245 | model.cpu() 246 | model.eval() 247 | 248 | filelists = loadImages(sys.argv[1]) 249 | 250 | with torch.no_grad(): 251 | for imname in filelists: 252 | src = cv2.imread(imname,cv2.IMREAD_GRAYSCALE) 253 | 254 | rows = int(np.ceil(src.shape[0]/16))*16 255 | cols = int(np.ceil(src.shape[1]/16))*16 256 | 257 | # manually construct a batch. You can change it based on your usecases. 258 | patch = np.ones((1,1,rows,cols),dtype="float32") 259 | patch[0,0,0:src.shape[0],0:src.shape[1]] = src 260 | 261 | if is_cuda: 262 | tensor = torch.from_numpy(patch).cuda() 263 | else: 264 | tensor = torch.from_numpy(patch).cpu() 265 | y = model(tensor) 266 | print(imname, torch.max(y), torch.min(y)) 267 | 268 | yc = y.cpu().numpy()[0,0,:,:] 269 | yc[yc>255] = 255 270 | yc[yc<0] = 0 271 | 272 | head, tail = os.path.split(imname) 273 | cv2.imwrite(sys.argv[2]+"/"+tail.replace(".jpg",".png"),yc[0:src.shape[0],0:src.shape[1]]) 274 | -------------------------------------------------------------------------------- /pytorchResults/PrismHeart_079.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljsabc/MangaLineExtraction_PyTorch/6ec136d5332180b65476e62c2558f2873d5d936a/pytorchResults/PrismHeart_079.png -------------------------------------------------------------------------------- /pytorchTestCases/PrismHeart_079.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljsabc/MangaLineExtraction_PyTorch/6ec136d5332180b65476e62c2558f2873d5d936a/pytorchTestCases/PrismHeart_079.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.9.1 2 | opencv-python --------------------------------------------------------------------------------