├── LICENSE ├── README.md ├── code ├── RectangleGenerator.py ├── TextGenerator.py ├── config.py ├── danbooru.py ├── dataset.py ├── experiments │ ├── BDUNET.ipynb │ ├── BDUNET_Train.ipynb │ ├── View results.ipynb │ ├── craft.ipynb │ ├── datasets.ipynb │ ├── experiments.py │ ├── extra.ipynb │ ├── hrnet.ipynb │ ├── hrnet │ │ ├── hrnet train.ipynb │ │ ├── manga.py │ │ └── manga.yaml │ ├── loss.ipynb │ ├── model.ipynb │ ├── refine.ipynb │ ├── sickzil-machine.ipynb │ ├── synthetic.ipynb │ ├── v1_2.ipynb │ └── yu45020.ipynb ├── losses.py ├── manga109.py ├── metrics.py ├── ssim.py └── transforms.py ├── examples └── Manga_Text_Segmentation_Predict.ipynb ├── images ├── AisazuNihaIrarenai-009.jpg ├── AisazuNihaIrarenaipre-prediction-009.png └── AisazuNihaIrarenaipre-processed-truth-009.png └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 julian 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Manga-Text-Segmentation 2 | Official repo of our paper [Unconstrained Text Detection in Manga: a NewDataset and Baseline](https://link.springer.com/chapter/10.1007%2F978-3-030-67070-2_38). 3 | 4 | ## Example 5 | Input: 6 | ![AisazuNihaIrarenai 009](images/AisazuNihaIrarenai-009.jpg) 7 | Output: 8 | ![AisazuNihaIrarenai 009](images/AisazuNihaIrarenaipre-prediction-009.png) 9 | (source: [manga109](http://www.manga109.org/en/), © Yoshi Masako) 10 | 11 | ## Dataset 12 | Our label masks are available at [zenodo](https://zenodo.org/record/4511796). For the original Manga109 images, please refer to [Manga 109 website](http://www.manga109.org/en/). 13 | 14 | ## Code 15 | The important notebooks that led to our best models are: **model** (resnet34 from modelDict) and **refine** (model/resnet34 path). The rest were either experiments or visualization. Check in notebook **extra** to view how predictions are loaded. 16 | 17 | ## Try it out 18 | You can try out predicting by opening the [Predict notebook](/examples/Manga_Text_Segmentation_Predict.ipynb) in colab. 19 | 20 | ## Additional Resources 21 | A further extensive detail of our research can be found in [Arxiv](https://arxiv.org/abs/2010.03997). This includes more examples and other things we tried before reaching the end result explained in our paper. 22 | 23 | ## Citation 24 | If you find this work or code is helpful in your research, please cite: 25 | ```` 26 | @InProceedings{10.1007/978-3-030-67070-2_38, 27 | author="Del Gobbo, Juli{\'a}n 28 | and Matuk Herrera, Rosana", 29 | editor="Bartoli, Adrien 30 | and Fusiello, Andrea", 31 | title="Unconstrained Text Detection in Manga: A New Dataset and Baseline", 32 | booktitle="Computer Vision -- ECCV 2020 Workshops", 33 | year="2020", 34 | publisher="Springer International Publishing", 35 | address="Cham", 36 | pages="629--646", 37 | abstract="The detection and recognition of unconstrained text is an open problem in research. Text in comic books has unusual styles that raise many challenges for text detection. This work aims to binarize text in a comic genre with highly sophisticated text styles: Japanese manga. To overcome the lack of a manga dataset with text annotations at a pixel level, we create our own. To improve the evaluation and search of an optimal model, in addition to standard metrics in binarization, we implement other special metrics. Using these resources, we designed and evaluated a deep network model, outperforming current methods for text binarization in manga in most metrics.", 38 | isbn="978-3-030-67070-2" 39 | } 40 | 41 | @dataset{segmentation_manga_dataset, 42 | author = {julian del gobbo and 43 | Rosana Matuk Herrera}, 44 | title = {{Mask Dataset for: Unconstrained Text Detection in 45 | Manga: a New Dataset and Baseline}}, 46 | month = feb, 47 | year = 2021, 48 | publisher = {Zenodo}, 49 | version = {1.0}, 50 | doi = {10.5281/zenodo.4511796}, 51 | url = {https://doi.org/10.5281/zenodo.4511796} 52 | } 53 | ```` 54 | -------------------------------------------------------------------------------- /code/RectangleGenerator.py: -------------------------------------------------------------------------------- 1 | import numpy.random as random 2 | 3 | class Rectangle: 4 | def __init__(self, x, y, width, height): 5 | self.x = x 6 | self.y = y 7 | self.width = width 8 | self.height = height 9 | 10 | def intersects(self, rect): 11 | return (self.x < rect.x + rect.width and self.x + self.width > rect.x and 12 | self.y < rect.y + rect.height and self.y + self.height > rect.y) 13 | 14 | def area(self): 15 | return self.width * self.height 16 | 17 | def __repr__(self): 18 | return str((self.x, self.y, self.width, self.height)) 19 | 20 | class RectangleGenerator: 21 | @staticmethod 22 | def generate(width, height, limit): 23 | rects = [] 24 | 25 | for i in range(0, min(limit * 2, 15)): 26 | x, y = random.randint(0, int(width * 0.93)) , random.randint(0, int(height * 0.9)) 27 | 28 | if random.random_sample() < 0.8: 29 | w = random.randint(7, 15) 30 | h = random.randint(10, 35) 31 | else: 32 | w = random.randint(15, 100) 33 | h = random.randint(10, 50) 34 | 35 | r = Rectangle(x, y, min(int(w * width / 100), width), min(int(h * height / 100), height)) 36 | add = True 37 | 38 | for rect in rects: 39 | if rect.intersects(r) and random.random_sample() < 0.5: 40 | r = Rectangle(x, y, int(r.width / 2), r.height) 41 | if rect.intersects(r) and random.random_sample() < 0.5: 42 | r = Rectangle(x, y, r.width, int(r.height / 2)) 43 | if rect.intersects(r): 44 | add = False 45 | break 46 | 47 | if add: 48 | rects.append(r) 49 | if len(rects) == limit: 50 | break 51 | return rects -------------------------------------------------------------------------------- /code/TextGenerator.py: -------------------------------------------------------------------------------- 1 | import numpy.random as random 2 | from fontTools.ttLib import TTFont 3 | from fontTools.unicode import Unicode 4 | 5 | #http://www.rikai.com/library/kanjitables/kanji_codes.unicode.shtml 6 | class TextGenerator: 7 | ranges = [(33, 122), (0x3040, 0x309f), (0x30a0, 0x30ff), (0xff60, 0xffb0), (0x4e00, 0x9faf)] 8 | total = sum([x[1] - x[0] + 1 for x in ranges]) 9 | cache = dict() 10 | 11 | @staticmethod 12 | def isChar(c): 13 | idx = ord(c) 14 | return any(map(lambda x: idx >= x[0] and idx <= x[1], TextGenerator.ranges)) 15 | 16 | @staticmethod 17 | def char(num): 18 | total = 0 19 | for x in TextGenerator.ranges: 20 | total += x[1] - x[0] + 1 21 | if total > num: 22 | return chr(x[1] - (total - num) + 1) 23 | 24 | @staticmethod 25 | def generate(l): 26 | return "".join([TextGenerator.char(random.randint(0, TextGenerator.total)) for x in range(l)]) 27 | 28 | @staticmethod 29 | def text_wrap(text, font, max_width, max_height): 30 | if font.size in TextGenerator.cache: 31 | char_width = TextGenerator.cache[font.size] 32 | else: 33 | char_width = TextGenerator.cache[font.size] = font.getsize('亮')[0] 34 | estimate = (max_width // char_width) 35 | lines = [] 36 | i, j, hei = 0, 0, 0 37 | # append every word to a line while its width is shorter than image width 38 | while i < len(text) and estimate > 0: 39 | i = j 40 | j = min(len(text), i + estimate) 41 | width = font.getsize(text[i:j])[0] 42 | while j < len(text) and width <= max_width: 43 | width += font.getsize(text[j])[0] 44 | j += 1 45 | while width > max_width and j > i: 46 | j -= 1 47 | width -= font.getsize(text[j])[0] 48 | hei += font.getsize(text[i:j])[1] 49 | if hei > max_height or i == j: 50 | break 51 | if len(text[i:j]): 52 | lines.append(text[i:j]) 53 | return lines 54 | 55 | 56 | from PIL import ImageFont 57 | import json 58 | import os 59 | 60 | class Font: 61 | def __init__(self, path): 62 | self.cache = dict() 63 | self.path = path 64 | self.initAvailableChars() 65 | 66 | def getFont(self, size): 67 | if size in self.cache: 68 | font = self.cache[size] 69 | else: 70 | font = self.cache[size] = ImageFont.truetype(str(self.path), size) 71 | 72 | return font 73 | 74 | def initAvailableChars(self): 75 | self.chars = [] 76 | key = str(self.path.name) 77 | filename = 'cache/fonts/'+key+'.json' 78 | 79 | try: 80 | with open(filename, 'r') as f: 81 | cache = json.load(f) 82 | except: 83 | cache = dict() 84 | 85 | if 'ranges' in cache and str(TextGenerator.ranges) == cache['ranges']: 86 | self.chars = cache['chars'] 87 | else: 88 | font = TTFont(str(self.path)) 89 | glyphset = font.getGlyphSet() 90 | table = font.getBestCmap() 91 | 92 | for r in TextGenerator.ranges: 93 | #ugly check to know if char is supported by font 94 | self.chars += [chr(x) for x in range(r[0], r[1] + 1) if x in table.keys() and (glyphset[table[x]]._glyph.bytecode != b' \x1d' if hasattr(glyphset[table[x]]._glyph, 'bytecode') else glyphset[table[x]]._glyph.numberOfContours > 0)] 95 | 96 | cache = {'ranges': str(TextGenerator.ranges), 'chars': self.chars} 97 | 98 | with open(filename, 'w') as f: 99 | json.dump(cache, f) 100 | 101 | Fonts.total += len(self.chars) 102 | 103 | 104 | def generateText(self, length): 105 | return "".join(random.choice(self.chars, length)) 106 | 107 | class Fonts: 108 | total = 0 109 | staticmethod 110 | def load(fontFolder): 111 | fonts = [] 112 | os.makedirs('cache/fonts', exist_ok=True) 113 | for extension in ['ttf', 'otf']: 114 | fonts += [Font(p) for p in fontFolder.glob("**/*." + extension)] 115 | 116 | return fonts 117 | 118 | def __init__(self, fonts): 119 | self.fonts = fonts 120 | self.updateWeights() 121 | 122 | def randomFont(self): 123 | if len(self.weights) != len(self.fonts): 124 | self.updateWeights() 125 | 126 | return random.choice(self.fonts, p=self.weights) # let fonts with more chars be more likely 127 | 128 | def updateWeights(self): 129 | self.weights = [] 130 | self.total = 0 131 | 132 | for font in self.fonts: 133 | self.total += len(font.chars) 134 | 135 | for font in self.fonts: 136 | self.weights.append(len(font.chars) / self.total) 137 | 138 | 139 | -------------------------------------------------------------------------------- /code/config.py: -------------------------------------------------------------------------------- 1 | EXPERIMENTS_PATH = "/data/anime2/experiments" 2 | MANGA109_PATH = "/data/anime" 3 | MASKS_PATH = "/data/anime2/manga" 4 | YU45020_PATH = '/data/anime2/yu45020' 5 | SICKZIL_PATH = '/data/anime2/sickzil' 6 | DANBOORU_PATH = '/data/anime2/danbooru2019/512px' 7 | CRAFT_PATH = '/data/anime2/craft/predictions' 8 | ICDAR2013_WEB_PATH = '/data/anime2/icdar2013-web' 9 | ICDAR2013_SCENE_PATH = '/data/anime2/icdar2013-scene' 10 | TOTAL_TEXT_PATH = '/data/anime2/total-text' 11 | DIBCO_PATH = '/data/anime2/dibco' 12 | KAIST_PATH = '/data/anime2/kaist' 13 | BDUNET_PATH = '/data/anime2/BDUnet_DIBCO_predictions/' -------------------------------------------------------------------------------- /code/danbooru.py: -------------------------------------------------------------------------------- 1 | from PIL import Image as pilImage 2 | from fastai.vision import Image, ImageImageList, to_float, SegmentationItemList 3 | from fastai.basics import ItemBase 4 | from typing import * 5 | import torch 6 | import matplotlib.pyplot as plt 7 | from fastai.core import subplots 8 | 9 | class DanbooruImage(ItemBase): 10 | def __init__(self, image, fileDir, idx): 11 | self.data = idx 12 | self.image = image 13 | self.fileDir = fileDir 14 | 15 | def __str__(self): return str(self.image) 16 | 17 | def apply_tfms(self, tfms, **kwargs): 18 | for tfm in tfms: 19 | tfm(self, **kwargs) 20 | 21 | return self 22 | 23 | 24 | class DanbooruImageList(ImageImageList): 25 | @classmethod 26 | def from_textInfo(cls, textInfo: dict, maxItems=10, maxArea=0, **kwargs): 27 | gen = filter(lambda k: textInfo[k] <= maxArea, textInfo.keys()) 28 | items = [x for _, x in zip(range(maxItems), gen)] 29 | return DanbooruImageList(items, **kwargs) 30 | 31 | @classmethod 32 | def with_text(top, textInfo: dict, **kwargs): 33 | return DanbooruImageList(sorted(textInfo.keys(), lambda k: textInfo[k])[0:top], **kwargs) 34 | 35 | def get(self, i): 36 | fileDir = self.items[i] 37 | image = self.open(fileDir.replace("/danbooru/", "/danbooru_corrupt/")) 38 | return DanbooruImage(image, fileDir, i) 39 | 40 | def open(self, fileDir): 41 | return pilImage.open(fileDir).convert('RGB') 42 | 43 | def show_xyzs(self, xs, ys, zs, logger=False, **kwargs): 44 | if logger: 45 | logger.show_xyzs(xs, ys, zs, **kwargs) 46 | else: 47 | return super().show_xyzs(xs, ys, zs, **kwargs) 48 | 49 | class DanbooruSegmentationList(SegmentationItemList): 50 | @classmethod 51 | def from_textInfo(cls, textInfo: dict, maxItems=10, maxArea=0, **kwargs): 52 | gen = filter(lambda k: textInfo[k] <= maxArea, textInfo.keys()) 53 | items = [x for _, x in zip(range(maxItems), gen)] 54 | return DanbooruSegmentationList(items, **kwargs) 55 | 56 | def get(self, i): 57 | fileDir = self.items[i] 58 | image = self.open(fileDir.replace("/danbooru/", "/danbooru_corrupt/")) 59 | return DanbooruImage(image, fileDir, i) 60 | 61 | def open(self, fileDir): 62 | return pilImage.open(fileDir).convert('RGB') -------------------------------------------------------------------------------- /code/dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image as Im 2 | from skimage.morphology import skeletonize, binary_erosion, binary_dilation, remove_small_objects 3 | from fastai.vision import * 4 | from fastai.vision.data import SegmentationProcessor 5 | from config import * 6 | from skimage.measure import label, regionprops 7 | from skimage.draw import polygon 8 | 9 | def get_segmentation(x): 10 | folder = x.parent.name 11 | return MASKS_PATH + '/' + folder + '/' + x.name.replace('.jpg', '.png') 12 | 13 | def divround_up(value, step): 14 | return (value+step-1)//step*step 15 | 16 | def pad_tensor(t, multiple = 8): 17 | padded = torch.zeros(t.shape[0], divround_up(t.shape[1], multiple), divround_up(t.shape[2], multiple)) 18 | padded[:, 0:t.shape[1], 0:t.shape[2]] = t 19 | return padded 20 | 21 | def unpad_tensor(t, shape): 22 | return t[:, 0:shape[1], 0:shape[2]] 23 | 24 | def _random_crop(px, size, randx:uniform=0.5, randy:uniform=0.5): 25 | if isinstance(size, int): 26 | size = (size, size) 27 | y, x = int((px.shape[1] - size[1]) * randx), int((px.shape[2] - size[0]) * randy) 28 | return px[:, y:y+size[1], x:x+size[0]] 29 | 30 | def cut_tensor(t, left=True): 31 | return t[:, :, 0:t.size(2)//2].contiguous() if left else t[:, :, t.size(2)//2:].contiguous() 32 | 33 | random_crop = TfmPixel(_random_crop) 34 | 35 | mapping = dict() 36 | mapping[(1, 1, 1)] = 1 #black (easy) 37 | mapping[(255, 1, 255)] = 2 #pink (hard) 38 | mapping[(255, 255, 255)] = 0 #white (background) 39 | 40 | 41 | #replaces part of truth with new ignore categories to account for errors during labeling. 42 | def addIgnore(mask, ignore = True): 43 | if ignore: 44 | skeleton = torch.zeros(mask.shape).bool() 45 | 46 | #need to make skeletons by dataset type or 2 close pink/black components will end up with same component in the skeleton 47 | for val in mask[0].unique(): 48 | if val == 0: continue 49 | 50 | sk = tensor(skeletonize((mask[0] == val).numpy())) #get skeleton of text 51 | skeleton |= sk.bool() 52 | 53 | newMask = (mask[0] != 0).numpy() #separate into text/non text 54 | eroded = newMask.copy() 55 | dilated = newMask.copy() 56 | 57 | for x in range(0, 3): 58 | eroded = binary_erosion(eroded) 59 | dilated = binary_dilation(dilated) 60 | 61 | eroded = tensor(np.expand_dims(eroded, 0)) 62 | dilated = tensor(np.expand_dims(dilated, 0)) 63 | 64 | mask[((dilated != 0) * ((eroded == 0) & (skeleton == 0)))] += len(mapping) 65 | 66 | return mask 67 | 68 | class CustomImageSegment(ImageSegment): 69 | @property 70 | def eroded(self): 71 | px = self.px.clone() 72 | px[px >= 3] = 0 73 | return type(self)(px) 74 | 75 | @property 76 | def original(self): 77 | px = self.px.clone() 78 | px[px >= 3] -=3 79 | return type(self)(px) 80 | 81 | @property 82 | def boxed(self): 83 | mask = self.px[0] 84 | labels = label(mask != 0, connectivity = 2) 85 | for region in regionprops(labels): 86 | sub = mask[region.slice[0], region.slice[1]] 87 | sub[sub == 0] = 3 88 | return type(self)(mask.unsqueeze(0).float()) 89 | 90 | @property 91 | def rgb(self): 92 | rgb = torch.zeros(3, self.shape[1], self.shape[2]) 93 | px = self.px[0] 94 | r, g, b = rgb 95 | blacks = (px == 1) 96 | pinks = (px == 2) 97 | dilated = px == 3 98 | eroded = (px == 4) | (px == 5) 99 | bg = (px == 0) 100 | 101 | r[bg] = g[bg] = b[bg] = 255 102 | r[blacks] = g[blacks] = b[blacks] = 1 103 | g[pinks] = 1 104 | r[pinks] = b[pinks] = 255 105 | g[dilated] = 255 106 | b[eroded] = 255 107 | 108 | return Image(rgb.div(255)) 109 | 110 | #caches image processing that happens on open 111 | class CacheProcessor(SegmentationProcessor): 112 | def process(self, ds): 113 | super().process(ds) 114 | for idx, item in enumerate(ds.items): 115 | path = Path(str(item).replace(".png", ".cache")) 116 | if not path.exists(): 117 | im = Im.fromarray(image2np(ds.open(item, False).px).astype(np.uint8)) 118 | im.save(path, format='PNG') 119 | 120 | #used to make dataloder have twice the size by cutting images by half 121 | class CutInHalf(): 122 | cutInHalf = True 123 | padding = 8 124 | 125 | def get(self, i, cutInHalf = None): 126 | if cutInHalf is None: 127 | cutInHalf = self.cutInHalf 128 | x = super().get(i % len(self.items)).px 129 | 130 | if cutInHalf: 131 | x = cut_tensor(x, i < len(self.items)) 132 | 133 | return self.reconstruct(x) 134 | 135 | def __len__(self): 136 | return max(1, len(self.items) * (2 if self.cutInHalf else 1)) 137 | 138 | class SegLabelListCustom(CutInHalf, SegmentationLabelList): 139 | ignore = True 140 | areaThreshold = 3 141 | _processor = CacheProcessor 142 | useCached = True 143 | 144 | def reconstruct(self, t): 145 | return CustomImageSegment(pad_tensor(t, self.padding)) 146 | 147 | def analyze_pred(self, pred, thresh:float=0.5): 148 | return torch.sigmoid(pred) > thresh 149 | 150 | def open(self, fn, useCached=None): 151 | if useCached is None: 152 | useCached = self.useCached 153 | 154 | if useCached: 155 | fn = type(fn)(str(fn).replace(".png", ".cache")) 156 | 157 | im = Im.open(fn) 158 | tensor = pil2tensor(im, np.uint8) 159 | 160 | if ".cache" in str(fn): 161 | return self.reconstruct(tensor) 162 | 163 | palette = im.getpalette() 164 | 165 | assert(tensor.ge(len(mapping)).sum().item() == 0) #assert we only have 3 colors 166 | 167 | newTensor = tensor.clone() 168 | 169 | for x in range(tensor.max().item() + 1): #replace color index (which can be black = 0 for one image, black = 1 for another) for something standard, defined in mapping 170 | color = (palette[x * 3], palette[x * 3 + 1], palette[x * 3 + 2]) 171 | assert color in mapping, str(color) + " not in mapping " + fn #assert label only has the colors we think they have (black, pink, white) 172 | newTensor[tensor == x] = mapping[color] 173 | 174 | return self.reconstruct(addIgnore(self.threshold(newTensor), self.ignore)) 175 | 176 | #removes small components and fill in small holes 177 | def threshold(self, mask): 178 | if self.areaThreshold is not None: 179 | #fill in small holes 180 | labels = label(mask[0].cpu() == 0, connectivity=2) 181 | for region in regionprops(labels): 182 | if region.area <= self.areaThreshold: 183 | values = mask[0, region.slice[0], region.slice[1]] 184 | val = values.max() 185 | if val == 0: 186 | val = mask[0, tensor(binary_dilation(labels == region.label))].max() 187 | values[values == 0] = val 188 | 189 | #remove small components 190 | labeles = label(mask[0].cpu(), connectivity=2) 191 | remove_small_objects(labeles, self.areaThreshold + 1, in_place=True) 192 | mask[0][labeles == 0] = 0 193 | 194 | #make sure it all worked 195 | labels = label(mask[0].cpu(), connectivity=2, background = 2) 196 | for region in regionprops(labels): 197 | if region.area <= self.areaThreshold: 198 | assert False, "didn't work" 199 | 200 | 201 | return mask 202 | 203 | 204 | 205 | def loadPrediction(self, path): 206 | mask = open_mask(path) 207 | mask.px = (mask.px != 0).float() 208 | mask.px = pad_tensor(mask.px, self.padding) 209 | return mask 210 | 211 | def sickzilImage(self, idx): 212 | p = Path(self.items[idx]) 213 | return self.loadPrediction(Path(SICKZIL_PATH) / (p.parent.name + "_mproj") / "masks" / p.name) 214 | 215 | def yu45020Image(self, idx): 216 | p = Path(self.items[idx]) 217 | return self.loadPrediction(Path(YU45020_PATH) / (p.parent.name) / p.name.replace(p.suffix, ".jpg")) 218 | 219 | def xceptionImage(self, idx): 220 | p = Path(self.items[idx]) 221 | pred = torch.load(Path(EXPERIMENTS_PATH) / 'model' / 'xception' /'predictions' / p.parent.name / p.name.replace(p.suffix, ".pt")) 222 | mask = torch.sigmoid(pred) > 0.5 223 | mask = pad_tensor(mask, self.padding) 224 | return ImageSegment(mask) 225 | 226 | #get mask from polygon prediction 227 | def craftImage(self, idx): 228 | mask = torch.zeros(self.get(idx, False).px.shape)[0].transpose(1, 0) 229 | path = Path(self.items[idx]) 230 | folder, name = path.parent.name, path.stem 231 | with open(Path(CRAFT_PATH) / folder / ('res_' + name + '.txt'), 'r') as f: 232 | polygons = list(map(lambda x: x.split(','), filter(len, f.read().split('\n')))) 233 | for poly in polygons: 234 | rr, cc = polygon(np.array(poly[::2]).astype(np.uint32), np.array(poly[1:][::2]).astype(np.uint32), shape = mask.shape) 235 | mask[rr, cc] = 1 236 | return self.reconstruct(tensor(mask.transpose(1, 0)).unsqueeze(0)) 237 | 238 | def cachedPrediction(self, datasetIdx, imageIdx, folder = Path(EXPERIMENTS_PATH) / 'model' / 'resnet34'): 239 | TENSOR_PATH = folder / str(datasetIdx) / 'predictions' / (str(imageIdx) + '.pt') 240 | return torch.load(TENSOR_PATH) 241 | 242 | 243 | class SegItemListCustom(CutInHalf, SegmentationItemList): 244 | _label_cls = SegLabelListCustom 245 | 246 | def __init__(self, items, *args, **kwargs): 247 | items = sorted(items) 248 | super().__init__(items, **kwargs) 249 | 250 | def reconstruct(self, t): 251 | return Image(pad_tensor(t, self.padding)) 252 | -------------------------------------------------------------------------------- /code/experiments/BDUNET.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import glob\n", 10 | "from pathlib import Path\n", 11 | "from experiments import *\n", 12 | "from metrics import *\n", 13 | "from config import *\n", 14 | "from fastai.vision import *\n", 15 | "from dataset import pad_tensor, addIgnore\n", 16 | "\n", 17 | "%load_ext autoreload\n", 18 | "%autoreload 2" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 32, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "random_seed(42)\n", 28 | "CSV_PATH = EXPERIMENTS_PATH + '/BDUNET/predictions.csv'\n", 29 | "\n", 30 | "folds = list(KFold(n_splits = 5, shuffle = True, random_state=42).split(trainFolders, trainFolders))\n", 31 | "\n", 32 | "for idx, (_, valid_indexes) in enumerate(folds):\n", 33 | " if idx not in [1, 2]:\n", 34 | " m = MetricsCallback(None)\n", 35 | " m.on_train_begin() \n", 36 | " images = []\n", 37 | " for f in glob.glob(MASKS_PATH + \"/**/*.cache\"):\n", 38 | " index = trainFolders.index(Path(f).parent.name)\n", 39 | " if index in valid_indexes:\n", 40 | " pred = open_mask(BDUNET_PATH + '/' + Path(f).parent.name + '/' + Path(f).stem + '.png')\n", 41 | " pred.px = pred.px != 0\n", 42 | " seg = open_mask(f)\n", 43 | " m.on_batch_end(False, pad_tensor(pred.px, 8), seg.px[:, 0:pred.px.shape[1], 0:pred.px.shape[2]])\n", 44 | " m.calculateMetrics()\n", 45 | " m.save(CSV_PATH, idx > 0)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [] 54 | } 55 | ], 56 | "metadata": { 57 | "kernelspec": { 58 | "display_name": "Python 3", 59 | "language": "python", 60 | "name": "python3" 61 | }, 62 | "language_info": { 63 | "codemirror_mode": { 64 | "name": "ipython", 65 | "version": 3 66 | }, 67 | "file_extension": ".py", 68 | "mimetype": "text/x-python", 69 | "name": "python", 70 | "nbconvert_exporter": "python", 71 | "pygments_lexer": "ipython3", 72 | "version": "3.7.3" 73 | } 74 | }, 75 | "nbformat": 4, 76 | "nbformat_minor": 4 77 | } 78 | -------------------------------------------------------------------------------- /code/experiments/BDUNET_Train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "trainFolders = [\n", 10 | " 'ARMS',\n", 11 | " 'AisazuNihaIrarenai',\n", 12 | " 'AkkeraKanjinchou',\n", 13 | " 'Akuhamu',\n", 14 | " 'AosugiruHaru',\n", 15 | " 'AppareKappore',\n", 16 | " 'Arisa',\n", 17 | " 'BEMADER_P',\n", 18 | " 'BakuretsuKungFuGirl',\n", 19 | " 'Belmondo',\n", 20 | " 'BokuHaSitatakaKun',\n", 21 | " 'BurariTessenTorimonocho',\n", 22 | " 'ByebyeC-BOY',\n", 23 | " 'Count3DeKimeteAgeru',\n", 24 | " 'DollGun',\n", 25 | " 'Donburakokko',\n", 26 | " 'DualJustice',\n", 27 | " 'EienNoWith',\n", 28 | " 'EvaLady',\n", 29 | " 'EverydayOsakanaChan',\n", 30 | " 'GOOD_KISS_Ver2',\n", 31 | " 'GakuenNoise',\n", 32 | " 'GarakutayaManta',\n", 33 | " 'GinNoChimera',\n", 34 | " 'Hamlet',\n", 35 | " 'HanzaiKousyouninMinegishiEitarou',\n", 36 | " 'HaruichibanNoFukukoro',\n", 37 | " 'HarukaRefrain',\n", 38 | " 'HealingPlanet',\n", 39 | " \"UchiNoNyan'sDiary\",\n", 40 | " 'UchuKigekiM774',\n", 41 | " 'UltraEleven',\n", 42 | " 'UnbalanceTokyo',\n", 43 | " 'WarewareHaOniDearu',\n", 44 | " 'YamatoNoHane',\n", 45 | " 'YasasiiAkuma',\n", 46 | " 'YouchienBoueigumi',\n", 47 | " 'YoumaKourin',\n", 48 | " 'YukiNoFuruMachi',\n", 49 | " 'YumeNoKayoiji',\n", 50 | " 'YumeiroCooking',\n", 51 | " 'TotteokiNoABC',\n", 52 | " 'ToutaMairimasu',\n", 53 | " 'TouyouKidan',\n", 54 | " 'TsubasaNoKioku'\n", 55 | "]\n" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 2, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "import numpy as np \n", 65 | "import utils as U \n", 66 | "import glob\n", 67 | "from PIL import Image\n", 68 | "from pathlib import Path \n", 69 | "from sklearn.model_selection import KFold\n", 70 | "\n", 71 | "import os\n", 72 | "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n", 73 | "import models as M\n", 74 | "import numpy as np\n", 75 | "from keras.callbacks import ModelCheckpoint,ReduceLROnPlateau\n", 76 | "from keras import callbacks\n", 77 | "import keras\n", 78 | "import pickle" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 3, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "class DataGenerator(keras.utils.Sequence):\n", 88 | " def __init__(self, fold, patches = 40, train = False, dims=(128,128), shuffle=True):\n", 89 | " self.shuffle = shuffle\n", 90 | " self.patches = patches\n", 91 | " self.train = train\n", 92 | " self.dims = dims\n", 93 | " self.fold = list(KFold(n_splits = 5, shuffle = True, random_state=42).split(trainFolders, trainFolders))[fold]\n", 94 | " self.init_images()\n", 95 | " self.on_epoch_end()\n", 96 | "\n", 97 | " def init_images(self):\n", 98 | " self.images = []\n", 99 | " for f in glob.glob(\"/data/anime/masks/**/*.cache\"):\n", 100 | " index = trainFolders.index(Path(f).parent.name)\n", 101 | " if index in self.fold[0] and self.train:\n", 102 | " self.images.append(f)\n", 103 | " elif index in self.fold[1] and not self.train:\n", 104 | " self.images.append(f)\n", 105 | " #self.images = self.images[0:3]\n", 106 | " def __len__(self):\n", 107 | " 'Denotes the number of batches per epoch'\n", 108 | " return int(np.floor(len(self.images)))\n", 109 | "\n", 110 | " def __getitem__(self, index):\n", 111 | " 'Generate one batch of data'\n", 112 | " # Generate indexes of the batch\n", 113 | "\n", 114 | " # Generate data\n", 115 | " X, y = self.__data_generation([self.images[index]])\n", 116 | "\n", 117 | " return X, y\n", 118 | "\n", 119 | " def on_epoch_end(self):\n", 120 | " 'Updates indexes after each epoch'\n", 121 | " self.indexes = np.arange(len(self.images))\n", 122 | " if self.shuffle == True:\n", 123 | " np.random.shuffle(self.indexes)\n", 124 | "\n", 125 | " def __data_generation(self, files):\n", 126 | " 'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)\n", 127 | " images = []\n", 128 | " masks = []\n", 129 | " \n", 130 | " for f in files:\n", 131 | " mask = np.array(Image.open(f))\n", 132 | " mask[mask >= 3] -= 3\n", 133 | " mask = (mask != 0) * 255\n", 134 | " img = np.asarray(Image.open('/data/anime/manga/Manga109_2017_09_28/images/' + Path(f).parent.name + '/' + Path(f).stem + '.jpg').convert('L'))\n", 135 | " images.append(img)\n", 136 | " masks.append(mask)\n", 137 | " \n", 138 | " patches_image, patches_masks = U.extract_random(images, masks, *self.dims, self.patches)\n", 139 | " \n", 140 | " patches_image /= 255.\n", 141 | " patches_masks /= 255. \n", 142 | " \n", 143 | " patches_image = np.expand_dims(patches_image, axis = 3)\n", 144 | " patches_masks = np.expand_dims(patches_masks, axis = 3)\n", 145 | " return patches_image, patches_masks\n", 146 | " " 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "def train(index):\n", 156 | " train = DataGenerator(index, train=True)\n", 157 | " test = DataGenerator(index, train=False)\n", 158 | " print('Dataset Prepared')\n", 159 | "\n", 160 | " model = M.BCDU_net_D3(input_size = (128, 128, 1))\n", 161 | " model.summary()\n", 162 | "\n", 163 | " print('Training')\n", 164 | "\n", 165 | " nb_epoch = 20\n", 166 | "\n", 167 | " mcp_save = ModelCheckpoint('weight_text'+str(index)+'.hdf5', save_best_only=True, monitor='val_loss', mode='min')\n", 168 | " reduce_lr_loss = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=7, verbose=1, epsilon=1e-4, mode='min')\n", 169 | "\n", 170 | " history = model.fit(x=train,\n", 171 | " epochs=nb_epoch,\n", 172 | " verbose=1,\n", 173 | " validation_data=test, callbacks=[mcp_save, reduce_lr_loss] )\n", 174 | "\n", 175 | " print('Trained model saved')\n", 176 | " with open('fold' + str(index), 'wb') as file_pi:\n", 177 | " pickle.dump(history.history, file_pi)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "name": "stdout", 187 | "output_type": "stream", 188 | "text": [ 189 | "Dataset Prepared\n" 190 | ] 191 | }, 192 | { 193 | "name": "stderr", 194 | "output_type": "stream", 195 | "text": [ 196 | "WARNING: Logging before flag parsing goes to stderr.\n", 197 | "W0904 03:00:48.262387 139725218301696 callbacks.py:2323] `epsilon` argument is deprecated and will be removed, use `min_delta` instead.\n" 198 | ] 199 | }, 200 | { 201 | "name": "stdout", 202 | "output_type": "stream", 203 | "text": [ 204 | "Model: \"functional_1\"\n", 205 | "__________________________________________________________________________________________________\n", 206 | "Layer (type) Output Shape Param # Connected to \n", 207 | "==================================================================================================\n", 208 | "input_1 (InputLayer) [(None, 128, 128, 1) 0 \n", 209 | "__________________________________________________________________________________________________\n", 210 | "conv2d (Conv2D) (None, 128, 128, 64) 640 input_1[0][0] \n", 211 | "__________________________________________________________________________________________________\n", 212 | "conv2d_1 (Conv2D) (None, 128, 128, 64) 36928 conv2d[0][0] \n", 213 | "__________________________________________________________________________________________________\n", 214 | "max_pooling2d (MaxPooling2D) (None, 64, 64, 64) 0 conv2d_1[0][0] \n", 215 | "__________________________________________________________________________________________________\n", 216 | "conv2d_2 (Conv2D) (None, 64, 64, 128) 73856 max_pooling2d[0][0] \n", 217 | "__________________________________________________________________________________________________\n", 218 | "conv2d_3 (Conv2D) (None, 64, 64, 128) 147584 conv2d_2[0][0] \n", 219 | "__________________________________________________________________________________________________\n", 220 | "max_pooling2d_1 (MaxPooling2D) (None, 32, 32, 128) 0 conv2d_3[0][0] \n", 221 | "__________________________________________________________________________________________________\n", 222 | "conv2d_4 (Conv2D) (None, 32, 32, 256) 295168 max_pooling2d_1[0][0] \n", 223 | "__________________________________________________________________________________________________\n", 224 | "conv2d_5 (Conv2D) (None, 32, 32, 256) 590080 conv2d_4[0][0] \n", 225 | "__________________________________________________________________________________________________\n", 226 | "max_pooling2d_2 (MaxPooling2D) (None, 16, 16, 256) 0 conv2d_5[0][0] \n", 227 | "__________________________________________________________________________________________________\n", 228 | "conv2d_6 (Conv2D) (None, 16, 16, 512) 1180160 max_pooling2d_2[0][0] \n", 229 | "__________________________________________________________________________________________________\n", 230 | "conv2d_7 (Conv2D) (None, 16, 16, 512) 2359808 conv2d_6[0][0] \n", 231 | "__________________________________________________________________________________________________\n", 232 | "dropout_1 (Dropout) (None, 16, 16, 512) 0 conv2d_7[0][0] \n", 233 | "__________________________________________________________________________________________________\n", 234 | "conv2d_8 (Conv2D) (None, 16, 16, 512) 2359808 dropout_1[0][0] \n", 235 | "__________________________________________________________________________________________________\n", 236 | "conv2d_9 (Conv2D) (None, 16, 16, 512) 2359808 conv2d_8[0][0] \n", 237 | "__________________________________________________________________________________________________\n", 238 | "dropout_2 (Dropout) (None, 16, 16, 512) 0 conv2d_9[0][0] \n", 239 | "__________________________________________________________________________________________________\n", 240 | "concatenate (Concatenate) (None, 16, 16, 1024) 0 dropout_2[0][0] \n", 241 | " dropout_1[0][0] \n", 242 | "__________________________________________________________________________________________________\n", 243 | "conv2d_10 (Conv2D) (None, 16, 16, 512) 4719104 concatenate[0][0] \n", 244 | "__________________________________________________________________________________________________\n", 245 | "conv2d_11 (Conv2D) (None, 16, 16, 512) 2359808 conv2d_10[0][0] \n", 246 | "__________________________________________________________________________________________________\n", 247 | "dropout_3 (Dropout) (None, 16, 16, 512) 0 conv2d_11[0][0] \n", 248 | "__________________________________________________________________________________________________\n", 249 | "conv2d_transpose (Conv2DTranspo (None, 32, 32, 256) 524544 dropout_3[0][0] \n", 250 | "__________________________________________________________________________________________________\n", 251 | "batch_normalization (BatchNorma (None, 32, 32, 256) 1024 conv2d_transpose[0][0] \n", 252 | "__________________________________________________________________________________________________\n", 253 | "dropout (Dropout) (None, 32, 32, 256) 0 conv2d_5[0][0] \n", 254 | "__________________________________________________________________________________________________\n", 255 | "activation (Activation) (None, 32, 32, 256) 0 batch_normalization[0][0] \n", 256 | "__________________________________________________________________________________________________\n", 257 | "reshape (Reshape) (None, 1, 32, 32, 25 0 dropout[0][0] \n", 258 | "__________________________________________________________________________________________________\n", 259 | "reshape_1 (Reshape) (None, 1, 32, 32, 25 0 activation[0][0] \n", 260 | "__________________________________________________________________________________________________\n", 261 | "concatenate_1 (Concatenate) (None, 2, 32, 32, 25 0 reshape[0][0] \n", 262 | " reshape_1[0][0] \n", 263 | "__________________________________________________________________________________________________\n", 264 | "conv_lst_m2d (ConvLSTM2D) (None, 32, 32, 128) 1769984 concatenate_1[0][0] \n", 265 | "__________________________________________________________________________________________________\n", 266 | "conv2d_12 (Conv2D) (None, 32, 32, 256) 295168 conv_lst_m2d[0][0] \n", 267 | "__________________________________________________________________________________________________\n", 268 | "conv2d_13 (Conv2D) (None, 32, 32, 256) 590080 conv2d_12[0][0] \n", 269 | "__________________________________________________________________________________________________\n", 270 | "conv2d_transpose_1 (Conv2DTrans (None, 64, 64, 128) 131200 conv2d_13[0][0] \n", 271 | "__________________________________________________________________________________________________\n", 272 | "batch_normalization_1 (BatchNor (None, 64, 64, 128) 512 conv2d_transpose_1[0][0] \n", 273 | "__________________________________________________________________________________________________\n", 274 | "activation_1 (Activation) (None, 64, 64, 128) 0 batch_normalization_1[0][0] \n", 275 | "__________________________________________________________________________________________________\n", 276 | "reshape_2 (Reshape) (None, 1, 64, 64, 12 0 conv2d_3[0][0] \n", 277 | "__________________________________________________________________________________________________\n", 278 | "reshape_3 (Reshape) (None, 1, 64, 64, 12 0 activation_1[0][0] \n", 279 | "__________________________________________________________________________________________________\n", 280 | "concatenate_2 (Concatenate) (None, 2, 64, 64, 12 0 reshape_2[0][0] \n", 281 | " reshape_3[0][0] \n", 282 | "__________________________________________________________________________________________________\n", 283 | "conv_lst_m2d_1 (ConvLSTM2D) (None, 64, 64, 64) 442624 concatenate_2[0][0] \n", 284 | "__________________________________________________________________________________________________\n", 285 | "conv2d_14 (Conv2D) (None, 64, 64, 128) 73856 conv_lst_m2d_1[0][0] \n", 286 | "__________________________________________________________________________________________________\n", 287 | "conv2d_15 (Conv2D) (None, 64, 64, 128) 147584 conv2d_14[0][0] \n", 288 | "__________________________________________________________________________________________________\n", 289 | "conv2d_transpose_2 (Conv2DTrans (None, 128, 128, 64) 32832 conv2d_15[0][0] \n", 290 | "__________________________________________________________________________________________________\n", 291 | "batch_normalization_2 (BatchNor (None, 128, 128, 64) 256 conv2d_transpose_2[0][0] \n", 292 | "__________________________________________________________________________________________________\n", 293 | "activation_2 (Activation) (None, 128, 128, 64) 0 batch_normalization_2[0][0] \n", 294 | "__________________________________________________________________________________________________\n", 295 | "reshape_4 (Reshape) (None, 1, 128, 128, 0 conv2d_1[0][0] \n", 296 | "__________________________________________________________________________________________________\n", 297 | "reshape_5 (Reshape) (None, 1, 128, 128, 0 activation_2[0][0] \n", 298 | "__________________________________________________________________________________________________\n", 299 | "concatenate_3 (Concatenate) (None, 2, 128, 128, 0 reshape_4[0][0] \n", 300 | " reshape_5[0][0] \n", 301 | "__________________________________________________________________________________________________\n", 302 | "conv_lst_m2d_2 (ConvLSTM2D) (None, 128, 128, 32) 110720 concatenate_3[0][0] \n", 303 | "__________________________________________________________________________________________________\n", 304 | "conv2d_16 (Conv2D) (None, 128, 128, 64) 18496 conv_lst_m2d_2[0][0] \n", 305 | "__________________________________________________________________________________________________\n", 306 | "conv2d_17 (Conv2D) (None, 128, 128, 64) 36928 conv2d_16[0][0] \n", 307 | "__________________________________________________________________________________________________\n", 308 | "conv2d_18 (Conv2D) (None, 128, 128, 2) 1154 conv2d_17[0][0] \n", 309 | "__________________________________________________________________________________________________\n", 310 | "conv2d_19 (Conv2D) (None, 128, 128, 1) 3 conv2d_18[0][0] \n", 311 | "==================================================================================================\n", 312 | "Total params: 20,659,717\n", 313 | "Trainable params: 20,658,821\n", 314 | "Non-trainable params: 896\n", 315 | "__________________________________________________________________________________________________\n", 316 | "Training\n" 317 | ] 318 | }, 319 | { 320 | "name": "stdout", 321 | "output_type": "stream", 322 | "text": [ 323 | "Epoch 1/20\n", 324 | "360/360 [==============================] - 292s 811ms/step - loss: 0.1528 - accuracy: 0.9642 - val_loss: 0.1258 - val_accuracy: 0.9681\n", 325 | "Epoch 2/20\n", 326 | "360/360 [==============================] - 289s 803ms/step - loss: 0.1127 - accuracy: 0.9664 - val_loss: 0.1294 - val_accuracy: 0.9683\n", 327 | "Epoch 3/20\n", 328 | "360/360 [==============================] - 290s 807ms/step - loss: 0.1005 - accuracy: 0.9699 - val_loss: 0.1030 - val_accuracy: 0.9703\n", 329 | "Epoch 4/20\n", 330 | "360/360 [==============================] - 289s 802ms/step - loss: 0.0884 - accuracy: 0.9733 - val_loss: 0.1331 - val_accuracy: 0.9682\n", 331 | "Epoch 5/20\n", 332 | "278/360 [======================>.......] - ETA: 1:00 - loss: 0.0852 - accuracy: 0.9747" 333 | ] 334 | } 335 | ], 336 | "source": [ 337 | "train(0)" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": null, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "train(1)" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "train(2)" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": null, 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [ 364 | "train(3)" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": null, 370 | "metadata": {}, 371 | "outputs": [], 372 | "source": [ 373 | "train(4)" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [ 382 | "for fold in range(0, 5):\n", 383 | " model = M.BCDU_net_D3(input_size = (128, 128,1))\n", 384 | " model.load_weights('weight_text'+str(fold)+'.hdf5')\n", 385 | " test = DataGenerator(fold, train=False)\n", 386 | "\n", 387 | " for f in test.images:\n", 388 | " img = np.asarray(Image.open('/data/anime/manga/Manga109_2017_09_28/images/' + Path(f).parent.name + '/' + Path(f).stem + '.jpg').convert('L'))\n", 389 | " patches , new_h, new_w = U.extract_ordered_overlap(img, 128, 128, 64, 64)\n", 390 | " patches = np.expand_dims(patches, axis = 3)\n", 391 | " predictions = model.predict(patches, batch_size= 40, verbose=1)\n", 392 | " estimated = U.recompone_overlap(predictions[:,:,:,0], new_h, new_w, 64, 64)\n", 393 | " estimated = np.where(estimated >= 0.7, 1, 0)\n", 394 | " save_path = '/data/anime/predictions/' + Path(f).parent.name\n", 395 | " Path(save_path).mkdir(parents=True, exist_ok=True)\n", 396 | " Image.fromarray(estimated.astype(np.uint8)).save(save_path + '/' + Path(f).stem + '.png')" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": 104, 402 | "metadata": {}, 403 | "outputs": [ 404 | { 405 | "name": "stdout", 406 | "output_type": "stream", 407 | "text": [ 408 | "0 tensor(90.)\n", 409 | "1 tensor(0.)\n", 410 | "2 tensor(0.)\n", 411 | "3 tensor(90.)\n", 412 | "4 tensor(90.)\n" 413 | ] 414 | } 415 | ], 416 | "source": [ 417 | "for fold in range(0, 5):\n", 418 | " test = DataGenerator(fold, train=False)\n", 419 | " count = 0\n", 420 | " for f in test.images:\n", 421 | " pred = open_mask('/data/anime/predictions/' + Path(f).parent.name + '/' + Path(f).stem + '.png').px\n", 422 | " count += pred.max()\n", 423 | " print(fold, count)" 424 | ] 425 | } 426 | ], 427 | "metadata": { 428 | "kernelspec": { 429 | "display_name": "Python 3", 430 | "language": "python", 431 | "name": "python3" 432 | }, 433 | "language_info": { 434 | "codemirror_mode": { 435 | "name": "ipython", 436 | "version": 3 437 | }, 438 | "file_extension": ".py", 439 | "mimetype": "text/x-python", 440 | "name": "python", 441 | "nbconvert_exporter": "python", 442 | "pygments_lexer": "ipython3", 443 | "version": "3.7.3" 444 | } 445 | }, 446 | "nbformat": 4, 447 | "nbformat_minor": 2 448 | } 449 | -------------------------------------------------------------------------------- /code/experiments/datasets.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from fastai.vision import unet_learner, imagenet_stats, torch, Path, os, load_learner, models, SegmentationLabelList, SegmentationItemList, TfmPixel, pil2tensor, Image, ImageSegment, partial\n", 10 | "from experiments import getDatasets, getData, random_seed\n", 11 | "from losses import MixedLoss\n", 12 | "from metrics import MetricsCallback, getDatasetMetrics\n", 13 | "from fastai.callbacks import CSVLogger, SaveModelCallback\n", 14 | "from config import *\n", 15 | "from dataset import pad_tensor\n", 16 | "import PIL\n", 17 | "import numpy as np\n", 18 | "\n", 19 | "%load_ext autoreload\n", 20 | "%autoreload 2\n", 21 | "\n", 22 | "torch.cuda.set_device(0)" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "EXPERIMENT_PATH = Path(EXPERIMENTS_PATH) / 'datasets'\n", 32 | "MODELS_PATH = EXPERIMENT_PATH / \"models\"\n", 33 | "os.makedirs(MODELS_PATH, exist_ok=True)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "def icdar_web_segmentation(x):\n", 43 | " return str(x).replace('/images', '/gt').replace(x.name, 'gt_' + x.stem + '.png')\n", 44 | "\n", 45 | "def icdar_scene_segmentation(x):\n", 46 | " if \"/train/\" in str(x):\n", 47 | " return str(x).replace('/images', '/gt').replace(x.name, x.stem + '_GT' + '.bmp')\n", 48 | " return icdar_web_segmentation(x)\n", 49 | "\n", 50 | "def total_text_segmentation(x):\n", 51 | " return str(x).replace('/images', '/gt')\n", 52 | "\n", 53 | "def dibco_segmentation(x):\n", 54 | " suffix = '.tiff'\n", 55 | " name = x.stem\n", 56 | " if any(map(lambda y: y in str(x), [\"2012\"])):\n", 57 | " suffix = '.tif'\n", 58 | " if any(map(lambda y: y in str(x), [\"2016\", \"2017\", \"2018\", \"2019\"])):\n", 59 | " suffix = '.bmp' \n", 60 | " if any(map(lambda y: y in str(x), [\"2010\", \"2013\", \"2014\"])):\n", 61 | " name = x.stem + '_estGT'\n", 62 | " if any(map(lambda y: y in str(x), [\"2011\", \"2012\"])):\n", 63 | " name = x.stem + '_GT' \n", 64 | " if any(map(lambda y: y in str(x), [\"2016\", \"2017\", \"2018\"])):\n", 65 | " name = x.stem + '_gt' \n", 66 | " return str(x).replace('/images', '/gt').replace(x.name, name + suffix)\n", 67 | "\n", 68 | "def kaist_segmentation(x):\n", 69 | " return str(x).replace(x.suffix, '.bmp')\n", 70 | "\n", 71 | "class CustomSegLabel2(SegmentationLabelList):\n", 72 | " def open(self, fn):\n", 73 | " im = PIL.Image.open(fn).convert('L')\n", 74 | " im.thumbnail((500, 500),PIL.Image.NEAREST)\n", 75 | " im = pil2tensor(im,np.float32)\n", 76 | " \n", 77 | " if str(self.path) != KAIST_PATH:\n", 78 | " im = im // 255\n", 79 | " else:\n", 80 | " im = (im != 0).float()\n", 81 | " \n", 82 | " if str(self.path) in [ICDAR2013_WEB_PATH, ICDAR2013_SCENE_PATH, DIBCO_PATH]:\n", 83 | " im = 1 - im\n", 84 | " return ImageSegment(im)\n", 85 | " \n", 86 | "class SItemListCustom2(SegmentationItemList):\n", 87 | " _label_cls = CustomSegLabel2\n", 88 | " def open(self, fn):\n", 89 | " im = PIL.Image.open(fn).convert('RGB')\n", 90 | " im.thumbnail((500, 500),PIL.Image.ANTIALIAS)\n", 91 | " im = pil2tensor(im,np.float32) / 255\n", 92 | " return Image(im)\n", 93 | "\n", 94 | "def getDataByName(name):\n", 95 | " if name == 'icdar2013-web' or name == 'icdar2013-scene':\n", 96 | " return (SItemListCustom2.from_folder(ICDAR2013_WEB_PATH if name == 'icdar2013-web' else ICDAR2013_SCENE_PATH)\n", 97 | " .filter_by_func(lambda p: '/train/images' in str(p) or '/test/images' in str(p))\n", 98 | " .split_none()\n", 99 | " .label_from_func(icdar_web_segmentation if name == 'icdar2013-web' else icdar_scene_segmentation, classes=['text']))\n", 100 | " elif name == 'total-text':\n", 101 | " return (SItemListCustom2.from_folder(TOTAL_TEXT_PATH)\n", 102 | " .filter_by_func(lambda p: '/train/images' in str(p) or '/test/images' in str(p))\n", 103 | " .split_none()\n", 104 | " .label_from_func(total_text_segmentation, classes=['text'])) \n", 105 | " elif name == 'dibco':\n", 106 | " return (SItemListCustom2.from_folder(DIBCO_PATH)\n", 107 | " .filter_by_func(lambda p: '/images' in str(p))\n", 108 | " .split_none()\n", 109 | " .label_from_func(dibco_segmentation, classes=['text'])) \n", 110 | " elif name == 'kaist':\n", 111 | " return (SItemListCustom2.from_folder(KAIST_PATH)\n", 112 | " .filter_by_func(lambda p: p.suffix != '.bmp')\n", 113 | " .split_none()\n", 114 | " .label_from_func(kaist_segmentation, classes=['text'])) \n", 115 | " \n", 116 | "def getDatabunch(name):\n", 117 | " props = {'bs': 1, 'val_bs': 2, 'num_workers': 0}\n", 118 | " random_seed(42)\n", 119 | " data = getDataByName(name)\n", 120 | " tfms = [TfmPixel(pad_tensor)(multiple = 8)]\n", 121 | " data.train.transform(tfms, tfm_y=True)\n", 122 | " data.valid = getDatasets(allData)[0].valid\n", 123 | " random_seed(42)\n", 124 | " return data.databunch(**props).normalize(imagenet_stats)\n", 125 | " \n", 126 | "allData = getData()" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "for name in ['icdar2013-web', 'icdar2013-scene', 'total-text', 'dibco', 'kaist']:\n", 136 | " PATH = EXPERIMENT_PATH / name\n", 137 | " print(name)\n", 138 | " if not (PATH / 'model.pkl').exists():\n", 139 | " learn = unet_learner(getDatabunch(name), models.resnet34, model_dir='models', callback_fns=[MetricsCallback, CSVLogger, partial(SaveModelCallback, monitor = 'normal pixel f1 %')], loss_func=MixedLoss(0.0, 1.0), path=PATH)\n", 140 | " random_seed(42)\n", 141 | " learn.fit_one_cycle(10, 1e-4)\n", 142 | " learn.save('model')\n", 143 | " learn.export(file='model.pkl')\n", 144 | " if not (PATH / 'predictions.csv').exists():\n", 145 | " learn = load_learner(PATH, 'model.pkl')\n", 146 | " random_seed(42)\n", 147 | " m = getDatasetMetrics(getDataset(allData), learn)\n", 148 | " m.save(PATH / 'predictions.csv', False)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "props = {'bs': 1, 'val_bs': 2, 'num_workers': 0}\n", 158 | "name = 'manga'\n", 159 | "random_seed(42)\n", 160 | "for index, dataset in enumerate(getDatasets(allData, crop=False, cutInHalf=False)):\n", 161 | " PATH = EXPERIMENT_PATH / name / str(index)\n", 162 | " if not (PATH / 'model.pkl').exists() or True:\n", 163 | " random_seed(42)\n", 164 | " data = dataset.databunch(**props).normalize(imagenet_stats)\n", 165 | " learn = unet_learner(data, models.resnet34, model_dir='models', callback_fns=[MetricsCallback, CSVLogger, partial(SaveModelCallback, monitor = 'normal pixel f1 %')], loss_func=MixedLoss(0.0, 1.0), path=PATH)\n", 166 | " random_seed(42)\n", 167 | " learn.fit_one_cycle(10, 1e-4)\n", 168 | " learn.save('model')\n", 169 | " learn.export(file='model.pkl')\n", 170 | " if not (PATH / 'predictions.csv').exists():\n", 171 | " learn = load_learner(PATH, 'model.pkl')\n", 172 | " random_seed(42)\n", 173 | " m = getDatasetMetrics(dataset, learn)\n", 174 | " m.save(PATH / 'predictions.csv', False)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "#unzipping KAIST dataset\n", 184 | "from contextlib import closing\n", 185 | "from zipfile import ZipFile\n", 186 | "\n", 187 | "count = dict()\n", 188 | "seen = dict()\n", 189 | "\n", 190 | "for f in sorted(glob.glob(KAIST_PATH + \"/KAIST/**/*.zip\", recursive=True)):\n", 191 | " with closing(ZipFile(f)) as archive:\n", 192 | " for info in archive.infolist():\n", 193 | " name = f + info.filename \n", 194 | " if \"Digital_Camera/(C.S)C-outdoor4.zipDSC03706\" in name:\n", 195 | " name = name.replace('zipDSC03706', 'zipDSC03707')\n", 196 | " if name.endswith('.bmp') or name.lower().endswith('.jpg'):\n", 197 | " if name[0:-4] not in count:\n", 198 | " count[name[0:-4]] = len(count)\n", 199 | " seen[name[0:-4]] = 0\n", 200 | " seen[name[0:-4]] += 1\n", 201 | " info.filename = str(count[name[0:-4]]) + Path(name).suffix\n", 202 | " archive.extract(info, path = KAIST_PATH + '/images/')\n", 203 | "for k in seen.keys():\n", 204 | " if seen[k] != 2:\n", 205 | " print(k, seen[k]) " 206 | ] 207 | } 208 | ], 209 | "metadata": { 210 | "kernelspec": { 211 | "display_name": "Python 3", 212 | "language": "python", 213 | "name": "python3" 214 | }, 215 | "language_info": { 216 | "codemirror_mode": { 217 | "name": "ipython", 218 | "version": 3 219 | }, 220 | "file_extension": ".py", 221 | "mimetype": "text/x-python", 222 | "name": "python", 223 | "nbconvert_exporter": "python", 224 | "pygments_lexer": "ipython3", 225 | "version": "3.7.3" 226 | } 227 | }, 228 | "nbformat": 4, 229 | "nbformat_minor": 4 230 | } 231 | -------------------------------------------------------------------------------- /code/experiments/experiments.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import os 5 | from pathlib import Path 6 | from sklearn.model_selection import KFold 7 | 8 | import sys 9 | sys.path.append('..') 10 | 11 | from dataset import get_segmentation, random_crop, SegItemListCustom 12 | from config import * 13 | 14 | def random_seed(seed_value, use_cuda = True): 15 | np.random.seed(seed_value) # cpu vars 16 | torch.manual_seed(seed_value) # cpu vars 17 | random.seed(seed_value) # Python 18 | os.environ['PYTHONHASHSEED'] = str(seed_value) 19 | if use_cuda: 20 | torch.cuda.manual_seed(seed_value) 21 | torch.cuda.manual_seed_all(seed_value) # gpu vars 22 | torch.backends.cudnn.deterministic = True #needed 23 | torch.backends.cudnn.benchmark = False 24 | 25 | trainFolders = [ 26 | 'ARMS', 27 | 'AisazuNihaIrarenai', 28 | 'AkkeraKanjinchou', 29 | 'Akuhamu', 30 | 'AosugiruHaru', 31 | 'AppareKappore', 32 | 'Arisa', 33 | 'BEMADER_P', 34 | 'BakuretsuKungFuGirl', 35 | 'Belmondo', 36 | 'BokuHaSitatakaKun', 37 | 'BurariTessenTorimonocho', 38 | 'ByebyeC-BOY', 39 | 'Count3DeKimeteAgeru', 40 | 'DollGun', 41 | 'Donburakokko', 42 | 'DualJustice', 43 | 'EienNoWith', 44 | 'EvaLady', 45 | 'EverydayOsakanaChan', 46 | 'GOOD_KISS_Ver2', 47 | 'GakuenNoise', 48 | 'GarakutayaManta', 49 | 'GinNoChimera', 50 | 'Hamlet', 51 | 'HanzaiKousyouninMinegishiEitarou', 52 | 'HaruichibanNoFukukoro', 53 | 'HarukaRefrain', 54 | 'HealingPlanet', 55 | "UchiNoNyan'sDiary", 56 | 'UchuKigekiM774', 57 | 'UltraEleven', 58 | 'UnbalanceTokyo', 59 | 'WarewareHaOniDearu', 60 | 'YamatoNoHane', 61 | 'YasasiiAkuma', 62 | 'YouchienBoueigumi', 63 | 'YoumaKourin', 64 | 'YukiNoFuruMachi', 65 | 'YumeNoKayoiji', 66 | 'YumeiroCooking', 67 | 'TotteokiNoABC', 68 | 'ToutaMairimasu', 69 | 'TouyouKidan', 70 | 'TsubasaNoKioku' 71 | ] 72 | 73 | assert(len(trainFolders) == len(set(trainFolders))) 74 | 75 | for x in trainFolders: 76 | assert(os.path.isdir(MASKS_PATH + '/' + x)) 77 | 78 | def getDatasetLists(dataset): 79 | return [dataset.train.x, dataset.train.y, dataset.valid.x, dataset.valid.y] 80 | 81 | #gets Kfolded data in order to train the models, 4/5 goes to train and 1/5 to validation in each fold. 82 | def getDatasets(allData, crop=True, padding = 8, cutInHalf = True): 83 | folds = KFold(n_splits = 5, shuffle = True, random_state=42).split(trainFolders, trainFolders) 84 | 85 | datasets = [] 86 | 87 | for _, valid_indexes in folds: 88 | dataset = (allData 89 | .split_by_valid_func(lambda x: trainFolders.index(Path(x).parent.name) in valid_indexes) 90 | .label_from_func(get_segmentation, classes=['text'])) 91 | 92 | if crop: 93 | dataset.train.transform([random_crop(size=(512, 800), randx=(0.0, 1.0), randy=(0.0, 1.0))], tfm_y=True) 94 | 95 | for l in getDatasetLists(dataset): 96 | l.padding = padding 97 | l.cutInHalf = cutInHalf 98 | 99 | datasets.append(dataset) 100 | 101 | return datasets 102 | 103 | #returns single dataset to use by methods that were not trained with manga (icdar2013, total-text, synthetic) 104 | def getDataset(allData): 105 | dataset = (allData 106 | .split_by_valid_func(lambda _: True) 107 | .label_from_func(get_segmentation, classes=['text'])) 108 | for l in getDatasetLists(dataset): 109 | l.padding = 8 110 | l.cutInHalf = False 111 | return dataset 112 | 113 | #given prediction and ground truth, returns colorized tensor with true positives as green, false positives as red and false negative as white 114 | def colorizePrediction(prediction, truth): 115 | prediction, truth = prediction[0], truth[0] 116 | colorized = torch.zeros(4, prediction.shape[0], prediction.shape[1]).int() 117 | r, g, b, a = colorized[:] 118 | 119 | fn = (truth >= 1) & (truth <= 5) & (truth != 3) & (prediction == 0) 120 | tp = ((truth >= 1) & (truth <= 5)) & (prediction == 1) 121 | fp = (truth == 0) & (prediction == 1) 122 | 123 | r[fp] = 255 124 | r[fn] = g[fn] = b[fn] = 255 125 | g[tp] = 255 126 | 127 | a[:, :] = 128 128 | a[tp | fn | fp] = 255 129 | 130 | return colorized 131 | 132 | #given an image index from a folder (like ARMS, 0) finds which dataset has it in validation and the index inside it 133 | def findImage(datasets, folder, index): 134 | for dIndex, dataset in enumerate(datasets): 135 | for idx, item in enumerate(dataset.valid.y.items): 136 | if "/" + folder + "/" in item and str(index) + ".png" in item: 137 | return dIndex, idx 138 | 139 | 140 | def getData(): 141 | random_seed(42) 142 | return (SegItemListCustom.from_folder(MANGA109_PATH) 143 | .filter_by_func(lambda x: Path(get_segmentation(x)).exists() and Path(x).parent.name in (trainFolders))) -------------------------------------------------------------------------------- /code/experiments/hrnet/manga.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | import os 8 | 9 | import cv2 10 | import numpy as np 11 | from PIL import Image 12 | 13 | import torch 14 | from torch.nn import functional as F 15 | 16 | from .base_dataset import BaseDataset 17 | 18 | from pathlib import Path 19 | from sklearn.model_selection import KFold 20 | import glob 21 | 22 | trainFolders = [ 23 | 'ARMS', 24 | 'AisazuNihaIrarenai', 25 | 'AkkeraKanjinchou', 26 | 'Akuhamu', 27 | 'AosugiruHaru', 28 | 'AppareKappore', 29 | 'Arisa', 30 | 'BEMADER_P', 31 | 'BakuretsuKungFuGirl', 32 | 'Belmondo', 33 | 'BokuHaSitatakaKun', 34 | 'BurariTessenTorimonocho', 35 | 'ByebyeC-BOY', 36 | 'Count3DeKimeteAgeru', 37 | 'DollGun', 38 | 'Donburakokko', 39 | 'DualJustice', 40 | 'EienNoWith', 41 | 'EvaLady', 42 | 'EverydayOsakanaChan', 43 | 'GOOD_KISS_Ver2', 44 | 'GakuenNoise', 45 | 'GarakutayaManta', 46 | 'GinNoChimera', 47 | 'Hamlet', 48 | 'HanzaiKousyouninMinegishiEitarou', 49 | 'HaruichibanNoFukukoro', 50 | 'HarukaRefrain', 51 | 'HealingPlanet', 52 | "UchiNoNyan'sDiary", 53 | 'UchuKigekiM774', 54 | 'UltraEleven', 55 | 'UnbalanceTokyo', 56 | 'WarewareHaOniDearu', 57 | 'YamatoNoHane', 58 | 'YasasiiAkuma', 59 | 'YouchienBoueigumi', 60 | 'YoumaKourin', 61 | 'YukiNoFuruMachi', 62 | 'YumeNoKayoiji', 63 | 'YumeiroCooking', 64 | 'TotteokiNoABC', 65 | 'ToutaMairimasu', 66 | 'TouyouKidan', 67 | 'TsubasaNoKioku' 68 | ] 69 | 70 | def get_segmentation(x): 71 | return '/data/anime2/manga/' + Path(x).parent.name + '/' + Path(x).name.replace('.jpg', '.cache') 72 | 73 | class Manga(BaseDataset): 74 | def __init__(self, 75 | root, 76 | list_path, 77 | num_samples=None, 78 | num_classes=19, 79 | multi_scale=True, 80 | flip=True, 81 | ignore_label=-1, 82 | base_size=2048, 83 | crop_size=(512, 1024), 84 | center_crop_test=False, 85 | downsample_rate=1, 86 | scale_factor=16, 87 | mean=[0.485, 0.456, 0.406], 88 | fold=0, 89 | std=[0.229, 0.224, 0.225]): 90 | 91 | super(Manga, self).__init__(ignore_label, base_size, 92 | crop_size, downsample_rate, scale_factor, mean, std,) 93 | 94 | self.root = root 95 | self.list_path = list_path 96 | self.fold_index = fold 97 | self.num_classes = 2 98 | self.class_weights = torch.FloatTensor([1, 10]).cuda() 99 | 100 | self.multi_scale = multi_scale 101 | self.flip = flip 102 | self.center_crop_test = center_crop_test 103 | 104 | 105 | self.files = self.read_files() 106 | 107 | if num_samples: 108 | self.files = self.files[:num_samples] 109 | 110 | def read_files(self): 111 | folds = KFold(n_splits = 5, shuffle = True, random_state=42).split(trainFolders, trainFolders) 112 | files = [] 113 | for idx, (_, valid_indexes) in enumerate(folds): 114 | if idx == self.fold_index: 115 | for x in sorted(glob.glob('/data/anime/**/*.jpg', recursive=True)): 116 | if Path(get_segmentation(x)).exists() and Path(x).parent.name in trainFolders: 117 | folder = Path(x).parent.name 118 | idx = trainFolders.index(folder) 119 | file = { 120 | "img": x, 121 | "name": Path(x).parent.name + '/' + Path(x).stem, 122 | "weight": 1, 123 | "label": get_segmentation(x) 124 | } 125 | if 'test' == self.list_path and idx in valid_indexes: 126 | files.append(file) 127 | elif 'train' == self.list_path and idx not in valid_indexes: 128 | files.append(file) 129 | 130 | return files 131 | 132 | def convert_label(self, label, inverse=False): 133 | label[label >= 3] -= 3 134 | label = label != 0 135 | return label.astype(np.uint8) 136 | 137 | def __getitem__(self, index): 138 | item = self.files[index] 139 | name = item["name"] 140 | image = cv2.imread(item["img"], cv2.IMREAD_COLOR) 141 | size = image.shape 142 | 143 | label = cv2.imread(item["label"], cv2.IMREAD_GRAYSCALE) 144 | 145 | label = self.convert_label(label) 146 | 147 | if 'test' == self.list_path: 148 | image = self.input_transform(image) 149 | image = image.transpose((2, 0, 1)) 150 | return image.copy(), label.copy(), np.array(size), name 151 | 152 | 153 | image, label = self.gen_sample(image, label, 154 | self.multi_scale, self.flip, 155 | self.center_crop_test) 156 | 157 | return image.copy(), label.copy(), np.array(size), name 158 | 159 | def multi_scale_inference(self, model, image, scales=[1], flip=False): 160 | batch, _, ori_height, ori_width = image.size() 161 | assert batch == 1, "only supporting batchsize 1." 162 | image = image.numpy()[0].transpose((1,2,0)).copy() 163 | stride_h = np.int(self.crop_size[0] * 1.0) 164 | stride_w = np.int(self.crop_size[1] * 1.0) 165 | final_pred = torch.zeros([1, self.num_classes, 166 | ori_height,ori_width]).cuda() 167 | for scale in scales: 168 | new_img = self.multi_scale_aug(image=image, 169 | rand_scale=scale, 170 | rand_crop=False) 171 | height, width = new_img.shape[:-1] 172 | 173 | if scale <= 1.0: 174 | new_img = new_img.transpose((2, 0, 1)) 175 | new_img = np.expand_dims(new_img, axis=0) 176 | new_img = torch.from_numpy(new_img) 177 | preds = self.inference(model, new_img, flip) 178 | preds = preds[:, :, 0:height, 0:width] 179 | else: 180 | new_h, new_w = new_img.shape[:-1] 181 | rows = np.int(np.ceil(1.0 * (new_h - 182 | self.crop_size[0]) / stride_h)) + 1 183 | cols = np.int(np.ceil(1.0 * (new_w - 184 | self.crop_size[1]) / stride_w)) + 1 185 | preds = torch.zeros([1, self.num_classes, 186 | new_h,new_w]).cuda() 187 | count = torch.zeros([1,1, new_h, new_w]).cuda() 188 | 189 | for r in range(rows): 190 | for c in range(cols): 191 | h0 = r * stride_h 192 | w0 = c * stride_w 193 | h1 = min(h0 + self.crop_size[0], new_h) 194 | w1 = min(w0 + self.crop_size[1], new_w) 195 | h0 = max(int(h1 - self.crop_size[0]), 0) 196 | w0 = max(int(w1 - self.crop_size[1]), 0) 197 | crop_img = new_img[h0:h1, w0:w1, :] 198 | crop_img = crop_img.transpose((2, 0, 1)) 199 | crop_img = np.expand_dims(crop_img, axis=0) 200 | crop_img = torch.from_numpy(crop_img) 201 | pred = self.inference(model, crop_img, flip) 202 | preds[:,:,h0:h1,w0:w1] += pred[:,:, 0:h1-h0, 0:w1-w0] 203 | count[:,:,h0:h1,w0:w1] += 1 204 | preds = preds / count 205 | preds = preds[:,:,:height,:width] 206 | preds = F.upsample(preds, (ori_height, ori_width), 207 | mode='bilinear') 208 | final_pred += preds 209 | return final_pred 210 | 211 | def get_palette(self, n): 212 | palette = [0] * (n * 3) 213 | for j in range(0, n): 214 | lab = j 215 | palette[j * 3 + 0] = 0 216 | palette[j * 3 + 1] = 0 217 | palette[j * 3 + 2] = 0 218 | i = 0 219 | while lab: 220 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 221 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 222 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 223 | i += 1 224 | lab >>= 3 225 | return palette 226 | 227 | def save_pred(self, preds, sv_path, name): 228 | palette = self.get_palette(256) 229 | preds = preds.cpu().numpy().copy() 230 | preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8) 231 | for i in range(preds.shape[0]): 232 | pred = self.convert_label(preds[i], inverse=True) 233 | save_img = Image.fromarray(pred) 234 | save_img.putpalette(palette) 235 | fname = os.path.join(sv_path, name[i]+'.png') 236 | os.makedirs(os.path.dirname(fname), exist_ok=True) 237 | save_img.save(fname) 238 | 239 | 240 | 241 | -------------------------------------------------------------------------------- /code/experiments/hrnet/manga.yaml: -------------------------------------------------------------------------------- 1 | CUDNN: 2 | BENCHMARK: true 3 | DETERMINISTIC: false 4 | ENABLED: true 5 | GPUS: (0,) 6 | OUTPUT_DIR: 'output' 7 | LOG_DIR: 'log' 8 | WORKERS: 4 9 | PRINT_FREQ: 120 10 | 11 | DATASET: 12 | DATASET: manga 13 | ROOT: 'HRNet-Semantic-Segmentation/data/' 14 | TEST_SET: 'test' 15 | TRAIN_SET: 'train' 16 | NUM_CLASSES: 2 17 | MODEL: 18 | NAME: seg_hrnet 19 | PRETRAINED: 'pretrained_models/hrnetv2_w48_imagenet_pretrained.pth' 20 | EXTRA: 21 | FINAL_CONV_KERNEL: 1 22 | STAGE1: 23 | NUM_MODULES: 1 24 | NUM_RANCHES: 1 25 | BLOCK: BOTTLENECK 26 | NUM_BLOCKS: 27 | - 4 28 | NUM_CHANNELS: 29 | - 64 30 | FUSE_METHOD: SUM 31 | STAGE2: 32 | NUM_MODULES: 1 33 | NUM_BRANCHES: 2 34 | BLOCK: BASIC 35 | NUM_BLOCKS: 36 | - 4 37 | - 4 38 | NUM_CHANNELS: 39 | - 48 40 | - 96 41 | FUSE_METHOD: SUM 42 | STAGE3: 43 | NUM_MODULES: 4 44 | NUM_BRANCHES: 3 45 | BLOCK: BASIC 46 | NUM_BLOCKS: 47 | - 4 48 | - 4 49 | - 4 50 | NUM_CHANNELS: 51 | - 48 52 | - 96 53 | - 192 54 | FUSE_METHOD: SUM 55 | STAGE4: 56 | NUM_MODULES: 3 57 | NUM_BRANCHES: 4 58 | BLOCK: BASIC 59 | NUM_BLOCKS: 60 | - 4 61 | - 4 62 | - 4 63 | - 4 64 | NUM_CHANNELS: 65 | - 48 66 | - 96 67 | - 192 68 | - 384 69 | FUSE_METHOD: SUM 70 | LOSS: 71 | USE_OHEM: false 72 | USE_DICE: true 73 | OHEMTHRES: 0.9 74 | OHEMKEEP: 131072 75 | TRAIN: 76 | IMAGE_SIZE: 77 | - 1323 78 | - 936 79 | BASE_SIZE: 1654 80 | BATCH_SIZE_PER_GPU: 1 81 | SHUFFLE: true 82 | BEGIN_EPOCH: 0 83 | END_EPOCH: 100 84 | RESUME: true 85 | OPTIMIZER: sgd 86 | LR: 0.01 87 | WD: 0.0005 88 | MOMENTUM: 0.9 89 | NESTEROV: false 90 | FLIP: true 91 | MULTI_SCALE: true 92 | DOWNSAMPLERATE: 1 93 | IGNORE_LABEL: -1 94 | SCALE_FACTOR: 6 95 | TEST: 96 | IMAGE_SIZE: 97 | - 1654 98 | - 1170 99 | BASE_SIZE: 1654 100 | BATCH_SIZE_PER_GPU: 1 101 | FLIP_TEST: false 102 | MULTI_SCALE: false -------------------------------------------------------------------------------- /code/experiments/loss.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from fastai.vision import unet_learner, imagenet_stats, torch, Path, os, load_learner, models\n", 10 | "from experiments import getDatasets, getData, random_seed\n", 11 | "from losses import BCELoss, MixedLoss\n", 12 | "from metrics import MetricsCallback, getDatasetMetrics\n", 13 | "from fastai.callbacks import CSVLogger\n", 14 | "from config import *\n", 15 | "\n", 16 | "%load_ext autoreload\n", 17 | "%autoreload 2\n", 18 | "\n", 19 | "torch.cuda.set_device(0)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "EXPERIMENT_PATH = Path(EXPERIMENTS_PATH) / 'loss'\n", 29 | "MODELS_PATH = EXPERIMENT_PATH / \"models\"\n", 30 | "os.makedirs(MODELS_PATH, exist_ok=True)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "allData = getData()" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "props = {'bs': 4, 'val_bs': 2, 'num_workers': 0}\n", 49 | "losses = {'bce0.5': BCELoss(0.5), 'bce1': BCELoss(1), 'bce5': BCELoss(5), 'bce10': BCELoss(10), 'bce30': BCELoss(30), \n", 50 | " 'mixed_10_2': MixedLoss(10.0, 2.0), 'mixed_10_1': MixedLoss(10.0, 1.0),\n", 51 | " 'mixed_5_2': MixedLoss(5.0, 2.0), 'mixed_5_1': MixedLoss(5.0, 1.0),\n", 52 | " 'mixed_5_2': MixedLoss(2.0, 2.0), 'mixed_5_1': MixedLoss(2.0, 1.0),\n", 53 | " 'dice': MixedLoss(0.0, 1.0)\n", 54 | " }" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": { 61 | "scrolled": false 62 | }, 63 | "outputs": [], 64 | "source": [ 65 | "for name, loss in losses.items():\n", 66 | " for index, dataset in enumerate(getDatasets(allData)):\n", 67 | " PATH = EXPERIMENT_PATH / name / str(index)\n", 68 | " if not (PATH / 'final model.pkl').exists():\n", 69 | " random_seed(42)\n", 70 | " data = dataset.databunch(**props).normalize(imagenet_stats)\n", 71 | " random_seed(42)\n", 72 | " learn = unet_learner(data, models.resnet18, callback_fns=[MetricsCallback, CSVLogger], model_dir='models', loss_func=loss, path=PATH)\n", 73 | " random_seed(42)\n", 74 | " learn.fit_one_cycle(10, 1e-4)\n", 75 | " learn.save('model')\n", 76 | " learn.export(file='final model.pkl')\n", 77 | " for index, dataset in enumerate(getDatasets(allData, crop=False, cutInHalf=False)): \n", 78 | " PATH = EXPERIMENT_PATH / name / str(index)\n", 79 | " if not (PATH / 'final predictions.csv').exists():\n", 80 | " learn = load_learner(PATH, 'final model.pkl')\n", 81 | " random_seed(42)\n", 82 | " m = getDatasetMetrics(dataset, learn)\n", 83 | " m.save(PATH / 'final predictions.csv')" 84 | ] 85 | } 86 | ], 87 | "metadata": { 88 | "kernelspec": { 89 | "display_name": "Python 3", 90 | "language": "python", 91 | "name": "python3" 92 | }, 93 | "language_info": { 94 | "codemirror_mode": { 95 | "name": "ipython", 96 | "version": 3 97 | }, 98 | "file_extension": ".py", 99 | "mimetype": "text/x-python", 100 | "name": "python", 101 | "nbconvert_exporter": "python", 102 | "pygments_lexer": "ipython3", 103 | "version": "3.7.3" 104 | } 105 | }, 106 | "nbformat": 4, 107 | "nbformat_minor": 2 108 | } 109 | -------------------------------------------------------------------------------- /code/experiments/model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from fastai.vision import unet_learner, imagenet_stats, torch, Path, os, load_learner, models, sys, Learner, partial, flatten_model, requires_grad, bn_types, defaults\n", 10 | "from experiments import getDatasets, getData, random_seed\n", 11 | "from losses import MixedLoss\n", 12 | "from metrics import MetricsCallback, getDatasetMetrics\n", 13 | "from fastai.callbacks import CSVLogger, SaveModelCallback\n", 14 | "from config import *\n", 15 | "sys.path.append('../../text-segmentation')\n", 16 | "\n", 17 | "%load_ext autoreload\n", 18 | "%autoreload 2\n", 19 | "\n", 20 | "from models.text_segmentation import TextSegament, XceptionTextSegment\n", 21 | "\n", 22 | "torch.cuda.set_device(0)" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "EXPERIMENT_PATH = Path(EXPERIMENTS_PATH) / 'model'\n", 32 | "MODELS_PATH = EXPERIMENT_PATH / \"models\"\n", 33 | "os.makedirs(MODELS_PATH, exist_ok=True)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "allData = getData()" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "props = {'bs': 4, 'val_bs': 2, 'num_workers': 0}\n", 52 | "modelDict = {'resnet34': models.resnet34, 'xception': XceptionTextSegment(), 'segament': TextSegament()}\n", 53 | "propsOverride = {\n", 54 | " 'xception': {'bs': 2},\n", 55 | " 'segament': {'bs': 2}\n", 56 | "}" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": { 63 | "scrolled": false 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "for name, model in list(modelDict.items()):\n", 68 | " for index, dataset in enumerate(getDatasets(allData)):\n", 69 | " PATH = EXPERIMENT_PATH / name / str(index) \n", 70 | " if not (PATH / 'final model.pkl').exists():\n", 71 | " overrides = {} if name not in propsOverride else propsOverride[name]\n", 72 | " random_seed(42)\n", 73 | " data = dataset.databunch(**{**props, **overrides}).normalize(imagenet_stats)\n", 74 | " func = Learner if name in [\"xception\", \"segament\"] else unet_learner\n", 75 | " random_seed(42)\n", 76 | " learn = func(data, model, callback_fns=[MetricsCallback, CSVLogger], model_dir='models', loss_func=MixedLoss(0, 1), path=PATH)\n", 77 | " random_seed(42)\n", 78 | " learn.fit_one_cycle(10, 1e-4)\n", 79 | " learn.save('model')\n", 80 | " learn.export(file='final model.pkl')\n", 81 | " for index, dataset in enumerate(getDatasets(allData, crop=False, cutInHalf=False)): \n", 82 | " PATH = EXPERIMENT_PATH / name / str(index) \n", 83 | " if not (PATH / 'final predictions.csv').exists():\n", 84 | " learn = load_learner(PATH, 'final model.pkl')\n", 85 | " random_seed(42)\n", 86 | " m = getDatasetMetrics(dataset, learn)\n", 87 | " m.save(PATH / 'final predictions.csv')" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": { 94 | "scrolled": true 95 | }, 96 | "outputs": [], 97 | "source": [ 98 | "import segmentation_models_pytorch as smp\n", 99 | "\n", 100 | "props = {'bs': 4, 'val_bs': 2, 'num_workers': 0}\n", 101 | "models = ['resnet50', 'dpn68', 'vgg16', 'densenet169', 'efficientnet-b4']\n", 102 | "propsOverride = {}\n", 103 | "archs = [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet, smp.PAN]\n", 104 | "for arch in archs:\n", 105 | " for model in models:\n", 106 | " if model in ['vgg16', 'densenet169'] and smp.PAN == arch: #not supported\n", 107 | " continue\n", 108 | " for index, dataset in enumerate(getDatasets(allData, padding = 16)):\n", 109 | " PATH = EXPERIMENT_PATH / (model + ' ' + arch.__name__) / str(index) \n", 110 | " if not (PATH / 'final model.pkl').exists():\n", 111 | " overrides = {} if model not in propsOverride else propsOverride[model]\n", 112 | " random_seed(42)\n", 113 | " data = dataset.databunch(**{**props, **overrides}).normalize(imagenet_stats)\n", 114 | " random_seed(42)\n", 115 | " learn = Learner(data, arch(model, encoder_weights='imagenet'), callback_fns=[MetricsCallback, CSVLogger, partial(SaveModelCallback, monitor=\"ignore global f1 score %\")], model_dir='models', loss_func=MixedLoss(0, 1), path=PATH)\n", 116 | " random_seed(42)\n", 117 | " #freeze encoder, still not implemented in smp\n", 118 | " if hasattr(learn.model, 'reset'): learn.model.reset()\n", 119 | " for l in flatten_model(learn.model.encoder):\n", 120 | " requires_grad(l, isinstance(l, bn_types))\n", 121 | " learn.create_opt(defaults.lr)\n", 122 | " random_seed(42)\n", 123 | " learn.fit_one_cycle(10, 1e-4)\n", 124 | " learn.save('model')\n", 125 | " learn.export(file='final model.pkl')\n", 126 | " for index, dataset in enumerate(getDatasets(allData, crop=False, cutInHalf=False, padding = 16)):\n", 127 | " PATH = EXPERIMENT_PATH / (model + ' ' + arch.__name__) / str(index) \n", 128 | " if not (PATH / 'final predictions.csv').exists():\n", 129 | " learn = load_learner(PATH, 'final model.pkl')\n", 130 | " random_seed(42)\n", 131 | " m = getDatasetMetrics(dataset, learn)\n", 132 | " m.save(PATH / 'final predictions.csv') " 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "PATH = EXPERIMENT_PATH / 'xception' \n", 142 | "for index, dataset in enumerate(getDatasets(allData, crop=False, cutInHalf = False)):\n", 143 | " learn = load_learner(PATH / str(index) , 'final model.pkl')\n", 144 | "\n", 145 | " for idx in range(len(dataset.valid.x.items)):\n", 146 | " img = dataset.valid.x.items[idx]\n", 147 | " TENSOR_PATH = PATH / 'predictions' / img.parent.name / img.name.replace(path.suffix, '.pt')\n", 148 | " (PATH / 'predictions' / img.parent.name).mkdir(parents=True, exist_ok=True) \n", 149 | " if not (TENSOR_PATH).exists():\n", 150 | " pred = learn.predict(dataset.valid.x.get(idx, False))[2]\n", 151 | " torch.save(pred, TENSOR_PATH) " 152 | ] 153 | } 154 | ], 155 | "metadata": { 156 | "kernelspec": { 157 | "display_name": "Python 3", 158 | "language": "python", 159 | "name": "python3" 160 | }, 161 | "language_info": { 162 | "codemirror_mode": { 163 | "name": "ipython", 164 | "version": 3 165 | }, 166 | "file_extension": ".py", 167 | "mimetype": "text/x-python", 168 | "name": "python", 169 | "nbconvert_exporter": "python", 170 | "pygments_lexer": "ipython3", 171 | "version": "3.7.3" 172 | } 173 | }, 174 | "nbformat": 4, 175 | "nbformat_minor": 2 176 | } 177 | -------------------------------------------------------------------------------- /code/experiments/refine.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import fastai\n", 10 | "from fastai.vision import *\n", 11 | "from experiments import *\n", 12 | "from losses import *\n", 13 | "from dataset import *\n", 14 | "from metrics import *\n", 15 | "from config import *\n", 16 | "from fastai.callbacks import CSVLogger\n", 17 | "\n", 18 | "%load_ext autoreload\n", 19 | "%autoreload 2\n", 20 | "\n", 21 | "torch.cuda.set_device(0)" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "random_seed(42)\n", 31 | "allData = (SegItemListCustom.from_folder(MANGA109_PATH)\n", 32 | " .filter_by_func(lambda x: Path(get_segmentation(x)).exists() and Path(x).parent.name in (trainFolders)))" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "props = {'bs': 4, 'val_bs': 2, 'num_workers': 0}\n", 42 | "paths = ['model/resnet34', 'loss/dice']" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 4, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "def isNotCsvLogger(c):\n", 52 | " if hasattr(c, '__name__') and c.__name__ == 'CSVLogger':\n", 53 | " return False\n", 54 | " if hasattr(c, 'func'):\n", 55 | " return isNotCsvLogger(c.func)\n", 56 | " return True" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 5, 62 | "metadata": { 63 | "scrolled": false 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "for path in paths:\n", 68 | " for index, dataset in enumerate(getDatasets(allData)):\n", 69 | " MODEL_PATH = Path(EXPERIMENTS_PATH) / path / str(index)\n", 70 | " if not (MODEL_PATH / 'final refined model.pkl').exists():\n", 71 | " random_seed(42)\n", 72 | " data = dataset.databunch(**props).normalize(imagenet_stats)\n", 73 | " random_seed(42)\n", 74 | " learn = load_learner(MODEL_PATH, 'final model.pkl')\n", 75 | " random_seed(42)\n", 76 | " learn.callback_fns = list(filter(isNotCsvLogger, learn.callback_fns)) + [partial(CSVLogger, filename = 'refined history')]\n", 77 | " learn.data = data\n", 78 | " learn.unfreeze()\n", 79 | " random_seed(42)\n", 80 | " learn.fit_one_cycle(5, 1e-4)\n", 81 | " learn.save('refined model')\n", 82 | " learn.export(file='final refined model.pkl')\n", 83 | " for index, dataset in enumerate(getDatasets(allData, crop=False, cutInHalf=False)):\n", 84 | " MODEL_PATH = Path(EXPERIMENTS_PATH) / path / str(index) \n", 85 | " if not (MODEL_PATH / 'final refined model 2.pkl').exists(): \n", 86 | " random_seed(42)\n", 87 | " data = dataset.databunch(bs=1, val_bs = 1, num_workers=0).normalize(imagenet_stats)\n", 88 | " learn = load_learner(MODEL_PATH, 'final refined model.pkl')\n", 89 | " learn.callback_fns = list(filter(isNotCsvLogger, learn.callback_fns)) + [partial(CSVLogger, filename = 'refined history 2')]\n", 90 | " learn.data = data\n", 91 | " learn.unfreeze()\n", 92 | " random_seed(42)\n", 93 | " learn.fit_one_cycle(3, 1e-4)\n", 94 | " learn.save('refined model 2')\n", 95 | " learn.export(file='final refined model 2.pkl')\n", 96 | " \n", 97 | " if not (MODEL_PATH / 'final refined predictions.csv').exists():\n", 98 | " learn = load_learner(MODEL_PATH, 'final refined model.pkl')\n", 99 | " random_seed(42)\n", 100 | " m = getDatasetMetrics(dataset, learn)\n", 101 | " m.save(MODEL_PATH / 'final refined predictions.csv')\n", 102 | " \n", 103 | " if not (MODEL_PATH / 'final refined predictions 2.csv').exists():\n", 104 | " learn = load_learner(MODEL_PATH, 'final refined model 2.pkl')\n", 105 | " random_seed(42)\n", 106 | " m = getDatasetMetrics(dataset, learn) \n", 107 | " m.save(MODEL_PATH / 'final refined predictions 2.csv')" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [] 116 | } 117 | ], 118 | "metadata": { 119 | "kernelspec": { 120 | "display_name": "Python 3", 121 | "language": "python", 122 | "name": "python3" 123 | }, 124 | "language_info": { 125 | "codemirror_mode": { 126 | "name": "ipython", 127 | "version": 3 128 | }, 129 | "file_extension": ".py", 130 | "mimetype": "text/x-python", 131 | "name": "python", 132 | "nbconvert_exporter": "python", 133 | "pygments_lexer": "ipython3", 134 | "version": "3.7.3" 135 | } 136 | }, 137 | "nbformat": 4, 138 | "nbformat_minor": 2 139 | } 140 | -------------------------------------------------------------------------------- /code/experiments/sickzil-machine.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from experiments import getDatasets, getData, random_seed\n", 10 | "from metrics import MetricsCallback, getDatasetMetrics\n", 11 | "from config import *\n", 12 | "from pathlib import Path\n", 13 | "\n", 14 | "%load_ext autoreload\n", 15 | "%autoreload 2" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "allData = getData()" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "CSV_PATH = Path(EXPERIMENTS_PATH) / 'sickzil' / 'metrics.csv'\n", 34 | "\n", 35 | "if not (CSV_PATH).exists():\n", 36 | " for index, dataset in enumerate(getDatasets(allData, crop=False, cutInHalf = False)):\n", 37 | " random_seed(42)\n", 38 | " m = MetricsCallback(None)\n", 39 | " m.on_train_begin()\n", 40 | " for idx in range(len(dataset.valid.x.items)):\n", 41 | " pred = dataset.valid.y.sickzilImage(idx) \n", 42 | " m.on_batch_end(False, pred.px, dataset.valid.y[idx].px)\n", 43 | " m.calculateMetrics()\n", 44 | " m.save(CSV_PATH, index > 0) " 45 | ] 46 | } 47 | ], 48 | "metadata": { 49 | "kernelspec": { 50 | "display_name": "Python 3", 51 | "language": "python", 52 | "name": "python3" 53 | }, 54 | "language_info": { 55 | "codemirror_mode": { 56 | "name": "ipython", 57 | "version": 3 58 | }, 59 | "file_extension": ".py", 60 | "mimetype": "text/x-python", 61 | "name": "python", 62 | "nbconvert_exporter": "python", 63 | "pygments_lexer": "ipython3", 64 | "version": "3.7.3" 65 | } 66 | }, 67 | "nbformat": 4, 68 | "nbformat_minor": 2 69 | } 70 | -------------------------------------------------------------------------------- /code/experiments/synthetic.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json\n", 10 | "from fastai.vision import unet_learner, imagenet_stats, torch, Path, os, load_learner, models, ItemBase, SegmentationLabelList, SegmentationItemList, partial, ImageSegment\n", 11 | "from experiments import getDatasets, getData, random_seed\n", 12 | "from losses import MixedLoss\n", 13 | "from metrics import MetricsCallback, getDatasetMetrics\n", 14 | "from fastai.callbacks import CSVLogger, SaveModelCallback\n", 15 | "from config import *\n", 16 | "import glob\n", 17 | "from TextGenerator import Fonts\n", 18 | "from transforms import textify, tensorize\n", 19 | "from PIL import Image as pilImage\n", 20 | "\n", 21 | "%load_ext autoreload\n", 22 | "%autoreload 2\n", 23 | "torch.cuda.set_device(0)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 3, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "EXPERIMENT_PATH = Path(EXPERIMENTS_PATH) / 'synthetic'\n", 33 | "MODELS_PATH = EXPERIMENT_PATH / \"models\"\n", 34 | "os.makedirs(MODELS_PATH, exist_ok=True)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 4, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "learn = load_learner(Path(EXPERIMENTS_PATH) / 'model/resnet34/0', 'final refined model 2.pkl')\n", 44 | "\n", 45 | "if (EXPERIMENT_PATH / 'text_info.json').exists():\n", 46 | " with open(EXPERIMENT_PATH / 'text_info.json', 'r') as f:\n", 47 | " info = json.load(f)\n", 48 | "else:\n", 49 | " info = dict()\n", 50 | "\n", 51 | "for file in glob.glob(DANBOORU_PATH + '/**/*.jpg', recursive=True):\n", 52 | " file = Path(file)\n", 53 | " if file.name not in info:\n", 54 | " pred = learn.predict(open_image(file))[2]\n", 55 | " pred = torch.sigmoid(pred) > 0.5\n", 56 | " info[file.name] = (pred == 1).sum().item()\n", 57 | " with open(EXPERIMENT_PATH / 'text_info.json', 'w') as f:\n", 58 | " json.dump(info, f) " 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 5, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "allData = getData()" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 6, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "def custom_collate(batch):\n", 77 | " if hasattr(batch[0][0], \"x_tensor\"):\n", 78 | " return torch.stack(list(map(lambda x: x[0].x_tensor, batch))), torch.stack(list(map(getSegmentationMask, batch))).long()\n", 79 | " else:\n", 80 | " return torch.stack(list(map(lambda x: x[0].px, batch))), torch.stack(list(map(lambda x: x[1].px, batch))).long()\n", 81 | "\n", 82 | "def getSegmentationMask(dan):\n", 83 | " return ((dan[0].x_tensor - dan[0].y_tensor).abs().sum(axis=0) > 0.1).unsqueeze(0)\n", 84 | " \n", 85 | "def folder(p):\n", 86 | " folder = (\"0000\" + p[-7:-4])[-4:]\n", 87 | " return '/' + folder + \"/\" + p \n", 88 | "\n", 89 | "\n", 90 | "class CustomItem(ItemBase):\n", 91 | " def __init__(self, image):\n", 92 | " self.image = image\n", 93 | " self.data = 0\n", 94 | " \n", 95 | " def __str__(self): return str(self.image)\n", 96 | " \n", 97 | " def apply_tfms(self, tfms, **kwargs):\n", 98 | " for tfm in tfms:\n", 99 | " tfm(self, **kwargs)\n", 100 | " return self \n", 101 | "\n", 102 | "class CustomLabel(SegmentationLabelList):\n", 103 | " def open(self, fn):\n", 104 | " return ImageSegment(torch.zeros(1, 64, 64)) \n", 105 | " \n", 106 | "class CustomItemList(SegmentationItemList): \n", 107 | " _label_cls = CustomLabel\n", 108 | " def get(self, i):\n", 109 | " return self.reconstruct(pilImage.open(self.items[i]).convert('RGB'))\n", 110 | " \n", 111 | " def reconstruct(self, t):\n", 112 | " return CustomItem(t)\n", 113 | "\n", 114 | " \n", 115 | "fonts = Fonts(Fonts.load(Path('../fonts')))\n", 116 | "\n", 117 | "items = list(map(lambda p: DANBOORU_PATH + folder(p), filter(lambda k: info[k] == 0, info.keys())))\n", 118 | "\n", 119 | "data = CustomItemList(items).split_none().label_const('a', classes=['text'])\n", 120 | "data.valid = getDatasets(allData)[0].valid\n", 121 | "\n", 122 | "data.train.transform([partial(textify, fonts=fonts), tensorize])\n", 123 | "\n", 124 | "data = data.databunch(bs=8, val_bs = 2, collate_fn = custom_collate).normalize(imagenet_stats)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "learn = unet_learner(data, models.resnet34, callback_fns=[MetricsCallback, CSVLogger, partial(SaveModelCallback, monitor = 'ignore global f1 score %')], model_dir='models', loss_func=MixedLoss(0, 1), path=EXPERIMENT_PATH)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "if not (EXPERIMENT_PATH / 'final model.pkl').exists():\n", 143 | " random_seed(42)\n", 144 | " learn.fit_one_cycle(5, 1e-4)\n", 145 | " learn.save('model')\n", 146 | " learn.export('final model.pkl') " 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "if not (EXPERIMENT_PATH / 'final predictions.csv').exists():\n", 156 | " learn = load_learner(EXPERIMENT_PATH, 'final model.pkl')\n", 157 | " for index, dataset in enumerate(getDatasets(allData, crop=False, cutInHalf = False)):\n", 158 | " random_seed(42)\n", 159 | " m = getDatasetMetrics(dataset, learn)\n", 160 | " m.save(EXPERIMENT_PATH / 'final predictions.csv', index > 0) " 161 | ] 162 | } 163 | ], 164 | "metadata": { 165 | "kernelspec": { 166 | "display_name": "Python 3", 167 | "language": "python", 168 | "name": "python3" 169 | }, 170 | "language_info": { 171 | "codemirror_mode": { 172 | "name": "ipython", 173 | "version": 3 174 | }, 175 | "file_extension": ".py", 176 | "mimetype": "text/x-python", 177 | "name": "python", 178 | "nbconvert_exporter": "python", 179 | "pygments_lexer": "ipython3", 180 | "version": "3.7.3" 181 | } 182 | }, 183 | "nbformat": 4, 184 | "nbformat_minor": 2 185 | } 186 | -------------------------------------------------------------------------------- /code/experiments/v1_2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json\n", 10 | "import fastai\n", 11 | "from experiments import *\n", 12 | "from fastai.vision import *\n", 13 | "from fastai.callbacks import *\n", 14 | "from losses import MixedLoss\n", 15 | "from dataset import *\n", 16 | "from transforms import *\n", 17 | "from config import *\n", 18 | "import glob\n", 19 | "from PIL import Image as pilImage\n", 20 | "from metrics import *\n", 21 | "\n", 22 | "%load_ext autoreload\n", 23 | "%autoreload 2\n", 24 | "\n", 25 | "torch.cuda.set_device(0)" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 4, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "EXPERIMENT_PATH = Path(EXPERIMENTS_PATH) / 'synthetic'\n", 35 | "\n", 36 | "def custom_loss(pred, truth):\n", 37 | " truth = truth.float()\n", 38 | " return F.binary_cross_entropy(pred, truth)\n", 39 | "\n", 40 | "def custom_collate(batch):\n", 41 | " if isinstance(batch[0][1], int):\n", 42 | " return torch.stack(list(map(lambda x: x[0].data, batch))), torch.stack(list(map(lambda x: tensor(x[1]), batch)))\n", 43 | " if hasattr(batch[0][0], \"x_tensor\"):\n", 44 | " return torch.stack(list(map(lambda x: x[0].x_tensor, batch))), torch.stack(list(map(getSegmentationMask, batch))).long()\n", 45 | " else:\n", 46 | " return torch.stack(list(map(lambda x: x[0].px, batch))), torch.stack(list(map(lambda x: x[1].px, batch))).long()\n", 47 | "\n", 48 | "def getSegmentationMask(dan):\n", 49 | " y, x = dan[0].y_tensor, dan[0].x_tensor\n", 50 | " res = ((y[0] == x[0]).int() + (y[1] == x[1]).int() + (x[2] == y[2]).int()) != 3\n", 51 | " res = res.unsqueeze(0)\n", 52 | " return res \n", 53 | " \n", 54 | "def folder(p):\n", 55 | " folder = (\"0000\" + p[-7:-4])[-4:]\n", 56 | " return '/' + folder + \"/\" + p \n", 57 | "\n", 58 | "\n", 59 | "class CustomItem(ItemBase):\n", 60 | " def __init__(self, image):\n", 61 | " self.image = image\n", 62 | " self.data = 0\n", 63 | " \n", 64 | " def __str__(self): return str(self.image)\n", 65 | " \n", 66 | " def apply_tfms(self, tfms, **kwargs):\n", 67 | " for tfm in tfms:\n", 68 | " tfm(self, **kwargs)\n", 69 | " return self \n", 70 | "\n", 71 | "class CustomLabel(SegmentationLabelList):\n", 72 | " def open(self, fn):\n", 73 | " return ImageSegment(torch.zeros(1, 64, 64)) \n", 74 | " \n", 75 | "class CustomItemList(SegmentationItemList): \n", 76 | " _label_cls = CustomLabel\n", 77 | " def get(self, i):\n", 78 | " return self.reconstruct(pilImage.open(self.items[i]).convert('RGB'))\n", 79 | " \n", 80 | " def reconstruct(self, t):\n", 81 | " return CustomItem(t)\n", 82 | "\n", 83 | "\n", 84 | "fonts = Fonts(Fonts.load(Path('../fonts')))\n", 85 | "with open(EXPERIMENT_PATH / 'text_info.json', 'r') as f:\n", 86 | " info = json.load(f)\n", 87 | "\n", 88 | "random_seed(42)\n", 89 | "allData = getData() \n", 90 | " \n", 91 | "items = list(map(lambda p: DANBOORU_PATH + folder(p), filter(lambda k: info[k] == 0, info.keys())))\n", 92 | "\n", 93 | "data = CustomItemList(items[0:10]).split_none().label_const('a', classes=['text'])\n", 94 | "\n", 95 | "data.valid = getDatasets(allData)[0].valid\n", 96 | "\n", 97 | "data.train.transform([partial(textify, fonts=fonts), tensorize])\n", 98 | "\n", 99 | "data = data.databunch(bs=8, val_bs = 2, collate_fn = custom_collate).normalize(imagenet_stats)\n", 100 | "\n", 101 | "learn = unet_learner(data, models.resnet18, metrics=[accuracy_thresh, partial(accuracy_thresh, thresh=0.95, sigmoid=False)], loss_func=custom_loss, y_range=(0,1))" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 18, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "cntIndex = 0\n", 111 | "if cv2.__version__.startswith(\"3\"):\n", 112 | " cntIndex = 1\n", 113 | "\n", 114 | "def expand(mask, img):\n", 115 | " mask = mask.astype('uint8')\n", 116 | " gray = cv2.cvtColor(img.data.mul(255).permute(1,2,0).numpy().astype('uint8'),cv2.COLOR_RGB2GRAY)\n", 117 | " thresh = cv2.adaptiveThreshold(gray,255,cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY,15,30)\n", 118 | " cnts = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)[cntIndex]\n", 119 | " im3 = np.zeros(thresh.shape, np.uint8)\n", 120 | "\n", 121 | " for c in cnts:\n", 122 | " x,y,w,h = cv2.boundingRect(c)\n", 123 | " thresh = cv2.adaptiveThreshold(gray[y:y+h, x:x+w],255,cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY,15,30)\n", 124 | " ret, markers = cv2.connectedComponents(cv2.bitwise_not(thresh), connectivity=8)\n", 125 | " if ret < 10:\n", 126 | " for label in range(1,ret):\n", 127 | " m = markers == label\n", 128 | " if m.sum() > 3:\n", 129 | " if (m & mask[y:y+h, x:x+w] > 0).sum() > m.sum() * 0.1:\n", 130 | " im3[y:y+h, x:x+w][m] = 255\n", 131 | " \n", 132 | " return im3\n", 133 | "def removeNoise(mask):\n", 134 | " cnts = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)[cntIndex]\n", 135 | " goods = [cv2.contourArea(c) >= 50 for c in cnts]\n", 136 | " rects = [cv2.boundingRect(c) for c in cnts]\n", 137 | " circles = [cv2.minEnclosingCircle(c) for c in cnts]\n", 138 | " banned = [False] * len(cnts)\n", 139 | " \n", 140 | " m = cv2.dilate(mask,(5, 5),iterations = 7)\n", 141 | " cc = cv2.findContours(m, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)[cntIndex] \n", 142 | " rr = [cv2.boundingRect(c) for c in cc]\n", 143 | "\n", 144 | " for c, good, idx, rect in zip(cnts, goods, range(0, len(cnts)), rects):\n", 145 | " x,y,w,h = rect\n", 146 | " \n", 147 | " if max(w,h) / min(w,h) > 5:\n", 148 | " goods[idx] = False\n", 149 | " if max(w,h) / min(w,h) > 8:\n", 150 | " banned[idx] = True\n", 151 | " continue\n", 152 | " \n", 153 | " for r2, c2 in zip(rr, cc):\n", 154 | " x2, y2, w2, h2 = r2\n", 155 | " if cv2.contourArea(c2) >= 50 and x >= x2 and x + w <= x2 + w2 and y >= y2 and y + h <= y2 + h2 and (mask[y2:y2+h2, x2:x2+w2] > 0).sum() > len(c2) * 0.5:\n", 156 | " goods[idx] = True \n", 157 | "\n", 158 | " \n", 159 | " changed = True\n", 160 | " while changed:\n", 161 | " changed = False\n", 162 | " for c, good, idx, rect in zip(cnts, goods, range(0, len(cnts)), rects):\n", 163 | " if banned[idx]:\n", 164 | " continue\n", 165 | " \n", 166 | " x,y,w,h = rect\n", 167 | " x, y = x + w / 2, y + h / 2 \n", 168 | " if not good:\n", 169 | " for a in range(max(idx - 50, 0), len(cnts)):\n", 170 | " if a != idx and goods[a]:\n", 171 | " x2, y2, w2, h2 = rects[a]\n", 172 | " x2, y2, = x2 + w2 / 2, y2 + h2 / 2 \n", 173 | " \n", 174 | " if abs(y2 - y) > 100 + h:\n", 175 | " break\n", 176 | " \n", 177 | " if abs (cv2.contourArea(cnts[idx]) - circles[idx][1]**2) > 20 and abs(y2 - y) < (h + h2) / 2 + 20 and abs(x2 - x) < (w + w2) / 2 + 20:\n", 178 | " good = goods[idx] = True\n", 179 | " changed = True\n", 180 | " break\n", 181 | "\n", 182 | " \n", 183 | "\n", 184 | " for c, good, idx, rect in zip(cnts, goods, range(0, len(cnts)), rects): \n", 185 | " if not good:\n", 186 | " cv2.drawContours(mask, [c], 0, (0, 0, 0), -1)" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 5, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "learn.load(EXPERIMENT_PATH / 'models' / 'v1_2');" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 19, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "if not (EXPERIMENT_PATH / 'v1_2 predictions.csv').exists() or True:\n", 205 | " for index, dataset in enumerate(getDatasets(allData, crop=False, cutInHalf = False)):\n", 206 | " random_seed(42)\n", 207 | " m = MetricsCallback(None)\n", 208 | " m.on_train_begin()\n", 209 | " for idx in range(len(dataset.valid.x.items)):\n", 210 | " x = dataset.valid.x.get(idx, False)\n", 211 | " y = learn.predict(x)[2] > 0.95\n", 212 | " y = y.permute(1,2,0).numpy() * 255\n", 213 | " y = expand(y[:,:,0], x)\n", 214 | " removeNoise(y)\n", 215 | " y = tensor(y).unsqueeze(0).div_(255).bool()\n", 216 | " m.on_batch_end(False, y, dataset.valid.y.get(idx, False).px)\n", 217 | " m.calculateMetrics() \n", 218 | " m.save(EXPERIMENT_PATH / 'v1_2 predictions.csv', index > 0) " 219 | ] 220 | } 221 | ], 222 | "metadata": { 223 | "kernelspec": { 224 | "display_name": "Python 3", 225 | "language": "python", 226 | "name": "python3" 227 | }, 228 | "language_info": { 229 | "codemirror_mode": { 230 | "name": "ipython", 231 | "version": 3 232 | }, 233 | "file_extension": ".py", 234 | "mimetype": "text/x-python", 235 | "name": "python", 236 | "nbconvert_exporter": "python", 237 | "pygments_lexer": "ipython3", 238 | "version": "3.7.3" 239 | } 240 | }, 241 | "nbformat": 4, 242 | "nbformat_minor": 4 243 | } 244 | -------------------------------------------------------------------------------- /code/losses.py: -------------------------------------------------------------------------------- 1 | from fastai.callbacks import hook_outputs 2 | import torch.nn.functional as F 3 | from torch import nn,tensor 4 | from ssim import SSIM 5 | from numpy import log10 6 | import torch 7 | 8 | def gram_matrix(x): 9 | n,c,h,w = x.size() 10 | x = x.view(n, c, -1) 11 | return (x @ x.transpose(1,2))/(c*h*w) 12 | from functools import reduce 13 | 14 | class FeatureLoss(nn.Module): 15 | def __init__(self, model, layer_ids, layer_wgts, base_loss=F.l1_loss): 16 | super().__init__() 17 | self.model = nn.Sequential(*list(model.children())[:layer_ids[-1] + 1]) 18 | self.loss_features = [self.model[i] for i in layer_ids] 19 | self.hooks = hook_outputs(self.loss_features, detach=False) 20 | self.wgts = layer_wgts 21 | self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids)) 22 | ] + [f'gram_{i}' for i in range(len(layer_ids))] + ['PSNR', 'SSIM', 'pEPs'] 23 | self.base_loss = base_loss 24 | self.mse = nn.MSELoss() 25 | self.ssim = SSIM(window_size=11) 26 | 27 | def make_features(self, x, clone=False): 28 | self.model(x.cuda()) 29 | return [(o.clone() if clone else o) for o in self.hooks.stored] 30 | 31 | def forward(self, input, target, *args, reduction = 'sum', **kwargs): 32 | #print(input.shape, target.shape) 33 | psnr = 10 * log10(1 / self.mse(input, target)) 34 | ssim = self.ssim(input, target) 35 | pEPs = (input - target).mul(255).abs().le(0.1).sum(dim=1).eq(3).sum().div(input.numel()/input.shape[1]).mul(100) 36 | 37 | out_feat = self.make_features(target, clone=True) 38 | in_feat = self.make_features(input) 39 | 40 | feat_losses = [self.base_loss(input,target)] 41 | feat_losses += [self.base_loss(f_in, f_out)*w 42 | for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)] 43 | feat_losses += [self.base_loss(gram_matrix(f_in), gram_matrix(f_out))*w* 5e3 44 | for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)] 45 | 46 | self.metrics = dict(zip(self.metric_names, feat_losses + [psnr, ssim, pEPs])) 47 | 48 | if reduction == 'mean': 49 | return torch.mean(torch.stack(feat_losses)) 50 | elif reduction == 'sum': 51 | return sum(feat_losses) 52 | else: 53 | raise reduction 54 | 55 | def __del__(self): self.hooks.remove() 56 | 57 | 58 | #https://www.kaggle.com/iafoss/unet34-dice-0-87 59 | def dice_loss(input, target): 60 | input = torch.sigmoid(input) 61 | smooth = 1.0 62 | 63 | iflat = input.view(-1).float() 64 | tflat = target.view(-1).float() 65 | intersection = (iflat * tflat).sum() 66 | 67 | return ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth)) 68 | 69 | 70 | class FocalLoss(nn.Module): 71 | def __init__(self, gamma, average = True): 72 | super().__init__() 73 | self.gamma = gamma 74 | self.average = average 75 | 76 | def forward(self, input, target): 77 | if not (target.size() == input.size()): 78 | raise ValueError("Target size ({}) must be the same as input size ({})" 79 | .format(target.size(), input.size())) 80 | 81 | max_val = (-input).clamp(min=0) 82 | loss = input - input * target + max_val + \ 83 | ((-max_val).exp() + (-input - max_val).exp()).log() 84 | 85 | invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0)) 86 | loss = (invprobs * self.gamma).exp() * loss 87 | 88 | if self.average: 89 | return loss.mean() 90 | else: 91 | return loss.sum() 92 | 93 | class MixedLoss(nn.Module): 94 | def __init__(self, alpha, gamma, normalize = True): 95 | super().__init__() 96 | self.alpha = alpha 97 | self.focal = FocalLoss(gamma) 98 | self.normalize = normalize 99 | 100 | def forward(self, input, target): 101 | if self.normalize: 102 | target = target.clone() 103 | target[target >= 3] -= 3 104 | target = target != 0 105 | loss = self.alpha*self.focal(input, target.float()) - torch.log(dice_loss(input, target.float())) 106 | return loss.mean() 107 | 108 | class BCELoss(nn.Module): 109 | def __init__(self, text_weight = 1, normalize = True): 110 | super().__init__() 111 | self.normalize = normalize 112 | self.text_weight = text_weight 113 | 114 | def forward(self, input, target): 115 | if self.normalize: 116 | target = target.clone() 117 | target[target >= 3] -= 3 118 | target = target != 0 119 | 120 | weight = tensor([self.text_weight]).to(input.device) 121 | 122 | return torch.nn.BCEWithLogitsLoss(pos_weight = weight)(input, target.float()) -------------------------------------------------------------------------------- /code/manga109.py: -------------------------------------------------------------------------------- 1 | import manga109api 2 | from fastai.vision import ImageImageList 3 | import os 4 | import json 5 | from fastai.core import listify 6 | from fastai.basics import ItemBase 7 | from fastai.vision import Image, ImageImageList 8 | from PIL import Image, ImageDraw 9 | from IPython.display import display 10 | 11 | def draw_rectangle(img, x0, y0, x1, y1, annotation_type): 12 | assert annotation_type in ["body", "face", "frame", "text"] 13 | color = {"body": "#258039", "face": "#f5be41", 14 | "frame": "#31a9b8", "text": "#cf3721"}[annotation_type] 15 | width = 1 16 | draw = ImageDraw.Draw(img) 17 | draw.line([x0 - width/2, y0, x1 + width/2, y0], fill=color, width=width) 18 | draw.line([x1, y0, x1, y1], fill=color, width=width) 19 | draw.line([x1 + width/2, y1, x0 - width/2, y1], fill=color, width=width) 20 | draw.line([x0, y1, x0, y0], fill=color, width=width) 21 | 22 | class MangaImage(ItemBase): 23 | def __init__(self, image, info, idx): 24 | self.image = image 25 | self.data = idx 26 | self.info = info 27 | 28 | def __str__(self): return str(self.image) 29 | 30 | def apply_tfms(self, tfms, **kwargs): 31 | for tfm in tfms: 32 | tfm(self, **kwargs) 33 | 34 | return self 35 | 36 | def show(self): 37 | img = Image.open(self.info['path']) 38 | for text in self.info['text']: 39 | draw_rectangle(img, text['xmin'], text['ymin'], text['xmax'], text['ymax'], 'text') 40 | 41 | display(img) 42 | 43 | 44 | class Manga109ImageList(ImageImageList): 45 | @classmethod 46 | def load(cls, path, **kwargs): 47 | os.makedirs('cache/manga', exist_ok=True) 48 | 49 | try: 50 | with open('cache/manga/data.json', 'r') as f: 51 | data = json.load(f) 52 | except: 53 | data = manga109api.Parser(root_dir=str(path)) 54 | data = {'books': data.books, 'annotations': data.annotations} 55 | 56 | with open('cache/manga/data.json', 'w') as f: 57 | json.dump(data, f) 58 | 59 | try: 60 | with open('cache/manga/pages.json', 'r') as f: 61 | pages = json.load(f) 62 | except: 63 | pages = [] 64 | 65 | for book in data['books']: 66 | for page in data['annotations'][book]['book']['pages']['page']: 67 | pagedata = {'text': [], 'path': str(path / 'images' / book / (str(page['@index']).zfill(3) + ".jpg"))} 68 | if 'text' in page: 69 | for txt in page['text'] if isinstance(page['text'], list) else [page['text']]: 70 | pagedata['text'].append({'xmin': txt['@xmin'], 'xmax': txt['@xmax'], 'ymin': txt['@ymin'], 'ymax': txt['@ymax'], 'text': txt['#text']}) 71 | pages.append(pagedata) 72 | 73 | with open('cache/manga/pages.json', 'w') as f: 74 | json.dump(pages, f) 75 | 76 | return Manga109ImageList(list(filter(lambda x: len(x['text']) > 0, pages)), **kwargs) 77 | 78 | def get(self, i): 79 | info = self.items[i] 80 | image = self.open(info['path']) 81 | return MangaImage(image, info, i) 82 | 83 | def show_xyzs(self, xs, ys, zs, logger=False, **kwargs): 84 | if logger: 85 | logger.show_xyzs(xs, ys, zs, **kwargs) 86 | else: 87 | return super().show_xyzs(xs, ys, zs, **kwargs) -------------------------------------------------------------------------------- /code/metrics.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from fastai.vision import * 3 | from skimage.measure import label, regionprops 4 | from skimage.morphology import watershed, square 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from scipy import ndimage as nd 8 | from itertools import cycle 9 | 10 | class NullRecorder(): 11 | def record(self, *args): 12 | pass 13 | 14 | class MetricsCallback(LearnerCallback): 15 | _order=-20 # Needs to run before the recorder 16 | def __init__(self, learn, thresh = 0.5): 17 | self.init = learn is not None 18 | 19 | if self.init: 20 | super().__init__(learn) 21 | 22 | self.thresh = thresh 23 | 24 | self.last_metrics = dict() 25 | 26 | self.datasets = ["easy", "hard"] 27 | self.modes = ["normal", "ignore"] 28 | self.start, self.end, self.step = 0, 1, 0.1 29 | self.bins = np.arange(self.start, self.end, self.step) 30 | 31 | def on_train_begin(self, **kwargs): 32 | self.clearMetrics() 33 | self.calculateMetrics() 34 | if self.init: 35 | self.learn.recorder.add_metric_names(self.last_metrics.keys()) 36 | 37 | def clearMetrics(self): 38 | self.metrics = defaultdict(float) 39 | 40 | for dataset in self.datasets: 41 | self.metrics[dataset] = defaultdict(float) 42 | for mode in self.modes: 43 | self.metrics[dataset][mode] = defaultdict(float) 44 | 45 | def bin(self, value): 46 | return min(int((value - self.start) / self.step), len(self.bins) - 1) 47 | 48 | def on_batch_end(self, train, last_output, last_target, **kwargs): 49 | if train: return 50 | 51 | assert last_output.shape == last_target.shape, f'output shape is {last_output.shape} while target is {last_target.shape}' 52 | 53 | #testing against an image 54 | if len(last_output.shape) == 3: 55 | last_output = last_output.unsqueeze(0) 56 | last_target = last_target.unsqueeze(0) 57 | 58 | pred = last_output if last_output.min() >= 0 else torch.sigmoid(last_output) > self.thresh 59 | 60 | expanded = last_target.clone().detach() 61 | original = expanded.clone() 62 | original[last_target >= 3] -= 3 63 | 64 | eroded = expanded.clone() 65 | eroded[last_target >= 3] = 0 66 | 67 | expanded, original, eroded = expanded.squeeze(1).cpu(), original.squeeze(1).cpu(), eroded.squeeze(1).cpu() 68 | pred = pred.squeeze(1).cpu() 69 | 70 | for image in range(0, last_target.shape[0]): 71 | text_mask = original[image] != 0 72 | 73 | self.metrics["drd_k sum"] += drd(pred[image].unsqueeze(0).unsqueeze(0), text_mask.unsqueeze(0).unsqueeze(0)).sum().item() 74 | self.metrics["#subn"] += subn(text_mask.numpy()) 75 | self.metrics["psnr"] += psnr(pred[image], text_mask) 76 | self.metrics["drd"] += self.metrics["drd_k sum"] / max(self.metrics["#subn"], 1) 77 | self.metrics["#images"] += 1 78 | 79 | #label text mask = ground truth 80 | labels = label(original[image].numpy(), connectivity=2) 81 | regions = regionprops(labels) 82 | 83 | regionsDict = {region.label: region for region in regions} 84 | 85 | for region in regionsDict.values(): 86 | region.area # forzes to calculate, seems to be lazy and give 0 later if not done this 87 | region.dataset = "hard" if original[image][region.coords[0][0], region.coords[0][1]] == 2 else "easy" 88 | 89 | addIntersectArea(regionsDict, labels.copy(), pred[image], "intersectArea") 90 | #watershedding to match expanded pixels to a single GT 91 | expansion = addWatershed(regionsDict, labels, expanded[image] != 0, "expandedArea") 92 | #watershedding to match prediction pixels to a single GT 93 | predictionWater = addWatershed(regionsDict, labels, pred[image], "predictionArea") 94 | expansion[expansion != predictionWater] = 0 95 | addRegionInfo(regionsDict, expansion, "expandedIntersectArea") 96 | 97 | labelsClone = labels.copy() 98 | addIntersectArea(regionsDict, labelsClone, eroded[image], "erodedArea") 99 | addIntersectArea(regionsDict, labelsClone, pred[image], "erodedIntersectArea") 100 | 101 | intersections = 0 102 | 103 | for region in list(regionsDict.values()): 104 | if not hasattr(region, "erodedIntersectArea"): 105 | region.erodedIntersectArea = 0 106 | region.intersectArea = 0 107 | region.expandedIntersectArea = 0 108 | 109 | #there is an edge case where there can be no eroded area intersection because image was cut in half, ignore. We calculate exact metrics with full resolution later 110 | if not hasattr(region, "erodedArea"): 111 | del regionsDict[region.label] 112 | continue 113 | 114 | assert(region.expandedArea >= region.area) 115 | 116 | erodedCoverage = region.erodedIntersectArea / region.erodedArea 117 | coverage = region.intersectArea / region.area 118 | 119 | self.metrics[region.dataset]["normal"]["#tp expanded"] += region.intersectArea 120 | self.metrics[region.dataset]["normal"]["#tp eroded"] += region.intersectArea 121 | 122 | self.metrics[region.dataset]["ignore"]["#tp expanded"] += region.expandedIntersectArea 123 | self.metrics[region.dataset]["ignore"]["#tp eroded"] += region.erodedIntersectArea 124 | 125 | self.metrics[region.dataset]["normal"]["#fn"] += region.area - region.intersectArea 126 | self.metrics[region.dataset]["ignore"]["#fn"] += region.erodedArea - region.erodedIntersectArea 127 | 128 | self.metrics[region.dataset]["#truth"] += 1 129 | 130 | if region.intersectArea > 0: 131 | assert(region.expandedIntersectArea <= region.predictionArea) 132 | 133 | accuracy = region.intersectArea / region.predictionArea 134 | expandedAccuracy = region.expandedIntersectArea / region.predictionArea 135 | 136 | self.metrics[region.dataset]["#intersections"] += 1 137 | intersections += 1 138 | 139 | self.metrics[region.dataset]["ignore"]["#fp"] += region.predictionArea - region.expandedIntersectArea 140 | self.metrics[region.dataset]["normal"]["#fp"] += region.predictionArea - region.intersectArea 141 | else: 142 | accuracy = 0 143 | expandedAccuracy = 0 144 | 145 | f1 = coverage * accuracy * 2 / max((coverage + accuracy), 1e-6) 146 | f1Ignore = erodedCoverage * expandedAccuracy * 2 / max((erodedCoverage + expandedAccuracy), 1e-6) 147 | 148 | for mode, word, metric in zip(cycle(["normal", "ignore"]),["accuracy", "accuracy", "coverage", "coverage", "f1", "f1"], [accuracy, expandedAccuracy, coverage, erodedCoverage, f1, f1Ignore]): 149 | self.metrics[region.dataset][mode][f'#{word} {self.bin(metric)}'] += 1 150 | self.metrics[region.dataset][mode][word] += metric 151 | 152 | self.metrics["#truth"] += len(regionsDict) 153 | 154 | #get prediction connected components that did not intersect with text 155 | pred[image][predictionWater != 0] = 0 156 | labels = label(pred[image]) 157 | regions = regionprops(labels) 158 | 159 | for region in regions: 160 | self.metrics["#fp no intersect"] += region.area 161 | 162 | self.metrics["#pred"] += intersections + len(regions) 163 | 164 | 165 | def on_epoch_begin(self, **kwargs): 166 | self.clearMetrics() 167 | 168 | def addHistogramMetrics(self, k, metrics, metric): 169 | for idx, bin in enumerate(self.bins): 170 | key = f'#{k} {str(round(bin, 2))}-{str(round(bin + self.step, 2))}' 171 | metrics[key] += metric[f"#{k} {idx}"] 172 | 173 | def calculateMetrics(self): 174 | 175 | self.last_metrics["#truth"] = self.metrics["#truth"] 176 | self.last_metrics["#pred"] = self.metrics["#pred"] 177 | self.last_metrics["#fp no intersect"] = self.metrics["#fp no intersect"] 178 | self.last_metrics["#intersections"] = (self.metrics["easy"]["#intersections"] + self.metrics["hard"]["#intersections"]) 179 | 180 | self.last_metrics["#subn"] = self.metrics["#subn"] 181 | self.last_metrics["drd_k sum"] = self.metrics["drd_k sum"] 182 | self.last_metrics["drd"] = self.last_metrics["drd_k sum"] / max(self.metrics["#subn"], 1e-6) 183 | self.last_metrics["average drd"] = self.metrics["drd"] / max(self.metrics["#images"], 1e-6) 184 | self.last_metrics["psnr"] = self.metrics["psnr"] / max(self.metrics["#images"], 1e-6) 185 | 186 | self.last_metrics["global precision quantity %"] = self.last_metrics["#intersections"] / max(self.metrics["#pred"], 1e-6) 187 | self.last_metrics["global recall quantity %"] = self.last_metrics["#intersections"] / max(self.metrics["#truth"], 1e-6) 188 | 189 | for dataset in self.datasets: 190 | for metric in ["#intersections", "#truth"]: 191 | self.last_metrics[dataset + " " + metric] = self.metrics[dataset][metric] 192 | 193 | for mode in self.modes: 194 | modeMetrics = defaultdict(float) 195 | 196 | for dataset in self.datasets: 197 | metrics = defaultdict(float) 198 | modeMetrics[dataset] = defaultdict(float) 199 | self.addHistogramMetrics("coverage", metrics, self.metrics[dataset][mode]) 200 | self.addHistogramMetrics("accuracy", metrics, self.metrics[dataset][mode]) 201 | self.addHistogramMetrics("f1", metrics, self.metrics[dataset][mode]) 202 | 203 | for metric in ["#fn", "#fp", "#tp expanded", "#tp eroded"]: 204 | metrics[metric] = self.metrics[dataset][mode][metric] 205 | modeMetrics[metric] += metrics[metric] 206 | 207 | metrics["global recall quantity %"] = self.last_metrics[dataset + " #intersections"] / max(self.metrics[dataset]["#truth"], 1e-6) 208 | 209 | metrics["global quality precision %"] = self.metrics[dataset][mode]["accuracy"] / max(self.last_metrics[dataset + " #intersections"], 1e-6) 210 | metrics["global quality recall %"] = self.metrics[dataset][mode]["coverage"] / max(self.last_metrics[dataset + " #intersections"], 1e-6) 211 | 212 | GP = metrics["global precision %"] = self.last_metrics["global precision quantity %"] * metrics["global quality precision %"] 213 | GR = metrics["global recall %"] = metrics["global recall quantity %"] * metrics["global quality recall %"] 214 | 215 | metrics["global f1 %"] = 2 * GR * GP / max((GP + GR), 1e-6) 216 | 217 | for key in metrics.keys(): 218 | modeMetrics[dataset][key] = metrics[key] 219 | 220 | 221 | acc, cov = [sum(self.metrics[dataset][mode][metric] for dataset in self.datasets) for metric in ["accuracy", "coverage"]] 222 | modeMetrics["global quality precision %"] = acc / max(self.last_metrics["#intersections"], 1e-6) 223 | modeMetrics["global quality recall %"] = cov / max(self.last_metrics["#intersections"], 1e-6) 224 | 225 | GP = modeMetrics["global precision %"] = self.last_metrics["global precision quantity %"] * modeMetrics["global quality precision %"] 226 | GR = modeMetrics["global recall %"] = self.last_metrics["global recall quantity %"] * modeMetrics["global quality recall %"] 227 | 228 | modeMetrics["global f1 score %"] = 2 * GR * GP / max((GP + GR), 1e-6) 229 | 230 | modeMetrics["#fp"] += self.metrics["#fp no intersect"] 231 | P = modeMetrics["pixel precision %"] = modeMetrics["#tp expanded"] / max(modeMetrics["#tp expanded"] + modeMetrics["#fp"], 1e-6) 232 | R = modeMetrics["pixel recall %"] = modeMetrics["#tp eroded"] / max(modeMetrics["#tp eroded"] + modeMetrics["#fn"], 1e-6) 233 | modeMetrics["pixel f1 %"] = 2 * P * R / max(P + R, 1e-6) 234 | 235 | 236 | for key in modeMetrics.keys(): 237 | if isinstance(modeMetrics[key], float): 238 | self.last_metrics[mode + " " + key] = modeMetrics[key] 239 | else: 240 | for k in modeMetrics[key]: 241 | self.last_metrics[mode + " " + key + " " + k] = modeMetrics[key][k] 242 | 243 | for key in self.last_metrics.keys(): 244 | self.last_metrics[key] = self.last_metrics[key] * (100 if "%" in key else 1) 245 | self.last_metrics[key] = int(self.last_metrics[key]) if "#" in key else round(self.last_metrics[key], 2) 246 | 247 | 248 | def on_epoch_end(self, last_metrics, **kwargs): 249 | self.calculateMetrics() 250 | if self.init: 251 | return add_metrics(last_metrics, self.last_metrics.values()) 252 | 253 | 254 | def save(self, path, append=False): 255 | path = Path(path) 256 | path.parent.mkdir(parents=True, exist_ok=True) 257 | mode = "a" if append else "w" 258 | exists = path.exists() 259 | with path.open(mode) as f: 260 | if mode == "w" or not exists: 261 | f.write(','.join(self.last_metrics.keys())) 262 | f.write('\n' + ','.join([str(stat) if isinstance(stat, int) else '#na#' if stat is None else f'{stat:.6f}' for stat in self.last_metrics.values()])) 263 | 264 | 265 | def subplots(self, xlabel, ylabel): 266 | params = {'axes.labelsize': 16, 267 | 'axes.titlesize': 16} 268 | plt.rcParams.update(params) 269 | 270 | rows = (len(self.datasets) * len(self.modes) + 1) // 2 271 | fig, axes = plt.subplots(nrows=rows, ncols=2, figsize=(14, 12)) 272 | axes = axes.flatten() 273 | indexes = np.arange(len(self.bins)) 274 | 275 | for idx, ax in enumerate(axes): 276 | if idx % 2 == 0 and ylabel is not None: 277 | ax.set(ylabel=ylabel) 278 | if idx // 2 == rows - 1 and xlabel is not None: 279 | ax.set(xlabel=xlabel) 280 | ax.set(ylim = [0, 1], yticks = np.arange(0, 1.1, 0.1), yticklabels = [round(bin, 2) for bin in np.arange(0, 1.1, 0.1)]) 281 | ax.set(xticks=indexes - 0.2, xticklabels=[round(bin, 2) for bin in self.bins]) 282 | ax.xaxis.set_tick_params(labelsize=16) 283 | ax.yaxis.set_tick_params(labelsize=16) 284 | ax.margins(x=0) 285 | 286 | return fig, axes.flatten() 287 | 288 | def data(self, data): 289 | l = 1 290 | 291 | if isinstance(data, list): 292 | data = pd.concat(data) 293 | 294 | if isinstance(data, pd.DataFrame): 295 | l = len(data.index) 296 | data = data.sum().to_dict() 297 | elif isinstance(data, pd.core.groupby.DataFrameGroupBy): 298 | index = data.mean()["ignore global f1 score %"].idxmax() 299 | l = int(data.size()[index]) 300 | data = data.sum().iloc[index] 301 | 302 | return data, l 303 | 304 | def showHistograms(self, configs, metrics, xlabel="Quality Interval", ylabel="Percentage of connected components"): 305 | totalBars = len(configs * len(metrics)) 306 | fig, axes = self.subplots(xlabel, ylabel) 307 | axIdx = 0 308 | indexes = np.arange(len(self.bins)) 309 | patterns = ['/', '*', '-', '+', 'x', '\\', 'o', 'O', '.'][0:totalBars] 310 | 311 | for dataset in self.datasets: 312 | for mode in self.modes: 313 | ax = axes[axIdx] 314 | i = 0 315 | for config in configs: 316 | data, _ = self.data(config['data']) 317 | for metric, label in zip(metrics, config['labels']): 318 | vals = [] 319 | 320 | for idx, bin in enumerate(self.bins): 321 | key = f'{str(round(bin, 2))}-{str(round(bin + self.step, 2))}' 322 | vals.append(data[f"{mode} {dataset} #{metric} {key}"]) 323 | 324 | ax.bar(indexes + i / totalBars, np.array(vals) / sum(vals), label = label , width = 1 / totalBars, hatch = patterns[i % len(patterns)]) 325 | i += 1 326 | 327 | ax.set(title = dataset + ' - ' + mode.replace("ignore", "relaxed") + " " + f'({int(sum(vals))} connected components)') 328 | 329 | ax.legend(prop={'size': 16}) 330 | axIdx += 1 331 | 332 | fig.tight_layout() 333 | plt.show() 334 | 335 | def getImageMetrics(pred, truth): 336 | m = MetricsCallback(None) 337 | m.on_train_begin() 338 | m.on_batch_end(False, pred, truth) 339 | m.calculateMetrics() 340 | return m 341 | 342 | def getDatasetMetrics(dataset, learn): 343 | m = MetricsCallback(None) 344 | m.on_train_begin() 345 | for idx in range(len(dataset.valid.x.items)): 346 | pred = learn.predict(dataset.valid.x.get(idx, False))[2] 347 | m.on_batch_end(False, pred, dataset.valid.y.get(idx, False).px) 348 | m.calculateMetrics() 349 | return m 350 | 351 | #modifies labels! 352 | def addIntersectArea(regionsDict, labels, mask, prop): 353 | labels[mask == 0] = 0 354 | #Image(tensor(labels).unsqueeze(0)).show(figsize=(15,15), title=prop) 355 | addRegionInfo(regionsDict, labels, prop) 356 | 357 | def addWatershed(regionsDict, labels, mask, prop): 358 | distance = nd.distance_transform_edt(mask) 359 | water = watershed(-distance, labels, mask=mask, connectivity=square(3)) 360 | addRegionInfo(regionsDict, water, prop) 361 | return water 362 | 363 | def addRegionInfo(regionsDict, labels, prop): 364 | regions = regionprops(labels) 365 | 366 | for region in regions: 367 | setattr(regionsDict[region.label], prop, region.area) 368 | 369 | def printStats(data, name = "", stats = None, cuts = [], pm = True): 370 | try: 371 | index = data.mean()["ignore global f1 score %"].idxmax() 372 | mean, std = data.mean().iloc[index], data.std().iloc[index] 373 | except: 374 | index = 0 375 | mean, std = data.mean(), data.std() 376 | 377 | if stats is None: 378 | stats = ["normal pixel f1 %", "normal global f1 score %", "ignore pixel f1 %", "ignore global f1 score %", "normal global quality precision %", "normal global quality recall %", "ignore global quality precision %", "ignore global quality recall %", "global precision quantity %", "global recall quantity %"] 379 | cuts = [3, 7] 380 | 381 | output = "" 382 | 383 | if name: 384 | print(name, index) 385 | 386 | for idx, stat in enumerate(stats): 387 | if not math.isnan(std[stat]) and pm: 388 | output += f" & ${'%.2f' % mean[stat]} \\pm {'%.1f' % std[stat]}$" 389 | else: 390 | output += f" & ${'%.2f' % mean[stat]}$" 391 | if idx in cuts or idx + 1 == len(stats): 392 | print(output + " \\\\") 393 | output = "" 394 | 395 | #drd metric 396 | W = np.hypot(np.arange(-2, 3), np.arange(-2, 3)[:, None]) 397 | np.reciprocal(W, where=W.astype(bool), out=W) 398 | W /= W.sum() 399 | W = tensor(W).unsqueeze(0).unsqueeze(0).float() 400 | 401 | def drd(im, im_gt): 402 | m1 = (im == 1) & (im_gt == 0) 403 | m2 = (im == 0) & (im_gt == 1) 404 | conv1 = F.conv2d((im_gt == 0).float(), W, bias=None, padding=2, stride=(1, 1)) 405 | conv2 = F.conv2d((im_gt == 1).float(), W, bias=None, padding=2, stride=(1, 1)) 406 | return (conv1 * m1 + conv2 * m2).sum(axis=(1,2,3)) 407 | 408 | def subn(im_gt): 409 | height, width = im_gt.shape 410 | block_size = 8 411 | 412 | y = im_gt[0:(height//block_size)*block_size, 0:(width//block_size)*block_size] 413 | sums = y.reshape(y.shape[0]//block_size, block_size, y.shape[1]//block_size, block_size).sum(axis=(1, 3)) 414 | return np.sum((sums != 0) & (sums != (block_size ** 2))) 415 | 416 | #psnr metric 417 | def psnr(im, im_gt): 418 | mse = (im != im_gt).float().mean() 419 | return 10 * math.log10(1. / mse) if mse > 0 else 100 -------------------------------------------------------------------------------- /code/ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 1 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | (_, channel, _, _) = img1.size() 49 | 50 | if channel == self.channel and self.window.data.type() == img1.data.type(): 51 | window = self.window 52 | else: 53 | window = create_window(self.window_size, channel) 54 | 55 | if img1.is_cuda: 56 | window = window.cuda(img1.get_device()) 57 | window = window.type_as(img1) 58 | 59 | self.window = window 60 | self.channel = channel 61 | 62 | 63 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 64 | 65 | def ssim(img1, img2, window_size = 11, size_average = True): 66 | (_, channel, _, _) = img1.size() 67 | window = create_window(window_size, channel) 68 | 69 | if img1.is_cuda: 70 | window = window.cuda(img1.get_device()) 71 | window = window.type_as(img1) 72 | 73 | return _ssim(img1, img2, window, window_size, channel, size_average) -------------------------------------------------------------------------------- /code/transforms.py: -------------------------------------------------------------------------------- 1 | from fastai.vision import pil2tensor 2 | 3 | from PIL import ImageFont, ImageDraw, Image 4 | from RectangleGenerator import * 5 | from TextGenerator import TextGenerator, Fonts 6 | import numpy as np 7 | from pathlib import Path 8 | import numpy.random as random 9 | from fastai.vision import imagenet_stats, normalize, vision, image2np 10 | import torch 11 | import cv2 12 | 13 | 14 | def randint(a, b, *args): 15 | return random.randint(a, b + 1, *args) 16 | 17 | def add_spaces(s): 18 | if len(s) < 3: 19 | return s 20 | idxs = set(random.choice(range(0, len(s)), len(s)//10 + 1, False)) 21 | return "".join(map(lambda t: t[1] if t[0] not in idxs else " ",enumerate(s))) 22 | 23 | def textify(danImage, fonts): 24 | image = danImage.image 25 | 26 | w, h = image.size 27 | pil_img_x = image.copy() 28 | pil_img_y = image.copy() 29 | draw_x = ImageDraw.Draw(pil_img_x, 'RGBA') 30 | draw_y = ImageDraw.Draw(pil_img_y, 'RGBA') 31 | 32 | padding = randint(4, 10) 33 | 34 | if random.random_sample() < 0.5: 35 | x, y = randint(0, w // 6), randint(0, h // 6) 36 | rects = [Rectangle(x, y, randint(w - 2 * x, w - x), randint(h - 2 * y, h - y))] 37 | else: 38 | rects = danImage.rects = RectangleGenerator.generate(w, h, randint(7, 15)) 39 | 40 | for rect in rects: 41 | 42 | if random.random_sample() < 0.8: 43 | text_size = randint(8, 20) 44 | else: 45 | text_size = randint(20, min(w, h) * 7 // 10) 46 | 47 | expected_to_fit = rect.area() // (text_size ** 2) 48 | 49 | font = fonts.randomFont() 50 | text = font.generateText(randint(expected_to_fit // 2, expected_to_fit + 1)) 51 | 52 | sizedFont = font.getFont(text_size) 53 | 54 | lines = TextGenerator.text_wrap(text, sizedFont, rect.width - padding * 2, rect.height - padding * 2) 55 | #lines = list(map(add_spaces, lines)) 56 | 57 | border_color = randint(240,255), randint(240,255), randint(240,255) 58 | 59 | if random.random_sample() < 0.8: 60 | text_color = (randint(0,60), randint(0,60), randint(0,60)) 61 | else: 62 | if random.random_sample() < 0.5: 63 | text_color = (randint(0,255), randint(0,255), randint(0,255)) 64 | else: 65 | text_color = (randint(200,255), randint(200,255), randint(200,255)) 66 | border_color = (randint(0,10), randint(0,10), randint(0,10)) 67 | 68 | rotate = random.random_sample() < 0.3 69 | box = random.random_sample() < 0.1 70 | angle = randint(0,255) 71 | alpha = randint(0,200) if random.random_sample() < 0.1 else 255 72 | border = random.random_sample() < 0.05 73 | 74 | x, y = rect.x + padding, rect.y + padding 75 | 76 | if box: 77 | mask = im = Image.new('RGBA', (rect.width, rect.height), (255, 255, 255, alpha)) 78 | 79 | if random.random_sample() < 0.5: # make circle shape 80 | bigsize = (im.size[0] * 3, im.size[1] * 3) 81 | mask = Image.new('L', bigsize, 0) 82 | draw = ImageDraw.Draw(mask) 83 | draw.ellipse((0, 0) + bigsize, fill=255) 84 | mask = mask.resize(im.size, Image.ANTIALIAS) 85 | im.putalpha(mask) 86 | 87 | if rotate: 88 | mask = rotate_image(im, angle) 89 | 90 | pil_img_x.paste(im, (rect.x, rect.y), mask) 91 | pil_img_y.paste(im, (rect.x, rect.y), mask) 92 | 93 | if rotate: 94 | im = Image.new('RGBA', (rect.width, rect.height), (255, 255, 255, 0)) 95 | draw_rotated_text(im, angle, "\n".join(lines), text_color, font=sizedFont) 96 | pil_img_x.paste(im, (rect.x, rect.y), im) 97 | else: 98 | if border: 99 | draw_border(draw_x, (x, y), "\n".join(lines), border_color, sizedFont) 100 | draw_x.multiline_text((x, y), "\n".join(lines), fill=text_color, font=sizedFont) 101 | 102 | 103 | if random.random_sample() < 0.2: 104 | pil_img_x = pil_img_x.convert('L').convert('RGB') 105 | pil_img_y = pil_img_y.convert('L').convert('RGB') 106 | 107 | 108 | danImage.pil_img_x = pil_img_x 109 | danImage.pil_img_y = pil_img_y 110 | 111 | danImage.x = pil2tensor(danImage.pil_img_x,np.float32).div_(255) 112 | 113 | return danImage 114 | 115 | def tensorize(danImage): 116 | danImage.x_tensor = pil2tensor(danImage.pil_img_x,np.float32).div_(255) 117 | danImage.y_tensor = pil2tensor(danImage.pil_img_y,np.float32).div_(255) 118 | 119 | return danImage 120 | 121 | def binarize(danImage): 122 | 123 | gray = cv2.cvtColor(np.array(danImage.pil_img_x), cv2.COLOR_RGB2GRAY) 124 | blur = gray#cv2.GaussianBlur(gray, (5,5), 0) 125 | thresh = cv2.adaptiveThreshold(blur,255,cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV,13,10) 126 | danImage.pil_img_x = np.stack([thresh, thresh, thresh]).transpose(1,2,0) 127 | 128 | gray = cv2.cvtColor(np.array(danImage.pil_img_y), cv2.COLOR_RGB2GRAY) 129 | blur = gray#cv2.GaussianBlur(gray, (5,5), 0) 130 | thresh = cv2.adaptiveThreshold(blur,255,cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV,13,10) 131 | danImage.pil_img_y = np.stack([thresh, thresh, thresh]).transpose(1,2,0) 132 | 133 | return danImage 134 | 135 | def rotate_image(im, angle): 136 | if angle % 90 == 0: 137 | # rotate by multiple of 90 deg is easier 138 | return im.rotate(angle) 139 | else: 140 | # rotate an an enlarged mask to minimize jaggies 141 | bigger_mask = im.resize((im.size[0]*8, im.size[1]*8), 142 | resample=Image.BICUBIC) 143 | return bigger_mask.rotate(angle).resize( 144 | im.size, resample=Image.LANCZOS) 145 | 146 | 147 | def draw_rotated_text(image, angle, text, fill, font): 148 | # get the size of our image 149 | width, height = image.size 150 | max_dim = max(width, height) 151 | 152 | # build a transparency mask large enough to hold the text 153 | mask = Image.new('L', image.size, 0) 154 | 155 | # add text to mask 156 | draw = ImageDraw.Draw(mask) 157 | 158 | size = draw.multiline_textsize(text, font) 159 | draw.multiline_text(((width - size[0]) // 2, (height - size[1]) // 2), text, 255, font, align='center') 160 | 161 | rotated_mask = rotate_image(mask, angle) 162 | 163 | # paste the appropriate color, with the text transparency mask 164 | color_image = Image.new('RGBA', image.size, fill) 165 | image.paste(color_image, rotated_mask) 166 | 167 | 168 | def draw_border(draw, pos, text, border, font): 169 | x, y = pos 170 | 171 | if random.random_sample() < 0.8: 172 | for adj in range(1, randint(1, 2)): 173 | #move right 174 | draw.multiline_text((x-adj, y), text, font=font, fill=border) 175 | #move left 176 | draw.multiline_text((x+adj, y), text, font=font, fill=border) 177 | #move up 178 | draw.multiline_text((x, y+adj), text, font=font, fill=border) 179 | #move down 180 | draw.multiline_text((x, y-adj), text, font=font, fill=border) 181 | #diagnal left up 182 | draw.multiline_text((x-adj, y+adj), text, font=font, fill=border) 183 | #diagnal right up 184 | draw.multiline_text((x+adj, y+adj), text, font=font, fill=border) 185 | #diagnal left down 186 | draw.multiline_text((x-adj, y-adj), text, font=font, fill=border) 187 | #diagnal right down 188 | draw.multiline_text((x+adj, y-adj), text, font=font, fill=border) 189 | 190 | def patchify(img): 191 | w, h = State.getRandomSize() 192 | wid, hei = img.image.size 193 | 194 | if w < wid: 195 | if random.random_sample() < 0.6: 196 | x, y = randint(wid//4, wid//4*3), randint(hei//4, hei//4*3) 197 | else: 198 | x, y = random.randint(0, wid - w), random.randint(0, hei - h) 199 | 200 | x, y = min(x, wid - w), min(y, hei - h) 201 | img.image = img.image.crop((x, y, x + w, y + h)) 202 | 203 | 204 | 205 | def mangacrop(manga): 206 | w, h = State.getRandomSize() 207 | 208 | target = random.choice(manga.info['text']) 209 | 210 | center = ((target['xmin'] + target['xmax']) * 0.5, (target['ymin'] + target['ymax']) * 0.5) 211 | 212 | xmin, ymin = int(center[0] - w // 2), int(center[1] - h // 2) 213 | 214 | xmax = xmin + w 215 | ymax = ymin + h 216 | 217 | if xmax > manga.image.size[1]: 218 | xmin -= (xmax - manga.image.size[1]) 219 | xmax -= (xmax - manga.image.size[1]) 220 | 221 | if xmin < 0: 222 | xmax -= xmin 223 | xmin -= xmin 224 | 225 | if ymax > manga.image.size[0]: 226 | ymin -= (ymax - manga.image.size[0]) 227 | ymax -= (ymax - manga.image.size[0]) 228 | 229 | if ymin < 0: 230 | ymax -= ymin 231 | ymin -= ymin 232 | 233 | #print(target, xmin, ymin, xmax, ymax, manga.image.size) 234 | 235 | manga.image = vision.Image(manga.image.px[:,ymin:ymax, xmin:xmax]) 236 | manga.x_tensor = manga.image.px 237 | manga.y_tensor = manga.x_tensor.clone() 238 | 239 | 240 | class State: 241 | minSize, maxSize = 64, 64 242 | randSizes = None 243 | 244 | @staticmethod 245 | def getRandomSize(): 246 | if State.randSizes is None: 247 | State.randSizes = [State.minSize, State.minSize] if State.minSize == State.maxSize else randint(State.minSize//8, State.maxSize//8, 2) * 8 248 | return State.randSizes 249 | 250 | @staticmethod 251 | def resetRandomSize(): 252 | State.randSizes = None 253 | -------------------------------------------------------------------------------- /images/AisazuNihaIrarenai-009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juvian/Manga-Text-Segmentation/de8f148c78978d70ad0e0ae3242566da6d70f3a5/images/AisazuNihaIrarenai-009.jpg -------------------------------------------------------------------------------- /images/AisazuNihaIrarenaipre-prediction-009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juvian/Manga-Text-Segmentation/de8f148c78978d70ad0e0ae3242566da6d70f3a5/images/AisazuNihaIrarenaipre-prediction-009.png -------------------------------------------------------------------------------- /images/AisazuNihaIrarenaipre-processed-truth-009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juvian/Manga-Text-Segmentation/de8f148c78978d70ad0e0ae3242566da6d70f3a5/images/AisazuNihaIrarenaipre-processed-truth-009.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastai==1.0.60 2 | torch==1.4.0 3 | torchvision==0.5.0 --------------------------------------------------------------------------------