├── .gitignore ├── Images ├── comp_results.png └── wavelet_decomp.png ├── README.md ├── WaveletImageCoder.py ├── article.ipynb ├── compress.py ├── decompress.py ├── dog.png ├── dog.ztc ├── utils.py └── zerotree.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/* 2 | __pycache__/* 3 | zerotree_test.txt 4 | /*.png 5 | /*.jpg 6 | /*.ztc 7 | *.ipynb -------------------------------------------------------------------------------- /Images/comp_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aparande/EZWImageCompression/47dc0c8e515697e496ebc6763605cc79c5b91c68/Images/comp_results.png -------------------------------------------------------------------------------- /Images/wavelet_decomp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aparande/EZWImageCompression/47dc0c8e515697e496ebc6763605cc79c5b91c68/Images/wavelet_decomp.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EZWImageCompression 2 | 3 | A Python implementation of an end-to-end image compression system using the Embedded Zero-Trees of Wavelet Transforms algorithm 4 | 5 | You can read about it on [Medium](https://medium.com/swlh/end-to-end-image-compression-using-embedded-zero-trees-of-wavelet-transforms-ezw-2362f2a965f7) 6 | 7 | ![Results](Images/comp_results.png) 8 | 9 | ## Usage 10 | ### Compress an image 11 | ``` 12 | usage: compress.py [-h] [--output OUTPUT] [--max-passes MAX_PASSES] file 13 | 14 | Compress a file to the ZTC format 15 | 16 | positional arguments: 17 | file File to compress 18 | 19 | optional arguments: 20 | -h, --help show this help message and exit 21 | --output OUTPUT Output filename 22 | --max-passes MAX_PASSES 23 | ``` 24 | ### Decompress an image 25 | ``` 26 | usage: decompress.py [-h] [--output OUTPUT] file 27 | 28 | Decompress a file from the ZTC format 29 | 30 | positional arguments: 31 | file File to decompress 32 | 33 | optional arguments: 34 | -h, --help show this help message and exit 35 | --output OUTPUT Output filename 36 | ``` 37 | ## Credits 38 | Dog photo by [Lilian Joore](https://unsplash.com/@lilian66?utm_source=unsplash&utm_medium=referral&utm_content=creditCopyText) on [Unsplash](https://unsplash.com/s/photos/dog?utm_source=unsplash&utm_medium=referral&utm_content=creditCopyText) 39 | -------------------------------------------------------------------------------- /WaveletImageCoder.py: -------------------------------------------------------------------------------- 1 | from bitarray import bitarray 2 | import numpy as np 3 | import utils 4 | from zerotree import ZeroTreeDecoder, ZeroTreeEncoder, ZeroTreeScan 5 | import pywt 6 | 7 | SOI_MARKER = bytes.fromhex("FFD8") # Start of Image 8 | SOS_MARKER = bytes.fromhex("FFDA") # Start of Scan 9 | EOI_MARKER = bytes.fromhex("FFDC") # End of Image 10 | STUFFED_MARKER = bytes.fromhex("FF00") 11 | 12 | WAVELET = "db2" 13 | 14 | class WaveletImageEncoder(): 15 | def __init__(self, max_passes): 16 | self.max_passes = max_passes 17 | 18 | def encode(self, image, filename): 19 | M, N = image.shape[:2] 20 | 21 | with open(filename, 'wb') as fh: 22 | # Write the header 23 | fh.write(SOI_MARKER) 24 | 25 | fh.write(M.to_bytes(2, "big")) 26 | fh.write(N.to_bytes(2, "big")) 27 | 28 | image = image.astype(np.float64) 29 | 30 | encoders = self.build_encoders(image) 31 | for enc in encoders: 32 | fh.write(int(enc.start_thresh).to_bytes(2, 'big')) 33 | 34 | encoders = [iter(enc) for enc in encoders] 35 | 36 | i = 0 37 | writes = float('inf') 38 | 39 | while writes != 0 and i < self.max_passes: 40 | writes = 0 41 | for enc_iter in encoders: 42 | fh.write(SOS_MARKER) 43 | scan = next(enc_iter, None) 44 | if scan is not None: 45 | scan.tofile(fh) 46 | writes += 1 47 | i += 1 48 | 49 | fh.write(EOI_MARKER) 50 | 51 | def build_encoders(self, image): 52 | ycbcr = utils.RGB2YCbCr(image) 53 | encoders = [] 54 | M, N = image.shape[:2] 55 | for i in range(3): 56 | channel = ycbcr[:, :, i] if i == 0 else utils.resize(ycbcr[:, :, i], M // 2, N // 2) 57 | encoders.append(ZeroTreeEncoder(channel, WAVELET)) 58 | 59 | return encoders 60 | 61 | class WaveletImageDecoder(): 62 | def decode(self, filename): 63 | with open(filename, 'rb') as fh: 64 | soi = fh.read(2) 65 | if soi != SOI_MARKER: 66 | raise Exception("Start of Image marker not found!") 67 | 68 | M = int.from_bytes(fh.read(2), "big") 69 | N = int.from_bytes(fh.read(2), "big") 70 | 71 | thresholds = [int.from_bytes(fh.read(2), 'big') for _ in range(3)] 72 | decoders = self.build_decoders(M, N, thresholds) 73 | 74 | cursor = fh.read(2) 75 | if cursor != SOS_MARKER: 76 | raise Exception("Scan's not found!") 77 | 78 | isDominant = True 79 | while cursor != EOI_MARKER: 80 | for i, dec in enumerate(decoders): 81 | buffer = bytes() 82 | while len(buffer) < 2 or (buffer[-2:] != SOS_MARKER and not (buffer[-2:] == EOI_MARKER and i == 2)): 83 | buffer += fh.read(1) 84 | 85 | buffer, cursor = buffer[:-2], buffer[-2:] 86 | buffer = buffer.replace(STUFFED_MARKER, b'\xff') 87 | 88 | ba = bitarray() 89 | ba.frombytes(buffer) 90 | 91 | if len(ba) != 0: 92 | scan = ZeroTreeScan.from_bits(ba, isDominant) 93 | dec.process(scan) 94 | 95 | isDominant = not isDominant 96 | 97 | image = np.zeros((M, N, 3)) 98 | for i, dec in enumerate(decoders): 99 | image[:, :, i] = dec.getImage() if i == 0 else utils.resize(dec.getImage(), M, N) 100 | 101 | return utils.YCbCr2RGB(image).astype('uint8') 102 | 103 | def build_decoders(self, M, N, thresholds): 104 | decoders = [] 105 | for i in range(3): 106 | max_thresh = thresholds[i] 107 | if i == 0: 108 | decoders.append(ZeroTreeDecoder(M, N, max_thresh, WAVELET)) 109 | else: 110 | decoders.append(ZeroTreeDecoder(M // 2, N // 2, max_thresh, WAVELET)) 111 | return decoders -------------------------------------------------------------------------------- /compress.py: -------------------------------------------------------------------------------- 1 | from WaveletImageCoder import WaveletImageEncoder 2 | import argparse 3 | from PIL import Image 4 | import numpy as np 5 | from utils import psnr, comp_ratio, bpp 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser(description="Compress a file to the ZTC format") 9 | parser.add_argument('file', type=str, help="File to compress") 10 | parser.add_argument('--output', type=str, default="output.ztc", help="Output filename") 11 | parser.add_argument("--max-passes", type=float, default=float('inf'), help="Maximum number of Image Packets") 12 | args = parser.parse_args() 13 | 14 | print("Compressing Image") 15 | 16 | img = np.array(Image.open(args.file)) 17 | encoder = WaveletImageEncoder(args.max_passes) 18 | encoder.encode(img, args.output) 19 | 20 | print(f"Saved output to {args.output}") 21 | print(f"BPP: {bpp(args.output)}") 22 | print(f"Compression Ratio: {comp_ratio(args.file, args.output)}") 23 | -------------------------------------------------------------------------------- /decompress.py: -------------------------------------------------------------------------------- 1 | from WaveletImageCoder import WaveletImageDecoder 2 | import argparse 3 | from PIL import Image 4 | import numpy as np 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser(description="Decompress a file from the ZTC format") 8 | parser.add_argument('file', type=str, help="File to decompress") 9 | parser.add_argument('--output', type=str, default="output.png", help="Output filename") 10 | args = parser.parse_args() 11 | 12 | decoder = WaveletImageDecoder() 13 | image = decoder.decode(args.file) 14 | Image.fromarray(image).save(args.output) 15 | -------------------------------------------------------------------------------- /dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aparande/EZWImageCompression/47dc0c8e515697e496ebc6763605cc79c5b91c68/dog.png -------------------------------------------------------------------------------- /dog.ztc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aparande/EZWImageCompression/47dc0c8e515697e496ebc6763605cc79c5b91c68/dog.ztc -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import numpy as np 4 | from os import stat 5 | from bitarray import bitarray 6 | 7 | SOI_MARKER = bytes.fromhex("FFD8") 8 | 9 | def bytestuff(bits): 10 | marker = bitarray('11111111') 11 | zeros = bitarray('00000000') 12 | 13 | stuffed_arr = bitarray() 14 | 15 | idx = 0 16 | while idx < len(bits): 17 | cursor = bits[idx:idx + 8] 18 | stuffed_arr.extend(cursor) 19 | if cursor == marker: 20 | stuffed_arr.extend(zeros) 21 | idx += 8 22 | 23 | return stuffed_arr 24 | 25 | 26 | # Compute video PSNR 27 | def psnr(ref, meas, maxVal=255): 28 | assert np.shape(ref) == np.shape(meas), "Reference image must match measured image dimensions" 29 | 30 | dif = (ref.astype(float)-meas.astype(float)).ravel() 31 | mse = np.linalg.norm(dif)**2/np.prod(np.shape(ref)) 32 | psnr = 10*np.log10(maxVal**2.0/mse) 33 | return psnr 34 | 35 | def bpp(filename): 36 | size = stat(filename).st_size 37 | with open(filename, 'rb') as fh: 38 | soi = fh.read(2) 39 | if soi != SOI_MARKER: 40 | raise Exception("Start of Image marker not found!") 41 | 42 | M = int.from_bytes(fh.read(2), "big") 43 | N = int.from_bytes(fh.read(2), "big") 44 | 45 | return size * 8 / (M * N) 46 | 47 | def comp_ratio(reference, measured): 48 | return stat(reference).st_size / stat(measured).st_size 49 | 50 | 51 | def resize(img, M, N): 52 | return np.array(Image.fromarray(img).resize((N, M), resample=Image.BILINEAR)) 53 | 54 | CONV_MAT = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.4188688, -0.081312]]).T 55 | INV_CONV_MAT = np.linalg.inv(CONV_MAT) 56 | 57 | def RGB2YCbCr(im_rgb): 58 | """ 59 | Input: a 3D float array, im_rgb, representing an RGB image in range [0.0,255.0] 60 | 61 | Output: a 3D float array, im_ycbcr, representing a YCbCr image in range [-128.0,127.0] 62 | """ 63 | 64 | im_ycbcr = np.array([-128, 0, 0]) + im_rgb @ CONV_MAT 65 | im_ycbcr = np.where(im_ycbcr > 127, 127, im_ycbcr) 66 | im_ycbcr = np.where(im_ycbcr < -128, -128, im_ycbcr) 67 | 68 | return im_ycbcr 69 | 70 | def YCbCr2RGB(im_ycbcr): 71 | """ 72 | Input: a 3D float array, im_ycbcr, representing a YCbCr image in range [-128.0,127.0] 73 | 74 | Output: a 3D float array, im_rgb, representing an RGB image in range [0.0,255.0] 75 | """ 76 | 77 | im_rgb = (np.array([128, 0, 0]) + im_ycbcr) @ INV_CONV_MAT 78 | im_rgb = np.where(im_rgb > 255, 255, im_rgb) 79 | im_rgb = np.where(im_rgb < 0, 0, im_rgb) 80 | return im_rgb -------------------------------------------------------------------------------- /zerotree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pywt 3 | from bitarray import bitarray 4 | from utils import bytestuff 5 | 6 | PREFIX_FREE_CODE = { 7 | "T": bitarray('0'), 8 | "Z": bitarray('10'), 9 | "P": bitarray('110'), 10 | "N": bitarray('111') 11 | } 12 | class CoefficientTree: 13 | def __init__(self, value, level, quadrant, loc, children=[]): 14 | self.value = value 15 | self.level = level 16 | self.quadrant = quadrant 17 | self.children = children 18 | self.loc = loc 19 | self.code = None 20 | 21 | def zero_code(self, threshold): 22 | for child in self.children: 23 | child.zero_code(threshold) 24 | 25 | if abs(self.value) >= threshold: 26 | self.code = "P" if self.value > 0 else "N" 27 | else: 28 | self.code = "Z" if any([child.code != "T" for child in self.children]) else "T" 29 | 30 | 31 | @staticmethod 32 | def build_trees(coeffs): 33 | def build_children(level, loc, quadrant): 34 | if level + 1 > len(coeffs): return [] 35 | 36 | i, j = loc 37 | child_locs = [(2*i, 2*j), (2*i, 2*j + 1), (2*i + 1, 2*j), (2*i + 1, 2*j + 1)] 38 | children = [] 39 | for cloc in child_locs: 40 | if cloc[0] >= coeffs[level][quadrant].shape[0] or cloc[1] >= coeffs[level][quadrant].shape[1]: 41 | continue 42 | node = CoefficientTree(coeffs[level][quadrant][cloc], level, quadrant, cloc) 43 | node.children = build_children(level + 1, cloc, quadrant) 44 | children.append(node) 45 | return children 46 | 47 | LL = coeffs[0] 48 | 49 | LL_trees = [] 50 | for i in range(LL.shape[0]): 51 | for j in range(LL.shape[1]): 52 | children = [CoefficientTree(subband[i, j], 1, quad, (i,j), children=build_children(2, (i,j), quad)) 53 | for quad, subband in enumerate(coeffs[1])] 54 | 55 | LL_trees.append(CoefficientTree(LL[i,j], 0, None, (i,j), children=children)) 56 | 57 | return LL_trees 58 | 59 | class ZeroTreeScan(): 60 | def __init__(self, code, isDominant): 61 | self.isDominant = isDominant 62 | self.code = code 63 | self.bits = code if not isDominant else self.code_bits(code) 64 | 65 | def __len__(self): 66 | return len(self.bits) 67 | 68 | def tofile(self, file, padto=16): 69 | bits = self.bits.copy() 70 | 71 | if padto != 0 and len(bits) % padto != 0: 72 | bits.extend([False for _ in range(padto - (len(bits) % padto))]) 73 | 74 | bits = bytestuff(bits) 75 | bits.tofile(file) 76 | 77 | def code_bits(self, code): 78 | bitarr = bitarray() 79 | bitarr.encode(PREFIX_FREE_CODE, code) 80 | return bitarr 81 | 82 | @staticmethod 83 | def from_bits(bits, isDominant): 84 | code = bits.decode(PREFIX_FREE_CODE) if isDominant else bits 85 | return ZeroTreeScan(code, isDominant) 86 | 87 | class ZeroTreeEncoder: 88 | def __init__(self, image, wavelet): 89 | coeffs = pywt.wavedec2(image, wavelet) 90 | coeff_arr, slices = pywt.coeffs_to_array(coeffs) 91 | coeff_arr = np.sign(coeff_arr) * np.floor(np.abs(coeff_arr)) 92 | 93 | coeffs = pywt.array_to_coeffs(coeff_arr, slices, output_format='wavedec2') 94 | 95 | self.trees = CoefficientTree.build_trees(coeffs) 96 | 97 | self.thresh = np.power(2, np.floor(np.log2(np.max(np.abs(coeff_arr))))) 98 | self.start_thresh = self.thresh 99 | 100 | self.secondary_list = [] 101 | self.perform_dominant_pass = True 102 | 103 | def __iter__(self): 104 | return self 105 | 106 | def __next__(self): 107 | if self.thresh <= 0: raise StopIteration 108 | if self.thresh <= 1 and not self.perform_dominant_pass: raise StopIteration 109 | 110 | if self.perform_dominant_pass: 111 | scan, next_coeffs = self.dominant_pass() 112 | 113 | self.secondary_list = np.concatenate((self.secondary_list, next_coeffs)) 114 | 115 | self.perform_dominant_pass = False 116 | return scan 117 | else: 118 | scan = self.secondary_pass() 119 | self.thresh //= 2 120 | self.perform_dominant_pass = True 121 | return scan 122 | 123 | def dominant_pass(self): 124 | sec = [] 125 | 126 | q = [] 127 | for parent in self.trees: 128 | parent.zero_code(self.thresh) 129 | q.append(parent) 130 | 131 | codes = [] 132 | while len(q) != 0: 133 | node = q.pop(0) 134 | codes.append(node.code) 135 | 136 | if node.code != "T": 137 | for child in node.children: 138 | q.append(child) 139 | 140 | if node.code == "P" or node.code == "N": 141 | sec.append(node.value) 142 | node.value = 0 143 | 144 | return ZeroTreeScan(codes, True), np.abs(np.array(sec)) 145 | 146 | def secondary_pass(self): 147 | bits = bitarray() 148 | 149 | middle = self.thresh // 2 150 | for i, coeff in enumerate(self.secondary_list): 151 | if coeff - self.thresh >= 0: 152 | self.secondary_list[i] -= self.thresh 153 | bits.append(self.secondary_list[i] >= middle) 154 | 155 | return ZeroTreeScan(bits, False) 156 | 157 | class ZeroTreeDecoder: 158 | def __init__(self, M, N, start_thres, wavelet): 159 | img = np.zeros((M, N)) 160 | self.wavelet = wavelet 161 | self.coeffs = pywt.wavedec2(img, wavelet) 162 | self.trees = CoefficientTree.build_trees(self.coeffs) 163 | self.T = start_thres 164 | self.processed = [] 165 | 166 | def getImage(self): 167 | return pywt.waverec2(self.coeffs, self.wavelet) 168 | 169 | def process(self, scan): 170 | if scan.isDominant: 171 | self.dominant_pass(scan.code) 172 | else: 173 | self.secondary_pass(scan.code) 174 | 175 | def dominant_pass(self, code_list): 176 | q = [] 177 | for parent in self.trees: 178 | q.append(parent) 179 | 180 | for code in code_list: 181 | if len(q) == 0: 182 | break 183 | node = q.pop(0) 184 | if code != "T": 185 | for child in node.children: 186 | q.append(child) 187 | if code == "P" or code == "N": 188 | node.value = (1 if code == "P" else -1) * self.T 189 | self._fill_coeff(node) 190 | self.processed.append(node) 191 | 192 | def secondary_pass(self, bitarr): 193 | if len(bitarr) != len(self.processed): 194 | bitarr = bitarr[:len(self.processed)] 195 | for bit, node in zip(bitarr, self.processed): 196 | if bit: 197 | node.value += (1 if node.value > 0 else -1) * self.T // 2 198 | self._fill_coeff(node) 199 | 200 | self.T //= 2 201 | 202 | def _fill_coeff(self, node): 203 | if node.quadrant is not None: 204 | self.coeffs[node.level][node.quadrant][node.loc] = node.value 205 | else: 206 | self.coeffs[node.level][node.loc] = node.value --------------------------------------------------------------------------------