├── weight └── .gitkeep ├── data ├── line │ └── 0 │ │ └── .gitkeep └── thin │ └── 0 │ └── .gitkeep ├── input ├── 0.png ├── 1.png ├── 2.png └── 3.png ├── output ├── 0.png ├── 1.png ├── 2.png └── 3.png ├── predict.py ├── LICENSE ├── README.md ├── utils.py ├── model2.py ├── datagen.py └── model1.py /weight/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/line/0/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/thin/0/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /input/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hepesu/LineCloser/HEAD/input/0.png -------------------------------------------------------------------------------- /input/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hepesu/LineCloser/HEAD/input/1.png -------------------------------------------------------------------------------- /input/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hepesu/LineCloser/HEAD/input/2.png -------------------------------------------------------------------------------- /input/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hepesu/LineCloser/HEAD/input/3.png -------------------------------------------------------------------------------- /output/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hepesu/LineCloser/HEAD/output/0.png -------------------------------------------------------------------------------- /output/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hepesu/LineCloser/HEAD/output/1.png -------------------------------------------------------------------------------- /output/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hepesu/LineCloser/HEAD/output/2.png -------------------------------------------------------------------------------- /output/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hepesu/LineCloser/HEAD/output/3.png -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # Try running on CPU 4 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 5 | 6 | import numpy as np 7 | import cv2 8 | from keras.models import load_model 9 | 10 | R = 2 ** 4 11 | MODEL_NAME = './model1.h5' 12 | 13 | model = load_model(MODEL_NAME) 14 | model.summary() 15 | 16 | for root, dirs, files in os.walk('./input', topdown=False): 17 | for name in files: 18 | print(os.path.join(root, name)) 19 | 20 | im = cv2.imread(os.path.join(root, name), cv2.IMREAD_GRAYSCALE) 21 | 22 | im_predict = cv2.resize(im, (im.shape[1] // R * R, im.shape[0] // R * R)) 23 | im_predict = np.reshape(im_predict, (1, im_predict.shape[0], im_predict.shape[1], 1)) 24 | im_predict = im_predict.astype(np.float32) / 255. 25 | 26 | result = model.predict(im_predict) 27 | 28 | im_res = cv2.resize(result[0] * 255., (im.shape[1], im.shape[0])) 29 | 30 | cv2.imwrite(os.path.join('./output', name), im_res) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 HEPESU 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LineCloser 2 | Unofficial Keras implementation of Joint Gap Detection and Inpainting of Line Drawings. 3 | 4 | ## Overview 5 | Joint gap for line-drawings. Model1 uses network from the paper. For stable training, BN was added for all Conv2D. Model2 uses common network for inpaint. 6 | 7 | ## Dependencies 8 | * Keras2 (Tensorflow backend) 9 | * OpenCV3 10 | * CairoSVG 11 | 12 | ## Usage 13 | 1. Set up directories. 14 | 15 | 2. Download the model from release and put it in the same folder with code. 16 | 17 | 3. Run `predict.py` for prediction. Run `model{NUM}.py` for train. 18 | 19 | ## Data Preparation 20 | There are 3 methods for data generation, `DATA_GEN`, `DATA_GAP` and `DATA_THIN`. 21 | 22 | 0. Use `DATA_GEN` for training, the data is generated online. 23 | 24 | 1. Collect line-drawings with [LineDistiller](https://github.com/hepesu/LineDistiller). 25 | 26 | 2. Put line-drawings into `data/line`, using `DATA_GAP` for training. 27 | 28 | 3. Thin(normalize) the line-drawings with [LineNormalizer](https://github.com/hepesu/LineNormalizer) or tranditional thinning method. 29 | 30 | 4. Manually processe line-drawings and thinning results(threshold etc.), then crop them into pieces. 31 | 32 | 5. Put line-drawings into `data/line` and put thinning results into `data/thin`, using `DATA_THIN` for training. 33 | 34 | ## Models 35 | Models are licensed under a CC-BY-NC-SA 4.0 international license. 36 | * [LineCloser Release Page](https://github.com/hepesu/LineCloser/releases) 37 | 38 | 39 | 40 | From **Project HAT** by Hepesu With :heart: 41 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | def generate_random_gap(imgs, gap_configs, seed=None): 6 | bg = np.full(imgs[0].shape, 1., np.float32) 7 | 8 | imgs_with_gaps = [] 9 | masks = [] 10 | 11 | if seed is not None: 12 | np.random.seed(seed) 13 | 14 | for img in imgs: 15 | img_height, img_width = img.shape[:2] 16 | mask = np.zeros_like(img, np.float32) 17 | print(mask.shape) 18 | for gap_config in gap_configs: 19 | nb_min, nb_max, r_min, r_max, b_min, b_max = gap_config 20 | _mask = np.zeros_like(img, np.float32) 21 | 22 | for _ in range(np.random.randint(nb_min, nb_max)): 23 | center = (np.random.randint(img_width), np.random.randint(img_height)) 24 | radius = np.random.randint(r_min, r_max) 25 | cv2.circle(_mask, center, radius, 1., -1) 26 | 27 | blur_radius = np.random.randint(b_min, b_max) * 2 + 1 28 | _mask = cv2.blur(_mask, (blur_radius, blur_radius)) 29 | 30 | _mask = np.expand_dims(_mask, axis=-1) 31 | 32 | # accumulate masks 33 | mask = mask + _mask 34 | 35 | mask = np.clip(mask, 0., 1.) 36 | 37 | # composite with mix 38 | imgs_with_gaps.append(img * (1. - mask) + bg * mask) 39 | masks.append(mask * (1. - img)) 40 | 41 | return np.array(imgs_with_gaps, np.float32), np.array(masks, np.float32) 42 | 43 | 44 | if __name__ == "__main__": 45 | y = cv2.imread('./input/0.png', cv2.IMREAD_GRAYSCALE) 46 | y = np.expand_dims(y, -1) / 255 47 | 48 | gap_configs352 = [ 49 | [50, 600, 2, 8, 0, 1], 50 | [50, 600, 2, 10, 0, 2], 51 | [1, 2, 5, 15, 0, 3] 52 | ] 53 | 54 | x, m = generate_random_gap([y], gap_configs352, 1) 55 | 56 | cv2.imwrite('./gap_x_check.png', x[0] * 255) 57 | cv2.imwrite('./gap_y_check.png', y * 255) 58 | -------------------------------------------------------------------------------- /model2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import datetime 3 | 4 | from keras.models import Model 5 | from keras.layers import Input, Conv2D, UpSampling2D, BatchNormalization, Activation 6 | from keras.preprocessing import image 7 | 8 | from datagen import gen_data 9 | from utils import generate_random_gap 10 | 11 | LOAD_WEIGHTS = False 12 | ITERATIONS = 50000 13 | BATCH_SIZE = 4 14 | SEED = 1 15 | IMG_SHAPE = (352, 352, 1) 16 | IMG_HEIGHT, IMG_WIDTH, IMG_CHAN = IMG_SHAPE 17 | DATA_TYPE = 'DATA_GEN' 18 | 19 | 20 | def build_model(): 21 | input_tensor = Input((None, None, 1)) 22 | 23 | x = Conv2D(64, 11, padding='same')(input_tensor) 24 | x = BatchNormalization()(x) 25 | x = Activation('relu')(x) 26 | 27 | x = Conv2D(128, 3, padding='same')(x) 28 | x = BatchNormalization()(x) 29 | x = Activation('relu')(x) 30 | 31 | x = Conv2D(256, 3, dilation_rate=1, padding='same')(x) 32 | x = BatchNormalization()(x) 33 | x = Activation('relu')(x) 34 | 35 | x = Conv2D(256, 3, dilation_rate=2, padding='same')(x) 36 | x = BatchNormalization()(x) 37 | x = Activation('relu')(x) 38 | 39 | x = Conv2D(256, 3, dilation_rate=4, padding='same')(x) 40 | x = BatchNormalization()(x) 41 | x = Activation('relu')(x) 42 | 43 | x = Conv2D(256, 3, dilation_rate=8, padding='same')(x) 44 | x = BatchNormalization()(x) 45 | x = Activation('relu')(x) 46 | 47 | x = Conv2D(128, 3, padding='same')(x) 48 | x = BatchNormalization()(x) 49 | x = Activation('relu')(x) 50 | 51 | x = Conv2D(64, 3, padding='same')(x) 52 | x = BatchNormalization()(x) 53 | x = Activation('relu')(x) 54 | 55 | output_tensor = Conv2D(1, 3, padding='same', activation='sigmoid')(x) 56 | 57 | return Model(input_tensor, output_tensor) 58 | 59 | 60 | def data_generator(type='DATA_GEN'): 61 | ''' 62 | Generate data in specific type. 63 | DATA_GEN: generate data with random graphics, use small disc as gap, non-meaningful data 64 | DATA_GAP: generate data use small disc as gap on user line-drawings, meaningful data 65 | DATA_THIN: directly read offline data generated using normalization(thinning) 66 | 67 | :param type: DATA_GEN, DATA_GAP, DATA_THIN 68 | :return: x_data, y_data 69 | ''' 70 | # Use both 352 and 176 could achieve better performance 71 | gap_configs352 = [ 72 | [50, 600, 2, 8, 0, 1], 73 | [50, 600, 2, 10, 0, 2], 74 | [1, 2, 5, 15, 0, 3] 75 | ] 76 | 77 | # gap_configs176 = [ 78 | # [50, 200, 1, 4, 0, 1], 79 | # [50, 200, 1, 5, 0, 2], 80 | # [1, 2, 5, 10, 0, 3] 81 | # ] 82 | 83 | # gap_configs128 = [ 84 | # [50, 200, 2, 4, 0, 1], 85 | # [50, 200, 2, 5, 0, 2], 86 | # [1, 2, 5, 15, 0, 3] 87 | # ] 88 | 89 | # gap_configs64 = [ 90 | # [50, 200, 1, 4, 0, 1], 91 | # [50, 200, 1, 5, 0, 2], 92 | # [1, 2, 5, 10, 0, 3] 93 | # ] 94 | 95 | datagen = image.ImageDataGenerator( 96 | rescale=1 / 255., 97 | rotation_range=180, 98 | width_shift_range=0.1, 99 | height_shift_range=0.1, 100 | zoom_range=0.2, 101 | horizontal_flip=True, 102 | vertical_flip=True, 103 | fill_mode='reflect' 104 | ) 105 | 106 | if type == 'DATA_GAP': 107 | raw_generator_352 = datagen.flow_from_directory( 108 | './data/line', 109 | target_size=(IMG_HEIGHT, IMG_WIDTH), 110 | color_mode='grayscale', 111 | seed=SEED, 112 | class_mode=None, 113 | batch_size=BATCH_SIZE, 114 | shuffle=True, 115 | interpolation='bilinear' 116 | ) 117 | 118 | # raw_generator_176 = datagen.flow_from_directory( 119 | # './data/line', 120 | # target_size=(IMG_HEIGHT // 2, IMG_WIDTH // 2), 121 | # color_mode='grayscale', 122 | # seed=SEED, 123 | # class_mode=None, 124 | # batch_size=BATCH_SIZE // 2, 125 | # shuffle=True, 126 | # interpolation='bilinear' 127 | # ) 128 | 129 | while True: 130 | train_y_batch = next(raw_generator_352) 131 | train_x_batch, _ = generate_random_gap(train_y_batch, gap_configs352, SEED) 132 | 133 | yield train_x_batch, train_y_batch 134 | 135 | elif type == 'DATA_GEN': 136 | while True: 137 | # Size config is in datagen.py 138 | train_y_batch = gen_data(np.random.RandomState(SEED), BATCH_SIZE) 139 | train_x_batch, _ = generate_random_gap(train_y_batch, gap_configs352, SEED) 140 | 141 | yield train_x_batch, train_y_batch 142 | 143 | elif type == 'DATA_THIN': 144 | raw_generator_x = datagen.flow_from_directory( 145 | './data/thin', 146 | target_size=(IMG_HEIGHT, IMG_WIDTH), 147 | color_mode='grayscale', 148 | seed=SEED, 149 | class_mode=None, 150 | batch_size=BATCH_SIZE, 151 | shuffle=True, 152 | interpolation='bilinear' 153 | ) 154 | 155 | raw_generator_y = datagen.flow_from_directory( 156 | './data/line', 157 | target_size=(IMG_HEIGHT, IMG_WIDTH), 158 | color_mode='grayscale', 159 | seed=SEED, 160 | class_mode=None, 161 | batch_size=BATCH_SIZE, 162 | shuffle=True, 163 | interpolation='bilinear' 164 | ) 165 | 166 | while True: 167 | yield next(raw_generator_x), next(raw_generator_y) 168 | 169 | 170 | def train(): 171 | model = build_model() 172 | model.summary() 173 | 174 | if LOAD_WEIGHTS: 175 | model.load_weights('./weight/model2.h5') 176 | 177 | model.compile(loss='MSE', optimizer='Adam') 178 | 179 | data = data_generator(DATA_TYPE) 180 | start_time = datetime.datetime.now() 181 | 182 | for iteration in range(1, ITERATIONS + 1): 183 | 184 | train_y_batch, train_x_batch = next(data) 185 | loss = model.train_on_batch(train_x_batch, train_y_batch) 186 | 187 | print('[Time: %s] [Iteration: %d] [Loss: %f]' % (datetime.datetime.now() - start_time, iteration, loss)) 188 | 189 | if iteration % 200 == 0: 190 | model.save('./weight/model2_%d.h5' % iteration) 191 | 192 | model.save('./weight/model2.h5') 193 | 194 | 195 | if __name__ == "__main__": 196 | train() 197 | -------------------------------------------------------------------------------- /datagen.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://github.com/byungsook/vectornet/blob/master/data_line.py 3 | 4 | Semantic Segmentation for Line Drawing Vectorization Using Neural Networks 5 | Tensorflow implementation of Semantic Segmentation for Line Drawing Vectorization Using Neural Networks. 6 | 7 | Byungsoo Kim1, Oliver Wang2, Cengiz ?ztireli1, Markus Gross1 8 | 9 | 1ETH Zurich, 2Adobe Research 10 | 11 | Computer Graphics Forum (Proceedings of Eurographics 2018) 12 | 13 | ''' 14 | 15 | import numpy as np 16 | import cv2 17 | from PIL import Image 18 | import cairosvg 19 | import io 20 | 21 | SEED = 1 22 | WIDTH, HEIGHT = 352, 352 23 | 24 | MIN_STROKE_WIDTH, MAX_STROKE_WIDTH = 0.2, 2 25 | MAX_STROKE_COLOR = 50 26 | 27 | NORM_STROKE_WIDTH = 0.5 28 | 29 | MAX_NUM_STROKES = 100 30 | 31 | SVG_START_TEMPLATE = """ 32 | 33 | 34 | \n""" 35 | 36 | SVG_RECT_TEMPLATE = """""" 37 | 38 | SVG_ELLIPSE_TEMPLATE = """""" 39 | 40 | SVG_LINE_TEMPLATE = """""" 41 | 42 | SVG_CUBIC_BEZIER_TEMPLATE = """""" 43 | 44 | SVG_END_TEMPLATE = """\n""" 45 | 46 | 47 | def draw_line(id, w, h, rng): 48 | stroke_color = rng.randint(MAX_STROKE_COLOR) 49 | stroke_width = rng.rand() * (MAX_STROKE_WIDTH - MIN_STROKE_WIDTH) + MIN_STROKE_WIDTH 50 | x = rng.randint(w, size=2) 51 | y = rng.randint(h, size=2) 52 | 53 | return SVG_LINE_TEMPLATE.format( 54 | id=id, 55 | x1=x[0], y1=y[0], 56 | x2=x[1], y2=y[1], 57 | r=stroke_color, g=stroke_color, b=stroke_color, 58 | sw=stroke_width 59 | ) 60 | 61 | 62 | def draw_cubic_bezier_curve(id, w, h, rng): 63 | stroke_color = rng.randint(MAX_STROKE_COLOR) 64 | stroke_width = rng.rand() * (MAX_STROKE_WIDTH - MIN_STROKE_WIDTH) + MIN_STROKE_WIDTH 65 | x = rng.randint(w, size=4) 66 | y = rng.randint(h, size=4) 67 | 68 | return SVG_CUBIC_BEZIER_TEMPLATE.format( 69 | id=id, 70 | sx=x[0], sy=y[0], 71 | cx1=x[1], cy1=y[1], 72 | cx2=x[2], cy2=y[2], 73 | tx=x[3], ty=y[3], 74 | r=stroke_color, g=stroke_color, b=stroke_color, 75 | sw=stroke_width 76 | ) 77 | 78 | 79 | def draw_rect(id, w, h, rng): 80 | stroke_color = rng.randint(MAX_STROKE_COLOR) 81 | stroke_width = rng.rand() * (MAX_STROKE_WIDTH - MIN_STROKE_WIDTH) + MIN_STROKE_WIDTH 82 | x = rng.randint(w) 83 | y = rng.randint(h) 84 | w = rng.randint(low=w // 4, high=w // 2) 85 | h = rng.randint(low=h // 4, high=h // 2) 86 | 87 | return SVG_RECT_TEMPLATE.format( 88 | id=id, 89 | x=x, y=y, 90 | w=w, h=h, 91 | r=stroke_color, g=stroke_color, b=stroke_color, 92 | sw=stroke_width 93 | ) 94 | 95 | 96 | def draw_ellipse(id, w, h, rng): 97 | stroke_color = rng.randint(MAX_STROKE_COLOR) 98 | stroke_width = rng.rand() * (MAX_STROKE_WIDTH - MIN_STROKE_WIDTH) + MIN_STROKE_WIDTH 99 | x = rng.randint(w) 100 | y = rng.randint(h) 101 | rx = rng.randint(low=w // 4, high=w // 2) 102 | ry = rng.randint(low=h // 4, high=h // 2) 103 | 104 | return SVG_ELLIPSE_TEMPLATE.format( 105 | id=id, 106 | x=x, y=y, 107 | rx=rx, ry=ry, 108 | r=stroke_color, g=stroke_color, b=stroke_color, 109 | sw=stroke_width 110 | ) 111 | 112 | 113 | def draw_path(id, w, h, rng): 114 | path_selector = { 115 | 0: draw_line, 116 | 1: draw_cubic_bezier_curve, 117 | 2: draw_rect, 118 | 3: draw_ellipse 119 | } 120 | 121 | stroke_type = rng.randint(len(path_selector)) 122 | 123 | return path_selector[stroke_type](id, w, h, rng) 124 | 125 | 126 | def gen_data(rng, batch_size): 127 | x = [] 128 | y = [] 129 | 130 | norm_stroke_width_txt = """stroke-width="{sw}" _stroke-width""".format(sw=NORM_STROKE_WIDTH) 131 | for file_id in range(batch_size): 132 | while True: 133 | svg = SVG_START_TEMPLATE.format( 134 | w=WIDTH, 135 | h=HEIGHT, 136 | rot=rng.randint(0, 180) 137 | ) 138 | svgpre = SVG_START_TEMPLATE 139 | 140 | for i in range(rng.randint(MAX_NUM_STROKES) + 1): 141 | path = draw_path( 142 | id=i, 143 | w=WIDTH, 144 | h=HEIGHT, 145 | rng=rng 146 | ) 147 | svg += path + '\n' 148 | svgpre += path + '\n' 149 | 150 | svg += SVG_END_TEMPLATE 151 | 152 | x_png = cairosvg.svg2png(bytestring=svg.encode('utf-8')) 153 | x_img = Image.open(io.BytesIO(x_png)) 154 | x_arr = np.array(x_img, np.float) 155 | 156 | # with open('data/s.svg', 'w') as f: 157 | # f.write(svg.replace('stroke-width', norm_stroke_width_txt)) 158 | 159 | y_png = cairosvg.svg2png(bytestring=svg.replace('stroke-width', norm_stroke_width_txt).encode('utf-8')) 160 | y_img = Image.open(io.BytesIO(y_png)) 161 | y_arr = np.array(y_img, np.float) 162 | 163 | x.append(np.reshape(x_arr[:, :, 0], (HEIGHT, WIDTH, 1))) 164 | y.append(np.reshape(y_arr[:, :, 0], (HEIGHT, WIDTH, 1))) 165 | 166 | if np.mean(x_arr) < 200 or np.mean(x_arr) > 253: 167 | continue 168 | else: 169 | x.append(np.reshape(x_arr[:, :, 0], (HEIGHT, WIDTH, 1))) 170 | y.append(np.reshape(y_arr[:, :, 0], (HEIGHT, WIDTH, 1))) 171 | break 172 | 173 | return np.array(x).astype(np.float32) / 255.0 174 | 175 | 176 | def test(): 177 | rnd = np.random.RandomState(SEED) 178 | 179 | for i in range(5): 180 | x_data = gen_data(rnd, 4) 181 | for j in range(4): 182 | cv2.imwrite('data/x_%d_%d.png' % (i, j), x_data[j] * 255) 183 | 184 | 185 | if __name__ == "__main__": 186 | test() 187 | -------------------------------------------------------------------------------- /model1.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import datetime 3 | 4 | from keras.models import Model 5 | from keras.layers import Input, Conv2D, UpSampling2D, BatchNormalization, Activation 6 | from keras.preprocessing import image 7 | 8 | from datagen import gen_data 9 | from utils import generate_random_gap 10 | 11 | LOAD_WEIGHTS = False 12 | ITERATIONS = 50000 13 | BATCH_SIZE = 4 14 | SEED = 1 15 | IMG_SHAPE = (352, 352, 1) 16 | IMG_HEIGHT, IMG_WIDTH, IMG_CHAN = IMG_SHAPE 17 | DATA_TYPE = 'DATA_GEN' 18 | 19 | 20 | def build_model(): 21 | input_tensor = Input((None, None, 1)) 22 | 23 | x = Conv2D(24, 5, strides=2, padding='same')(input_tensor) 24 | x = BatchNormalization()(x) 25 | x = Activation('relu')(x) 26 | 27 | x = Conv2D(64, 3, strides=2, padding='same')(x) 28 | x = BatchNormalization()(x) 29 | x = Activation('relu')(x) 30 | 31 | x = Conv2D(128, 3, strides=1, padding='same')(x) 32 | x = BatchNormalization()(x) 33 | x = Activation('relu')(x) 34 | 35 | x = Conv2D(256, 3, strides=2, padding='same')(x) 36 | x = BatchNormalization()(x) 37 | x = Activation('relu')(x) 38 | 39 | x = Conv2D(512, 3, strides=1, padding='same')(x) 40 | x = BatchNormalization()(x) 41 | x = Activation('relu')(x) 42 | 43 | x = Conv2D(512, 3, strides=2, padding='same')(x) 44 | x = BatchNormalization()(x) 45 | x = Activation('relu')(x) 46 | 47 | x = Conv2D(256, 3, strides=1, padding='same', )(x) 48 | x = BatchNormalization()(x) 49 | x = Activation('relu')(x) 50 | 51 | x = UpSampling2D(2)(x) 52 | x = Conv2D(128, 3, strides=1, padding='same')(x) 53 | x = BatchNormalization()(x) 54 | x = Activation('relu')(x) 55 | 56 | x = Conv2D(64, 3, strides=1, padding='same')(x) 57 | x = BatchNormalization()(x) 58 | x = Activation('relu')(x) 59 | 60 | x = UpSampling2D(2)(x) 61 | x = Conv2D(32, 3, strides=1, padding='same')(x) 62 | x = BatchNormalization()(x) 63 | x = Activation('relu')(x) 64 | 65 | x = Conv2D(16, 3, strides=1, padding='same')(x) 66 | x = BatchNormalization()(x) 67 | x = Activation('relu')(x) 68 | 69 | x = UpSampling2D(2)(x) 70 | x = Conv2D(8, 3, strides=1, padding='same')(x) 71 | x = BatchNormalization()(x) 72 | x = Activation('relu')(x) 73 | 74 | x = Conv2D(4, 3, strides=1, padding='same')(x) 75 | x = BatchNormalization()(x) 76 | x = Activation('relu')(x) 77 | 78 | x = UpSampling2D(2)(x) 79 | x = Conv2D(2, 3, strides=1, padding='same')(x) 80 | x = BatchNormalization()(x) 81 | x = Activation('relu')(x) 82 | 83 | x = Conv2D(1, 3, strides=1, padding='same')(x) 84 | x = BatchNormalization()(x) 85 | x = Activation('relu')(x) 86 | 87 | output_tensor = Conv2D(1, 3, padding='same', activation='sigmoid')(x) 88 | 89 | return Model(input_tensor, output_tensor) 90 | 91 | 92 | def data_generator(type='DATA_GEN'): 93 | ''' 94 | Generate data in specific type. 95 | DATA_GEN: generate data with random graphics, use small disc as gap, non-meaningful data 96 | DATA_GAP: generate data use small disc as gap on user line-drawings, meaningful data 97 | DATA_THIN: directly read offline data generated using normalization(thinning) 98 | 99 | :param type: DATA_GEN, DATA_GAP, DATA_THIN 100 | :return: x_data, y_data 101 | ''' 102 | # Use both 352 and 176 could achieve better performance 103 | gap_configs352 = [ 104 | [50, 600, 2, 8, 0, 1], 105 | [50, 600, 2, 10, 0, 2], 106 | [1, 2, 5, 15, 0, 3] 107 | ] 108 | 109 | # gap_configs176 = [ 110 | # [50, 200, 1, 4, 0, 1], 111 | # [50, 200, 1, 5, 0, 2], 112 | # [1, 2, 5, 10, 0, 3] 113 | # ] 114 | 115 | # gap_configs128 = [ 116 | # [50, 200, 2, 4, 0, 1], 117 | # [50, 200, 2, 5, 0, 2], 118 | # [1, 2, 5, 15, 0, 3] 119 | # ] 120 | 121 | # gap_configs64 = [ 122 | # [50, 200, 1, 4, 0, 1], 123 | # [50, 200, 1, 5, 0, 2], 124 | # [1, 2, 5, 10, 0, 3] 125 | # ] 126 | 127 | datagen = image.ImageDataGenerator( 128 | rescale=1 / 255., 129 | rotation_range=180, 130 | width_shift_range=0.1, 131 | height_shift_range=0.1, 132 | zoom_range=0.2, 133 | horizontal_flip=True, 134 | vertical_flip=True, 135 | fill_mode='reflect' 136 | ) 137 | 138 | if type == 'DATA_GAP': 139 | raw_generator_352 = datagen.flow_from_directory( 140 | './data/line', 141 | target_size=(IMG_HEIGHT, IMG_WIDTH), 142 | color_mode='grayscale', 143 | seed=SEED, 144 | class_mode=None, 145 | batch_size=BATCH_SIZE, 146 | shuffle=True, 147 | interpolation='bilinear' 148 | ) 149 | 150 | # raw_generator_176 = datagen.flow_from_directory( 151 | # './data/line', 152 | # target_size=(IMG_HEIGHT // 2, IMG_WIDTH // 2), 153 | # color_mode='grayscale', 154 | # seed=SEED, 155 | # class_mode=None, 156 | # batch_size=BATCH_SIZE // 2, 157 | # shuffle=True, 158 | # interpolation='bilinear' 159 | # ) 160 | 161 | while True: 162 | train_y_batch = next(raw_generator_352) 163 | train_x_batch, _ = generate_random_gap(train_y_batch, gap_configs352, SEED) 164 | 165 | yield train_x_batch, train_y_batch 166 | 167 | elif type == 'DATA_GEN': 168 | while True: 169 | # Size config is in datagen.py 170 | train_y_batch = gen_data(np.random.RandomState(SEED), BATCH_SIZE) 171 | train_x_batch, _ = generate_random_gap(train_y_batch, gap_configs352, SEED) 172 | 173 | yield train_x_batch, train_y_batch 174 | 175 | elif type == 'DATA_THIN': 176 | raw_generator_x = datagen.flow_from_directory( 177 | './data/thin', 178 | target_size=(IMG_HEIGHT, IMG_WIDTH), 179 | color_mode='grayscale', 180 | seed=SEED, 181 | class_mode=None, 182 | batch_size=BATCH_SIZE, 183 | shuffle=True, 184 | interpolation='bilinear' 185 | ) 186 | 187 | raw_generator_y = datagen.flow_from_directory( 188 | './data/line', 189 | target_size=(IMG_HEIGHT, IMG_WIDTH), 190 | color_mode='grayscale', 191 | seed=SEED, 192 | class_mode=None, 193 | batch_size=BATCH_SIZE, 194 | shuffle=True, 195 | interpolation='bilinear' 196 | ) 197 | 198 | while True: 199 | yield next(raw_generator_x), next(raw_generator_y) 200 | 201 | 202 | def train(): 203 | model = build_model() 204 | model.summary() 205 | 206 | if LOAD_WEIGHTS: 207 | model.load_weights('./weight/model1.h5') 208 | 209 | model.compile(loss='MSE', optimizer='Adam') 210 | 211 | data = data_generator(DATA_TYPE) 212 | start_time = datetime.datetime.now() 213 | 214 | for iteration in range(1, ITERATIONS + 1): 215 | 216 | train_y_batch, train_x_batch = next(data) 217 | loss = model.train_on_batch(train_x_batch, train_y_batch) 218 | 219 | print('[Time: %s] [Iteration: %d] [Loss: %f]' % (datetime.datetime.now() - start_time, iteration, loss)) 220 | 221 | if iteration % 200 == 0: 222 | model.save('./weight/model1_%d.h5' % iteration) 223 | 224 | model.save('./weight/model1.h5') 225 | 226 | 227 | if __name__ == "__main__": 228 | train() 229 | --------------------------------------------------------------------------------