├── README.md ├── examples ├── flower.jpg ├── flower_result.png └── simple.py ├── pixel_to_svg ├── __init__.py └── pixel_to_svg.py ├── requirements.txt └── setup.py /README.md: -------------------------------------------------------------------------------- 1 | This is a simple unsupervised segmentation based method for turning a raster image into SVG. 2 | 3 | ![](https://raw.githubusercontent.com/mehdidc/pixel_to_svg/master/examples/flower_result.png) 4 | 5 | # How to install? 6 | 7 | You first need to install pypotrace. 8 | 9 | Here are the steps to install pypotrace: 10 | 11 | 1. `sudo apt-get install build-essential python-dev libagg-dev libpotrace-dev pkg-config` 12 | 2. `git clone https://github.com/mehdidc/pypotrace` 13 | 3. `cd pypotrace` 14 | 4. `git checkout to_xml` 15 | 5. `rm -f potrace/*.c potrace/*.cpp potrace/agg/*.cpp potrace/*.so potrace/agg/*.so` 16 | 6. `pip install .` 17 | 18 | 19 | Once pypotrace is available, you can install this repo. 20 | Here are the steps: 21 | 22 | 1. `git clone https://github.com/mehdidc/pixel_to_svg` 23 | 2. `cd pixel_to_svg` 24 | 3. `python setup.py install` 25 | 26 | 27 | # How to use ? 28 | 29 | Please check the example in 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /examples/flower.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mehdidc/pixel_to_svg/4110f22df427d77392982d4d4f3569b2d4731030/examples/flower.jpg -------------------------------------------------------------------------------- /examples/flower_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mehdidc/pixel_to_svg/4110f22df427d77392982d4d4f3569b2d4731030/examples/flower_result.png -------------------------------------------------------------------------------- /examples/simple.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.io import imread 3 | import matplotlib.pyplot as plt 4 | from pixel_to_svg import graph_seg 5 | from pixel_to_svg import render_svg 6 | from pixel_to_svg import save_svg 7 | from pixel_to_svg import to_svg 8 | img = imread("flower.jpg") 9 | if img.shape[2] == 4: 10 | img = img[:,:,0:3] 11 | 12 | # segmentation step. 13 | # here each pixel is mapped to a category. 14 | # this internally uses quickshift and hierarchical_merge 15 | # from scikit-image (see the code for more info). 16 | # In principle, any segmentation method can be used 17 | # Feel free to replace this with you preferred method. 18 | seg = graph_seg( 19 | img, 20 | thresh=80, 21 | ) 22 | # Given a segmented image, turn it to SVG. 23 | # This internally uses `potrace`. 24 | svg = to_svg(img, seg) 25 | 26 | # Convert SVG back to raster to display it 27 | # and compare it to original image 28 | img2 = render_svg(svg) 29 | 30 | fig = plt.subplots(nrows=1, ncols=2) 31 | plt.subplot(1,2,1) 32 | plt.imshow(img) 33 | plt.subplot(1,2,2) 34 | plt.imshow(img2) 35 | plt.show() 36 | 37 | # save the SVG 38 | save_svg(svg, "out.svg") 39 | -------------------------------------------------------------------------------- /pixel_to_svg/__init__.py: -------------------------------------------------------------------------------- 1 | from .pixel_to_svg import to_svg, render_svg, graph_seg, save_svg 2 | -------------------------------------------------------------------------------- /pixel_to_svg/pixel_to_svg.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import os 3 | from collections import namedtuple 4 | from subprocess import call 5 | from io import StringIO, BytesIO 6 | 7 | import numpy as np 8 | from imageio import imread, imsave 9 | from skimage import segmentation 10 | from skimage.future import graph 11 | 12 | from svgpathtools import svg2paths2 13 | from svgpathtools import disvg, wsvg 14 | from cairosvg import svg2png 15 | 16 | SVG = namedtuple("SVG", "paths attributes") 17 | 18 | def to_svg(img, seg, nb_layers=None, palette=None, opacity=None): 19 | if len(seg.shape) == 2: 20 | if nb_layers is None: 21 | nb_layers = seg.max() + 1 22 | masks = np.zeros((seg.shape[0], seg.shape[1], nb_layers)).astype(bool) 23 | m = masks.reshape((-1,nb_layers)) 24 | s = seg.reshape((-1,)) 25 | m[np.arange(len(m)), s] = 1 26 | assert np.all(masks.argmax(axis=2) == seg) 27 | else: 28 | masks = seg 29 | P = [] 30 | A = [] 31 | for layer in range(masks.shape[2]): 32 | mask = masks[:,:,layer] 33 | if np.all(mask==0): 34 | continue 35 | paths, attrs, svg_attrs = binary_image_to_svg2(mask) 36 | for attr in attrs: 37 | if palette is None: 38 | r, g, b, *rest = img[mask].mean(axis=0) 39 | else: 40 | r, g, b = palette[layer] 41 | r = int(r) 42 | g = int(g) 43 | b = int(b) 44 | col = f"rgb({r},{g},{b})" 45 | attr["stroke"] = col 46 | attr["fill"] = col 47 | if opacity: 48 | attr["opacity"] = opacity[layer] 49 | P.extend(paths) 50 | A.extend(attrs) 51 | return SVG(paths=P, attributes=A) 52 | 53 | def render_svg(svg, width=None, height=None): 54 | drawing = wsvg( 55 | paths=svg.paths, 56 | attributes=svg.attributes, 57 | paths2Drawing=True, 58 | ) 59 | fd = StringIO() 60 | drawing.write(fd) 61 | fo = BytesIO() 62 | svg2png(bytestring=fd.getvalue(), write_to=fo, output_width=width, output_height=height) 63 | fo.seek(0) 64 | return imread(fo, format="png") 65 | 66 | 67 | def wsvg(paths=None, colors=None, 68 | filename=os.path.join(os.getcwd(), 'disvg_output.svg'), 69 | stroke_widths=None, nodes=None, node_colors=None, node_radii=None, 70 | openinbrowser=False, timestamp=False, 71 | margin_size=0.1, mindim=600, dimensions=None, 72 | viewbox=None, text=None, text_path=None, font_size=None, 73 | attributes=None, svg_attributes=None, svgwrite_debug=False, paths2Drawing=False): 74 | #NB: this code is originally from . 75 | # Thanks tho @mathandy 76 | """Convenience function; identical to disvg() except that 77 | openinbrowser=False by default. See disvg() docstring for more info.""" 78 | return disvg(paths, colors=colors, filename=filename, 79 | stroke_widths=stroke_widths, nodes=nodes, 80 | node_colors=node_colors, node_radii=node_radii, 81 | openinbrowser=openinbrowser, timestamp=timestamp, 82 | margin_size=margin_size, mindim=mindim, dimensions=dimensions, 83 | viewbox=viewbox, text=text, text_path=text_path, font_size=font_size, 84 | attributes=attributes, svg_attributes=svg_attributes, 85 | svgwrite_debug=svgwrite_debug, paths2Drawing=paths2Drawing) 86 | 87 | def save_svg(svg, out="output.svg"): 88 | wsvg( 89 | paths=svg.paths, 90 | attributes=svg.attributes, 91 | filename=out, 92 | ) 93 | 94 | def binary_image_to_svg(seg): 95 | seg = (1-seg) 96 | seg = (seg*255).astype("uint8") 97 | seg = seg[::-1] 98 | name = str(uuid.uuid4()) 99 | bmp = name + ".bmp" 100 | svg = name + ".svg" 101 | imsave(bmp, seg) 102 | call(f"potrace -s {bmp}", shell=True) 103 | paths = svg2paths2(svg) 104 | os.remove(bmp) 105 | os.remove(svg) 106 | return paths 107 | 108 | def binary_image_to_svg2(mask): 109 | """ 110 | same as binary_image_to_svg, but use `pypotrace` 111 | instead of calling `potrace` from shell 112 | it is more convenient and faster, this way 113 | """ 114 | import potrace 115 | bmp = potrace.Bitmap(mask) 116 | bmp.trace() 117 | xml = bmp.to_xml() 118 | fo = StringIO() 119 | fo.write(xml) 120 | fo.seek(0) 121 | paths = svg2paths2(fo) 122 | return paths 123 | 124 | 125 | 126 | def graph_seg(img, max_dist=200, thresh=80, sigma=255.0): 127 | """ 128 | segment an image using quickshift and merge_hierarchical 129 | from scikit-image. In principle, any segmentation method 130 | can be used, this is just one example. 131 | """ 132 | img = img.astype("float") 133 | seg = segmentation.quickshift( 134 | img, 135 | max_dist=max_dist, 136 | ) 137 | g = graph.rag_mean_color( 138 | img, 139 | seg, 140 | sigma=sigma, 141 | ) 142 | seg = graph.merge_hierarchical( 143 | seg, 144 | g, 145 | thresh=thresh, 146 | rag_copy=False, 147 | in_place_merge=True, 148 | merge_func=_merge_mean_color, 149 | weight_func=_weight_mean_color 150 | ) 151 | return seg 152 | 153 | 154 | def _weight_mean_color(graph, src, dst, n): 155 | """ 156 | NB: this code is originally from . 157 | Thanks to scikit-image authors. 158 | 159 | Callback to handle merging nodes by recomputing mean color. 160 | 161 | The method expects that the mean color of `dst` is already computed. 162 | 163 | Parameters 164 | ---------- 165 | graph : RAG 166 | The graph under consideration. 167 | src, dst : int 168 | The vertices in `graph` to be merged. 169 | n : int 170 | A neighbor of `src` or `dst` or both. 171 | 172 | Returns 173 | ------- 174 | data : dict 175 | A dictionary with the `"weight"` attribute set as the absolute 176 | difference of the mean color between node `dst` and `n`. 177 | """ 178 | 179 | diff = graph.nodes[dst]['mean color'] - graph.nodes[n]['mean color'] 180 | diff = np.linalg.norm(diff) 181 | return {'weight': diff} 182 | 183 | 184 | def _merge_mean_color(graph, src, dst): 185 | """ 186 | NB: this code is originally from . 187 | Thanks to scikit-image authors. 188 | 189 | Callback called before merging two nodes of a mean color distance graph. 190 | 191 | This method computes the mean color of `dst`. 192 | 193 | Parameters 194 | ---------- 195 | graph : RAG 196 | The graph under consideration. 197 | src, dst : int 198 | The vertices in `graph` to be merged. 199 | """ 200 | graph.nodes[dst]['total color'] += graph.nodes[src]['total color'] 201 | graph.nodes[dst]['pixel count'] += graph.nodes[src]['pixel count'] 202 | graph.nodes[dst]['mean color'] = (graph.nodes[dst]['total color'] / 203 | graph.nodes[dst]['pixel count']) 204 | 205 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | imageio 3 | scikit-image 4 | cairosvg 5 | git+https://github.com/mehdidc/svgpathtools 6 | git+https://github.com/mehdidc/pypotrace 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Note: To use the 'upload' functionality of this file, you must: 5 | # $ pipenv install twine --dev 6 | 7 | import io 8 | import os 9 | import sys 10 | from shutil import rmtree 11 | 12 | from setuptools import find_packages, setup, Command 13 | 14 | # Package meta-data. 15 | NAME = 'pixel_to_svg' 16 | DESCRIPTION = 'My short description for my project.' 17 | URL = 'https://github.com/mehdidc/pixel_to_svg' 18 | EMAIL = 'mehdicherti@gmail.com' 19 | AUTHOR = 'Mehdi Cherti' 20 | REQUIRES_PYTHON = '>=3.6.0' 21 | VERSION = '0.1.0' 22 | 23 | # What packages are required for this module to be executed? 24 | REQUIRED = [ 25 | "numpy", 26 | "imageio", 27 | "scikit-image", 28 | "cairosvg", 29 | "svgpathtools", 30 | "pypotrace", 31 | ] 32 | # What packages are optional? 33 | EXTRAS = { 34 | # 'fancy feature': ['django'], 35 | } 36 | 37 | # The rest you shouldn't have to touch too much :) 38 | # ------------------------------------------------ 39 | # Except, perhaps the License and Trove Classifiers! 40 | # If you do change the License, remember to change the Trove Classifier for that! 41 | 42 | here = os.path.abspath(os.path.dirname(__file__)) 43 | 44 | # Import the README and use it as the long-description. 45 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 46 | try: 47 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 48 | long_description = '\n' + f.read() 49 | except FileNotFoundError: 50 | long_description = DESCRIPTION 51 | 52 | # Load the package's __version__.py module as a dictionary. 53 | about = {} 54 | if not VERSION: 55 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_") 56 | with open(os.path.join(here, project_slug, '__version__.py')) as f: 57 | exec(f.read(), about) 58 | else: 59 | about['__version__'] = VERSION 60 | 61 | 62 | class UploadCommand(Command): 63 | """Support setup.py upload.""" 64 | 65 | description = 'Build and publish the package.' 66 | user_options = [] 67 | 68 | @staticmethod 69 | def status(s): 70 | """Prints things in bold.""" 71 | print('\033[1m{0}\033[0m'.format(s)) 72 | 73 | def initialize_options(self): 74 | pass 75 | 76 | def finalize_options(self): 77 | pass 78 | 79 | def run(self): 80 | try: 81 | self.status('Removing previous builds…') 82 | rmtree(os.path.join(here, 'dist')) 83 | except OSError: 84 | pass 85 | 86 | self.status('Building Source and Wheel (universal) distribution…') 87 | os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable)) 88 | 89 | self.status('Uploading the package to PyPI via Twine…') 90 | os.system('twine upload dist/*') 91 | 92 | self.status('Pushing git tags…') 93 | os.system('git tag v{0}'.format(about['__version__'])) 94 | os.system('git push --tags') 95 | 96 | sys.exit() 97 | 98 | 99 | # Where the magic happens: 100 | setup( 101 | name=NAME, 102 | version=about['__version__'], 103 | description=DESCRIPTION, 104 | long_description=long_description, 105 | long_description_content_type='text/markdown', 106 | author=AUTHOR, 107 | author_email=EMAIL, 108 | python_requires=REQUIRES_PYTHON, 109 | url=URL, 110 | packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), 111 | # If your package is a single module, use this instead of 'packages': 112 | # py_modules=['mypackage'], 113 | 114 | # entry_points={ 115 | # 'console_scripts': ['mycli=mymodule:cli'], 116 | # }, 117 | install_requires=REQUIRED, 118 | extras_require=EXTRAS, 119 | include_package_data=True, 120 | license='MIT', 121 | classifiers=[ 122 | # Trove classifiers 123 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 124 | 'License :: OSI Approved :: MIT License', 125 | 'Programming Language :: Python', 126 | 'Programming Language :: Python :: 3', 127 | 'Programming Language :: Python :: 3.6', 128 | 'Programming Language :: Python :: Implementation :: CPython', 129 | 'Programming Language :: Python :: Implementation :: PyPy' 130 | ], 131 | # $ setup.py publish support. 132 | cmdclass={ 133 | 'upload': UploadCommand, 134 | }, 135 | ) 136 | --------------------------------------------------------------------------------