├── .gitignore ├── LICENSE ├── README.md ├── ict ├── KMeans.py ├── MMCQ.py ├── OQ.py └── __init__.py ├── imgs ├── avatar_282x282.png ├── mmcq_vs_kmeans.png ├── photo1.jpg ├── photo2.jpg ├── photo3.jpg └── photo4.jpg └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *,cover 45 | 46 | # Translations 47 | *.mo 48 | *.pot 49 | 50 | # Django stuff: 51 | *.log 52 | 53 | # Sphinx documentation 54 | docs/_build/ 55 | 56 | # PyBuilder 57 | target/ 58 | .DS_Store 59 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Yusheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ImageColorTheme 2 | --- 3 | 4 | [Extract Color Themes from Images](http://blog.rainy.im/2015/11/25/extract-color-themes-from-images/) 5 | 6 | ### `pixData` 7 | 8 | ```py 9 | import numpy as np 10 | 11 | pixData = np.array([[R, G, B], [R, G, B],...], dtype=np.uint8) 12 | print(pixData.shape) 13 | # (h, w, d) 14 | ``` 15 | 16 | ### MMCQ 17 | 18 | ```py 19 | from ict.MMCQ import MMCQ 20 | mmcq = MMCQ(pixData, maxColor) 21 | theme= mmcq.quantize() 22 | ``` 23 | 24 | ### k-means 25 | 26 | ```py 27 | from ict.KMeans import KMeans 28 | km = KMeans(pixData, maxColor) 29 | theme = km.quantize() 30 | ``` 31 | 32 | ### Results 33 | 34 | ![](imgs/mmcq_vs_kmeans.png) 35 | -------------------------------------------------------------------------------- /ict/KMeans.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans as KM 2 | import numpy as np 3 | 4 | class KMDiy(object): 5 | """KMDiy = KMeans DIY""" 6 | MAX_ITER = 300 7 | def __init__(self, n_clusters=8, max_iter=300): 8 | super(KMDiy, self).__init__() 9 | self.n_clusters = n_clusters 10 | self.cluster_centers_ = None 11 | self.MAX_ITER = max_iter 12 | 13 | def randCent(self, data): 14 | dim = data.shape[1] 15 | centroids = np.zeros((self.n_clusters, dim)) 16 | for j in range(dim): 17 | minJ = min(data[:, j]) 18 | maxJ = max(data[:, j]) 19 | centroids[:, j] = minJ + float(maxJ - minJ) * np.random.rand(self.n_clusters) 20 | return centroids 21 | def fit(self, data): 22 | self.cluster_centers_ = self.randCent(data) 23 | size, dim = data.shape 24 | clusterAssment = np.zeros((size, 2)) 25 | clusterChanged = True 26 | iters = 0 27 | while clusterChanged: 28 | clusterChanged = False 29 | if iters > self.MAX_ITER: 30 | print("Reach MAX_ITER #{0}".format(self.MAX_ITER)) 31 | break 32 | 33 | for i in range(size): 34 | minDist = np.inf 35 | centIdx = -1 36 | for j in range(dim): 37 | d = self.distMeas(self.cluster_centers_[j, :], data[i, :]) 38 | if d < minDist: 39 | minDist = d 40 | centIdx = j 41 | if clusterAssment[i, 0] != centIdx: clusterChanged = True 42 | clusterAssment[i, :] = centIdx, minDist 43 | for j in range(dim): 44 | ptsInCluster = data[np.nonzero(clusterAssment[:,0] == j)] 45 | self.cluster_centers_[j, :] = np.mean(ptsInCluster)#new cluster centroid 46 | iters += 1 47 | 48 | def distMeas(self, v1, v2): 49 | return np.sqrt(sum([pow(x-y, 2) for x in v1 for y in v2])) 50 | 51 | class KMeans(object): 52 | """docstring for KMeans""" 53 | def __init__(self, pixData, maxColor, useSklearn=True): 54 | super(KMeans, self).__init__() 55 | h, w, d = pixData.shape 56 | self.pixData = np.reshape(pixData, (h * w, d)) 57 | self.maxColor = maxColor 58 | if useSklearn: 59 | self._KMeans = KM(n_clusters = maxColor) 60 | else: 61 | self._KMeans = KMDiy(n_clusters = maxColor) 62 | 63 | def quantize(self): 64 | self._KMeans.fit(self.pixData) 65 | return np.array(self._KMeans.cluster_centers_, dtype=np.uint8) 66 | -------------------------------------------------------------------------------- /ict/MMCQ.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from queue import PriorityQueue as PQueue 3 | from functools import reduce 4 | 5 | # PQueue lowest first! 6 | DEBUG = False 7 | 8 | class VBox(object): 9 | """ 10 | The color space is divided up into a set of 3D rectangular regions (called `vboxes`) 11 | """ 12 | def __init__(self, r1, r2, g1, g2, b1, b2, histo): 13 | super(VBox, self).__init__() 14 | self.r1 = r1 15 | self.r2 = r2 16 | self.g1 = g1 17 | self.g2 = g2 18 | self.b1 = b1 19 | self.b2 = b2 20 | self.histo = histo 21 | 22 | ziped = [(r1, r2), (g1, g2), (b1, b2)] 23 | sides = list(map(lambda t: abs(t[0] - t[1]) + 1, ziped)) 24 | self.vol = reduce(lambda x, y: x*y, sides) 25 | self.mAxis = sides.index(max(sides)) 26 | self.plane = ziped[:self.mAxis] + ziped[self.mAxis+1:] 27 | self.npixs = self.population() 28 | self.priority = self.npixs * -1 29 | def population(self): 30 | s = 0 31 | for r in range(self.r1, self.r2+1): 32 | for g in range(self.g1, self.g2+1): 33 | for b in range(self.b1, self.b2+1): 34 | s += self.histo[MMCQ.getColorIndex(r, g, b)] 35 | return int(s) 36 | def __lt__(self, vbox): #实现<操作 37 | return self.priority < vbox.priority 38 | def contains(self, r, g, b): 39 | # real r, g, b here 40 | pass 41 | 42 | class MMCQ(object): 43 | """ 44 | Modified Median Cut Quantization(MMCQ) 45 | Leptonica: http://tpgit.github.io/UnOfficialLeptDocs/leptonica/color-quantization.html 46 | """ 47 | MAX_ITERATIONS = 1000 48 | SIGBITS = 5 49 | def __init__(self, pixData, maxColor, fraction=0.85, sigbits=5): 50 | """ 51 | @pixData Image data [[R, G, B], ...] 52 | @maxColor Between [2, 256] 53 | @fraction Between [0.3, 0.9] 54 | @sigbits 5 or 6 55 | """ 56 | super(MMCQ, self).__init__() 57 | self.pixData = pixData 58 | if not 2 <= maxColor <= 256: 59 | raise AttributeError("maxColor should between [2, 256]!") 60 | self.maxColor = maxColor 61 | if not 0.3 <= fraction <= 0.9: 62 | raise AttributeError("fraction should between [0.3, 0.9]!") 63 | self.fraction = fraction 64 | if sigbits != 5 and sigbits != 6: 65 | raise AttributeError("sigbits should be either 5 or 6!") 66 | self.SIGBITS = sigbits 67 | self.rshift = 8 - sigbits 68 | 69 | self.h, self.w, _ = self.pixData.shape 70 | def getPixHisto(self): 71 | pixHisto = np.zeros(1 << (3 * self.SIGBITS)) 72 | for y in range(self.h): 73 | for x in range(self.w): 74 | r = self.pixData[y, x, 0] >> self.rshift 75 | g = self.pixData[y, x, 1] >> self.rshift 76 | b = self.pixData[y, x, 2] >> self.rshift 77 | 78 | pixHisto[self.getColorIndex(r, g, b)] += 1 79 | return pixHisto 80 | @classmethod 81 | def getColorIndex(self, r, g, b): 82 | return (r << (2 * self.SIGBITS)) + (g << self.SIGBITS) + b 83 | def createVbox(self, pixData): 84 | rmax = np.max(pixData[:,:,0]) >> self.rshift 85 | rmin = np.min(pixData[:,:,0]) >> self.rshift 86 | gmax = np.max(pixData[:,:,1]) >> self.rshift 87 | gmin = np.min(pixData[:,:,1]) >> self.rshift 88 | bmax = np.max(pixData[:,:,2]) >> self.rshift 89 | bmin = np.min(pixData[:,:,2]) >> self.rshift 90 | 91 | if DEBUG: 92 | print("Red range: {0}-{1}".format(rmin, rmax)) 93 | print("Green range: {0}-{1}".format(gmin, gmax)) 94 | print("Blue range: {0}-{1}".format(bmin, bmax)) 95 | return VBox(rmin, rmax, gmin, gmax, bmin, bmax,self.pixHisto) 96 | def medianCutApply(self, vbox): 97 | npixs = 0 98 | if vbox.mAxis == 0: 99 | # Red axis is largest 100 | plane = 0 101 | for r in range(vbox.r1, vbox.r2+1): 102 | for g in range(vbox.g1, vbox.g2+1): 103 | for b in range(vbox.b1, vbox.b2+1): 104 | h = vbox.histo[self.getColorIndex(r, g, b)] 105 | plane += h 106 | npixs += h 107 | if npixs >= vbox.npixs / 2.: 108 | left = r - vbox.r1 109 | right = vbox.r2 - r 110 | if left >= right: 111 | r2 = int(max(vbox.r1, r - 1 - left / 2)) 112 | else: 113 | r2 = int(min(vbox.r2 - 1, r + right / 2)) 114 | vbox1 = VBox(vbox.r1, r2, vbox.g1, vbox.g2, vbox.b1, vbox.b2, vbox.histo) 115 | vbox2 = VBox(r2+1, vbox.r2, vbox.g1, vbox.g2, vbox.b1, vbox.b2, vbox.histo) 116 | return vbox1, vbox2 117 | elif vbox.mAxis == 1: 118 | # Green axis is largest 119 | for g in range(vbox.g1, vbox.g2+1): 120 | plane = 0 121 | for r in range(vbox.r1, vbox.r2+1): 122 | for b in range(vbox.b1, vbox.b2+1): 123 | h = vbox.histo[self.getColorIndex(r, g, b)] 124 | plane += h 125 | npixs += h 126 | if npixs >= vbox.npixs / 2.: 127 | left = g - vbox.g1 128 | right = vbox.g2 - g 129 | if left >= right: 130 | g2 = int(max(vbox.g1, g - 1 - left / 2)) 131 | else: 132 | g2 = int(min(vbox.g2 - 1, g + right / 2)) 133 | vbox1 = VBox(vbox.r1, vbox.r2, vbox.g1, g2, vbox.b1, vbox.b2, vbox.histo) 134 | vbox2 = VBox(vbox.r1, vbox.r2, g2+1, vbox.g2, vbox.b1, vbox.b2, vbox.histo) 135 | return vbox1, vbox2 136 | else: 137 | # Blue axis is largest 138 | for b in range(vbox.b1, vbox.b2+1): 139 | plane = 0 140 | for r in range(vbox.r1, vbox.r2+1): 141 | for g in range(vbox.b1, vbox.b2+1): 142 | h = vbox.histo[self.getColorIndex(r, g, b)] 143 | plane += h 144 | npixs += h 145 | if npixs >= vbox.npixs / 2.: 146 | left = b - vbox.b1 147 | right = vbox.b2 - b 148 | if left >= right: 149 | b2 = int(max(vbox.b1, b - 1 - left / 2)) 150 | else: 151 | b2 = int(min(vbox.b2 - 1, b + right / 2)) 152 | vbox1 = VBox(vbox.r1, vbox.r2, vbox.g1, vbox.g2, vbox.b1, b2, vbox.histo) 153 | vbox2 = VBox(vbox.r1, vbox.r2, vbox.g1, vbox.g2, b2+1, vbox.b2, vbox.histo) 154 | return vbox1, vbox2 155 | def iterCut(self, maxColor, boxQueue, vol=False): 156 | ncolors = 1 157 | niters = 0 158 | while True: 159 | if ncolors >= maxColor: 160 | break 161 | vbox0 = boxQueue.get_nowait()[1] 162 | if vbox0.npixs == 0: 163 | print("Vbox has no pixels") 164 | boxQueue.put((vbox0.priority, vbox0)) 165 | continue 166 | vbox1, vbox2 = self.medianCutApply(vbox0) 167 | 168 | if vol: 169 | vbox1.priority *= vbox1.vol 170 | boxQueue.put((vbox1.priority, vbox1)) 171 | if vbox2 is not None: 172 | ncolors += 1 173 | if vol: 174 | vbox2.priority *= vbox2.vol 175 | boxQueue.put((vbox2.priority, vbox2)) 176 | niters += 1 177 | if niters >= self.MAX_ITERATIONS: 178 | print("infinite loop; perhaps too few pixels!") 179 | break 180 | return boxQueue 181 | def boxAvgColor(self, vbox): 182 | ntot = 0 183 | mult = 1 << self.rshift 184 | rsum = 0 185 | gsum = 0 186 | bsum = 0 187 | for r in range(vbox.r1, vbox.r2+1): 188 | for g in range(vbox.g1, vbox.g2+1): 189 | for b in range(vbox.b1, vbox.b2+1): 190 | h = vbox.histo[self.getColorIndex(r, g, b)] 191 | ntot += h 192 | rsum += int(h * (r + 0.5) * mult) 193 | gsum += int(h * (g + 0.5) * mult) 194 | bsum += int(h * (b + 0.5) * mult) 195 | if ntot == 0: 196 | avgs = map(lambda x: x * mult / 2, [vbox.r1 + vbox.r2 + 1, vbox.g1 + vbox.g2 + 1, vbox.b1 + vbox.b2 + 1]) 197 | else: 198 | avgs = map(lambda x : x / ntot, [rsum, gsum, bsum]) 199 | return list(map(lambda x: int(x), avgs)) 200 | 201 | def quantize(self): 202 | if self.h * self.w < self.maxColor: 203 | raise AttributeError("Image({0}x{1}) too small to be quantized".format(self.w, self.h)) 204 | self.pixHisto = self.getPixHisto() 205 | 206 | orgVbox = self.createVbox(self.pixData) 207 | pOneQueue = PQueue(self.maxColor) 208 | pOneQueue.put((orgVbox.priority, orgVbox)) 209 | popcolors = int(self.maxColor * self.fraction) 210 | 211 | pOneQueue = self.iterCut(popcolors, pOneQueue) 212 | 213 | boxQueue = PQueue(self.maxColor) 214 | while not pOneQueue.empty(): 215 | vbox = pOneQueue.get()[1] 216 | vbox.priority *= vbox.vol 217 | boxQueue.put((vbox.priority, vbox)) 218 | boxQueue = self.iterCut(self.maxColor - popcolors + 1, boxQueue, True) 219 | 220 | theme = [] 221 | while not boxQueue.empty(): 222 | theme.append(self.boxAvgColor(boxQueue.get()[1])) 223 | return theme 224 | -------------------------------------------------------------------------------- /ict/OQ.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class OctNode(object): 4 | """docstring for OctNode""" 5 | isLeaf = False 6 | n = 0 7 | r = 0 8 | g = 0 9 | b = 0 10 | next = None 11 | def __init__(self): 12 | super(OctNode, self).__init__() 13 | self.children = [] 14 | for i in range(8): 15 | self.children.append(None) 16 | 17 | class OQ(object): 18 | """Octree Quantization""" 19 | NCOLORS = 10 20 | def __init__(self, pixData, maxColor): 21 | super(OQ, self).__init__() 22 | self.pixData = pixData 23 | if not 2 <= maxColor <= 256: 24 | raise AttributeError("maxColor should be [2, 256]") 25 | 26 | self.maxColor = maxColor 27 | self.H, self.W, _ = self.pixData.shape 28 | self.leafNum = 0 29 | self.reducible = [] 30 | self.theme = [] 31 | 32 | for i in range(7): 33 | self.reducible.append(None) 34 | 35 | self.octree = OctNode() 36 | def buildOctree(self): 37 | for y in range(self.H): 38 | for x in range(self.W): 39 | pix = self.pixData[y, x, :] 40 | self.addColor(self.octree, pix, 0) 41 | def addColor(self, node, pix, level): 42 | if node.isLeaf: 43 | node.n += 1 44 | node.r += pix[0] 45 | node.g += pix[1] 46 | node.b += pix[2] 47 | else: 48 | rc = ((pix[0] >> (7 - level)) & 0x1) << 2 49 | gc = ((pix[1] >> (7 - level)) & 0x1) << 1 50 | bc = (pix[2] >> (7 - level)) & 0x1 51 | idx = rc | gc | bc 52 | 53 | if node.children[idx] is None: 54 | node.children[idx] = self.createOctNode(level + 1) 55 | self.addColor(node.children[idx], pix, level + 1) 56 | def createOctNode(self, level): 57 | node = OctNode() 58 | if level == 7: 59 | node.isLeaf = True 60 | self.leafNum += 1 61 | else: 62 | node.next = self.reducible[level] 63 | self.reducible[level] = node 64 | return node 65 | def reduceTree(self): 66 | lv = 6 67 | while self.reducible[lv] is None: 68 | lv -= 1 69 | node = self.reducible[lv] 70 | self.reducible[lv] = node.next 71 | 72 | r, g, b, c = (0, 0, 0, 0) 73 | for i in range(8): 74 | child = node.children[i] 75 | if child is None: 76 | continue 77 | r += child.r 78 | g += child.g 79 | b += child.b 80 | c += child.n 81 | self.leafNum -= 1 82 | 83 | node.isLeaf = True 84 | node.r = r 85 | node.g = g 86 | node.b = b 87 | node.n += c 88 | self.leafNum += 1 89 | def getColors(self, node): 90 | if node.isLeaf: 91 | [r, g, b] = list(map(lambda n: int(n[0] / n[1]), zip([node.r, node.g, node.b], [node.n]*3))) 92 | self.theme.append([r,g,b, node.n]) 93 | else: 94 | for i in range(8): 95 | if node.children[i] is not None: 96 | self.getColors(node.children[i]) 97 | def quantize(self): 98 | self.buildOctree() 99 | if self.leafNum <= self.maxColor: 100 | raise AttributeError("Image too small to be quantized!") 101 | while self.leafNum > (self.maxColor + self.NCOLORS): 102 | self.reduceTree() 103 | # print("leafNum = {0}".format(self.leafNum)) 104 | self.getColors(self.octree) 105 | # print(len(self.theme)) 106 | self.theme = sorted(self.theme, key=lambda c: -1*c[1]) 107 | return list(map(lambda l: l[:-1],self.theme[:self.maxColor])) 108 | -------------------------------------------------------------------------------- /ict/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainyear/ImageColorTheme/39be9d7c449f1815d5c2f934eab24218779f8196/ict/__init__.py -------------------------------------------------------------------------------- /imgs/avatar_282x282.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainyear/ImageColorTheme/39be9d7c449f1815d5c2f934eab24218779f8196/imgs/avatar_282x282.png -------------------------------------------------------------------------------- /imgs/mmcq_vs_kmeans.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainyear/ImageColorTheme/39be9d7c449f1815d5c2f934eab24218779f8196/imgs/mmcq_vs_kmeans.png -------------------------------------------------------------------------------- /imgs/photo1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainyear/ImageColorTheme/39be9d7c449f1815d5c2f934eab24218779f8196/imgs/photo1.jpg -------------------------------------------------------------------------------- /imgs/photo2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainyear/ImageColorTheme/39be9d7c449f1815d5c2f934eab24218779f8196/imgs/photo2.jpg -------------------------------------------------------------------------------- /imgs/photo3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainyear/ImageColorTheme/39be9d7c449f1815d5c2f934eab24218779f8196/imgs/photo3.jpg -------------------------------------------------------------------------------- /imgs/photo4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainyear/ImageColorTheme/39be9d7c449f1815d5c2f934eab24218779f8196/imgs/photo4.jpg -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!../venv3/bin/python 2 | import numpy as np 3 | import cv2 as cv 4 | import matplotlib.pyplot as plt 5 | from mpl_toolkits.mplot3d import Axes3D 6 | import matplotlib.gridspec as gridspec 7 | 8 | import time 9 | 10 | from ict.MMCQ import MMCQ 11 | from ict.OQ import OQ 12 | from ict.KMeans import KMeans 13 | 14 | def doWhat(): 15 | pixData = getPixData('imgs/avatar_282x282.png') 16 | theme = MMCQ(pixData, 16).quantize() 17 | h, w, _ = pixData.shape 18 | 19 | mask = np.zeros(pixData.shape, dtype=np.uint8) 20 | def dist(a, b): 21 | return (a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2 + (a[2] - b[2]) ** 2 22 | for y in range(h): 23 | for x in range(w): 24 | p = pixData[y,x,:] 25 | dists = list(map(lambda t: dist(p, t), theme)) 26 | mask[y,x,:] = np.array(theme[dists.index(min(dists))], np.uint8) 27 | plt.subplot(121), plt.imshow(pixData) 28 | plt.subplot(122), plt.imshow(mask) 29 | plt.show() 30 | def imgPixInColorSpace(pixData): 31 | fig = plt.figure() 32 | gs = gridspec.GridSpec(1, 3) 33 | 34 | im = fig.add_subplot(gs[0,0]) 35 | im.imshow(pixData) 36 | im.set_title("2D Image") 37 | 38 | ax = fig.add_subplot(gs[0,1:3], projection='3d') 39 | colors = np.reshape(pixData, (pixData.shape[0] * pixData.shape[1], pixData.shape[2])) 40 | colors = colors / 255.0 41 | ax.scatter(pixData[:,:,0], pixData[:,:,1], pixData[:,:,2], c=colors) 42 | ax.set_xlabel("Red", color='red') 43 | ax.set_ylabel("Green", color='green') 44 | ax.set_zlabel("Blue", color='blue') 45 | 46 | ax.set_title("Image in Color Space") 47 | 48 | ax.set_xlim(0, 255) 49 | ax.set_ylim(0, 255) 50 | ax.set_zlim(0, 255) 51 | 52 | ax.xaxis.set_ticks([]) 53 | ax.yaxis.set_ticks([]) 54 | ax.zaxis.set_ticks([]) 55 | 56 | 57 | plt.show() 58 | 59 | def imgPalette(imgs, themes, titles): 60 | N = len(imgs) 61 | 62 | fig = plt.figure() 63 | gs = gridspec.GridSpec(len(imgs), len(themes)+1) 64 | print(N) 65 | for i in range(N): 66 | im = fig.add_subplot(gs[i, 0]) 67 | im.imshow(imgs[i]) 68 | im.set_title("Image %s" % str(i+1)) 69 | im.xaxis.set_ticks([]) 70 | im.yaxis.set_ticks([]) 71 | 72 | t = 1 73 | for themeLst in themes: 74 | theme = themeLst[i] 75 | pale = np.zeros(imgs[i].shape, dtype=np.uint8) 76 | h, w, _ = pale.shape 77 | ph = h / len(theme) 78 | for y in range(h): 79 | pale[y,:,:] = np.array(theme[int(y / ph)], dtype=np.uint8) 80 | pl = fig.add_subplot(gs[i, t]) 81 | pl.imshow(pale) 82 | pl.set_title(titles[t-1]) 83 | pl.xaxis.set_ticks([]) 84 | pl.yaxis.set_ticks([]) 85 | 86 | t += 1 87 | 88 | plt.show() 89 | 90 | def getPixData(imgfile='imgs/avatar_282x282.png'): 91 | return cv.cvtColor(cv.imread(imgfile, 1), cv.COLOR_BGR2RGB) 92 | 93 | def testColorSpace(): 94 | imgfile = 'imgs/avatar_282x282.png' 95 | pixData = getPixData(imgfile) 96 | imgPixInColorSpace(cv.resize(pixData, None, fx=0.2, fy=0.2)) 97 | 98 | def testMMCQ(pixDatas, maxColor): 99 | start = time.process_time() 100 | themes = list(map(lambda d: MMCQ(d, maxColor).quantize(), pixDatas)) 101 | print("MMCQ Time cost: {0}".format(time.process_time() - start)) 102 | return themes 103 | # imgPalette(pixDatas, themes) 104 | def testOQ(pixDatas, maxColor): 105 | start = time.process_time() 106 | themes = list(map(lambda d: OQ(d, maxColor).quantize(), pixDatas)) 107 | print("OQ Time cost: {0}".format(time.process_time() - start)) 108 | return themes 109 | # imgPalette(pixDatas, themes) 110 | def testKmeans(pixDatas, maxColor, skl=True): 111 | start = time.process_time() 112 | themes = list(map(lambda d: KMeans(d, maxColor, skl).quantize(), pixDatas)) 113 | print("KMeans Time cost: {0}".format(time.process_time() - start)) 114 | return themes 115 | def vs(): 116 | imgs = map(lambda i: 'imgs/photo%s.jpg' % i, range(1,5)) 117 | pixDatas = list(map(getPixData, imgs)) 118 | maxColor = 7 119 | themes = [testMMCQ(pixDatas, maxColor), testOQ(pixDatas, maxColor), testKmeans(pixDatas, maxColor)] 120 | imgPalette(pixDatas, themes, ["MMCQ Palette", "OQ Palette", "KMeans Palette"]) 121 | 122 | def kmvs(): 123 | imgs = map(lambda i: 'imgs/photo%s.jpg' % i, range(1,5)) 124 | pixDatas = list(map(getPixData, imgs)) 125 | maxColor = 7 126 | themes = [testKmeans(pixDatas, maxColor), testKmeans(pixDatas, maxColor, False)] 127 | imgPalette(pixDatas, themes, ["KMeans Palette", "KMeans DIY"]) 128 | 129 | 130 | if __name__ == '__main__': 131 | # testColorSpace() 132 | # testMMCQ() 133 | # kmvs() 134 | print(testKmeans([getPixData()], 7, False)) 135 | print(testKmeans([getPixData()], 7)) 136 | --------------------------------------------------------------------------------