├── .gitignore ├── decoder.py ├── encoder.py ├── huffman.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | data/ 3 | *.pyc 4 | -------------------------------------------------------------------------------- /decoder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import numpy as np 4 | from utils import * 5 | from scipy import fftpack 6 | from PIL import Image 7 | 8 | 9 | class JPEGFileReader: 10 | TABLE_SIZE_BITS = 16 11 | BLOCKS_COUNT_BITS = 32 12 | 13 | DC_CODE_LENGTH_BITS = 4 14 | CATEGORY_BITS = 4 15 | 16 | AC_CODE_LENGTH_BITS = 8 17 | RUN_LENGTH_BITS = 4 18 | SIZE_BITS = 4 19 | 20 | def __init__(self, filepath): 21 | self.__file = open(filepath, 'r') 22 | 23 | def read_int(self, size): 24 | if size == 0: 25 | return 0 26 | 27 | # the most significant bit indicates the sign of the number 28 | bin_num = self.__read_str(size) 29 | if bin_num[0] == '1': 30 | return self.__int2(bin_num) 31 | else: 32 | return self.__int2(binstr_flip(bin_num)) * -1 33 | 34 | def read_dc_table(self): 35 | table = dict() 36 | 37 | table_size = self.__read_uint(self.TABLE_SIZE_BITS) 38 | for _ in range(table_size): 39 | category = self.__read_uint(self.CATEGORY_BITS) 40 | code_length = self.__read_uint(self.DC_CODE_LENGTH_BITS) 41 | code = self.__read_str(code_length) 42 | table[code] = category 43 | return table 44 | 45 | def read_ac_table(self): 46 | table = dict() 47 | 48 | table_size = self.__read_uint(self.TABLE_SIZE_BITS) 49 | for _ in range(table_size): 50 | run_length = self.__read_uint(self.RUN_LENGTH_BITS) 51 | size = self.__read_uint(self.SIZE_BITS) 52 | code_length = self.__read_uint(self.AC_CODE_LENGTH_BITS) 53 | code = self.__read_str(code_length) 54 | table[code] = (run_length, size) 55 | return table 56 | 57 | def read_blocks_count(self): 58 | return self.__read_uint(self.BLOCKS_COUNT_BITS) 59 | 60 | def read_huffman_code(self, table): 61 | prefix = '' 62 | # TODO: break the loop if __read_char is not returing new char 63 | while prefix not in table: 64 | prefix += self.__read_char() 65 | return table[prefix] 66 | 67 | def __read_uint(self, size): 68 | if size <= 0: 69 | raise ValueError("size of unsigned int should be greater than 0") 70 | return self.__int2(self.__read_str(size)) 71 | 72 | def __read_str(self, length): 73 | return self.__file.read(length) 74 | 75 | def __read_char(self): 76 | return self.__read_str(1) 77 | 78 | def __int2(self, bin_num): 79 | return int(bin_num, 2) 80 | 81 | 82 | def read_image_file(filepath): 83 | reader = JPEGFileReader(filepath) 84 | 85 | tables = dict() 86 | for table_name in ['dc_y', 'ac_y', 'dc_c', 'ac_c']: 87 | if 'dc' in table_name: 88 | tables[table_name] = reader.read_dc_table() 89 | else: 90 | tables[table_name] = reader.read_ac_table() 91 | 92 | blocks_count = reader.read_blocks_count() 93 | 94 | dc = np.empty((blocks_count, 3), dtype=np.int32) 95 | ac = np.empty((blocks_count, 63, 3), dtype=np.int32) 96 | 97 | for block_index in range(blocks_count): 98 | for component in range(3): 99 | dc_table = tables['dc_y'] if component == 0 else tables['dc_c'] 100 | ac_table = tables['ac_y'] if component == 0 else tables['ac_c'] 101 | 102 | category = reader.read_huffman_code(dc_table) 103 | dc[block_index, component] = reader.read_int(category) 104 | 105 | cells_count = 0 106 | 107 | # TODO: try to make reading AC coefficients better 108 | while cells_count < 63: 109 | run_length, size = reader.read_huffman_code(ac_table) 110 | 111 | if (run_length, size) == (0, 0): 112 | while cells_count < 63: 113 | ac[block_index, cells_count, component] = 0 114 | cells_count += 1 115 | else: 116 | for i in range(run_length): 117 | ac[block_index, cells_count, component] = 0 118 | cells_count += 1 119 | if size == 0: 120 | ac[block_index, cells_count, component] = 0 121 | else: 122 | value = reader.read_int(size) 123 | ac[block_index, cells_count, component] = value 124 | cells_count += 1 125 | 126 | return dc, ac, tables, blocks_count 127 | 128 | 129 | def zigzag_to_block(zigzag): 130 | # assuming that the width and the height of the block are equal 131 | rows = cols = int(math.sqrt(len(zigzag))) 132 | 133 | if rows * cols != len(zigzag): 134 | raise ValueError("length of zigzag should be a perfect square") 135 | 136 | block = np.empty((rows, cols), np.int32) 137 | 138 | for i, point in enumerate(zigzag_points(rows, cols)): 139 | block[point] = zigzag[i] 140 | 141 | return block 142 | 143 | 144 | def dequantize(block, component): 145 | q = load_quantization_table(component) 146 | return block * q 147 | 148 | 149 | def idct_2d(image): 150 | return fftpack.idct(fftpack.idct(image.T, norm='ortho').T, norm='ortho') 151 | 152 | 153 | def main(): 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument("input", help="path to the input image") 156 | args = parser.parse_args() 157 | 158 | dc, ac, tables, blocks_count = read_image_file(args.input) 159 | 160 | # assuming that the block is a 8x8 square 161 | block_side = 8 162 | 163 | # assuming that the image height and width are equal 164 | image_side = int(math.sqrt(blocks_count)) * block_side 165 | 166 | blocks_per_line = image_side // block_side 167 | 168 | npmat = np.empty((image_side, image_side, 3), dtype=np.uint8) 169 | 170 | for block_index in range(blocks_count): 171 | i = block_index // blocks_per_line * block_side 172 | j = block_index % blocks_per_line * block_side 173 | 174 | for c in range(3): 175 | zigzag = [dc[block_index, c]] + list(ac[block_index, :, c]) 176 | quant_matrix = zigzag_to_block(zigzag) 177 | dct_matrix = dequantize(quant_matrix, 'lum' if c == 0 else 'chrom') 178 | block = idct_2d(dct_matrix) 179 | npmat[i:i+8, j:j+8, c] = block + 128 180 | 181 | image = Image.fromarray(npmat, 'YCbCr') 182 | image = image.convert('RGB') 183 | image.show() 184 | 185 | 186 | if __name__ == "__main__": 187 | main() 188 | -------------------------------------------------------------------------------- /encoder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import math 4 | import numpy as np 5 | from utils import * 6 | from scipy import fftpack 7 | from PIL import Image 8 | from huffman import HuffmanTree 9 | 10 | 11 | def quantize(block, component): 12 | q = load_quantization_table(component) 13 | return (block / q).round().astype(np.int32) 14 | 15 | 16 | def block_to_zigzag(block): 17 | return np.array([block[point] for point in zigzag_points(*block.shape)]) 18 | 19 | 20 | def dct_2d(image): 21 | return fftpack.dct(fftpack.dct(image.T, norm='ortho').T, norm='ortho') 22 | 23 | 24 | def run_length_encode(arr): 25 | # determine where the sequence is ending prematurely 26 | last_nonzero = -1 27 | for i, elem in enumerate(arr): 28 | if elem != 0: 29 | last_nonzero = i 30 | 31 | # each symbol is a (RUNLENGTH, SIZE) tuple 32 | symbols = [] 33 | 34 | # values are binary representations of array elements using SIZE bits 35 | values = [] 36 | 37 | run_length = 0 38 | 39 | for i, elem in enumerate(arr): 40 | if i > last_nonzero: 41 | symbols.append((0, 0)) 42 | values.append(int_to_binstr(0)) 43 | break 44 | elif elem == 0 and run_length < 15: 45 | run_length += 1 46 | else: 47 | size = bits_required(elem) 48 | symbols.append((run_length, size)) 49 | values.append(int_to_binstr(elem)) 50 | run_length = 0 51 | return symbols, values 52 | 53 | 54 | def write_to_file(filepath, dc, ac, blocks_count, tables): 55 | try: 56 | f = open(filepath, 'w') 57 | except FileNotFoundError as e: 58 | raise FileNotFoundError( 59 | "No such directory: {}".format( 60 | os.path.dirname(filepath))) from e 61 | 62 | for table_name in ['dc_y', 'ac_y', 'dc_c', 'ac_c']: 63 | 64 | # 16 bits for 'table_size' 65 | f.write(uint_to_binstr(len(tables[table_name]), 16)) 66 | 67 | for key, value in tables[table_name].items(): 68 | if table_name in {'dc_y', 'dc_c'}: 69 | # 4 bits for the 'category' 70 | # 4 bits for 'code_length' 71 | # 'code_length' bits for 'huffman_code' 72 | f.write(uint_to_binstr(key, 4)) 73 | f.write(uint_to_binstr(len(value), 4)) 74 | f.write(value) 75 | else: 76 | # 4 bits for 'run_length' 77 | # 4 bits for 'size' 78 | # 8 bits for 'code_length' 79 | # 'code_length' bits for 'huffman_code' 80 | f.write(uint_to_binstr(key[0], 4)) 81 | f.write(uint_to_binstr(key[1], 4)) 82 | f.write(uint_to_binstr(len(value), 8)) 83 | f.write(value) 84 | 85 | # 32 bits for 'blocks_count' 86 | f.write(uint_to_binstr(blocks_count, 32)) 87 | 88 | for b in range(blocks_count): 89 | for c in range(3): 90 | category = bits_required(dc[b, c]) 91 | symbols, values = run_length_encode(ac[b, :, c]) 92 | 93 | dc_table = tables['dc_y'] if c == 0 else tables['dc_c'] 94 | ac_table = tables['ac_y'] if c == 0 else tables['ac_c'] 95 | 96 | f.write(dc_table[category]) 97 | f.write(int_to_binstr(dc[b, c])) 98 | 99 | for i in range(len(symbols)): 100 | f.write(ac_table[tuple(symbols[i])]) 101 | f.write(values[i]) 102 | f.close() 103 | 104 | 105 | def main(): 106 | parser = argparse.ArgumentParser() 107 | parser.add_argument("input", help="path to the input image") 108 | parser.add_argument("output", help="path to the output image") 109 | args = parser.parse_args() 110 | 111 | input_file = args.input 112 | output_file = args.output 113 | 114 | image = Image.open(input_file) 115 | ycbcr = image.convert('YCbCr') 116 | 117 | npmat = np.array(ycbcr, dtype=np.uint8) 118 | 119 | rows, cols = npmat.shape[0], npmat.shape[1] 120 | 121 | # block size: 8x8 122 | if rows % 8 == cols % 8 == 0: 123 | blocks_count = rows // 8 * cols // 8 124 | else: 125 | raise ValueError(("the width and height of the image " 126 | "should both be mutiples of 8")) 127 | 128 | # dc is the top-left cell of the block, ac are all the other cells 129 | dc = np.empty((blocks_count, 3), dtype=np.int32) 130 | ac = np.empty((blocks_count, 63, 3), dtype=np.int32) 131 | 132 | for i in range(0, rows, 8): 133 | for j in range(0, cols, 8): 134 | try: 135 | block_index += 1 136 | except NameError: 137 | block_index = 0 138 | 139 | for k in range(3): 140 | # split 8x8 block and center the data range on zero 141 | # [0, 255] --> [-128, 127] 142 | block = npmat[i:i+8, j:j+8, k] - 128 143 | 144 | dct_matrix = dct_2d(block) 145 | quant_matrix = quantize(dct_matrix, 146 | 'lum' if k == 0 else 'chrom') 147 | zz = block_to_zigzag(quant_matrix) 148 | 149 | dc[block_index, k] = zz[0] 150 | ac[block_index, :, k] = zz[1:] 151 | 152 | H_DC_Y = HuffmanTree(np.vectorize(bits_required)(dc[:, 0])) 153 | H_DC_C = HuffmanTree(np.vectorize(bits_required)(dc[:, 1:].flat)) 154 | H_AC_Y = HuffmanTree( 155 | flatten(run_length_encode(ac[i, :, 0])[0] 156 | for i in range(blocks_count))) 157 | H_AC_C = HuffmanTree( 158 | flatten(run_length_encode(ac[i, :, j])[0] 159 | for i in range(blocks_count) for j in [1, 2])) 160 | 161 | tables = {'dc_y': H_DC_Y.value_to_bitstring_table(), 162 | 'ac_y': H_AC_Y.value_to_bitstring_table(), 163 | 'dc_c': H_DC_C.value_to_bitstring_table(), 164 | 'ac_c': H_AC_C.value_to_bitstring_table()} 165 | 166 | write_to_file(output_file, dc, ac, blocks_count, tables) 167 | 168 | 169 | if __name__ == "__main__": 170 | main() 171 | -------------------------------------------------------------------------------- /huffman.py: -------------------------------------------------------------------------------- 1 | from queue import PriorityQueue 2 | 3 | 4 | class HuffmanTree: 5 | 6 | class __Node: 7 | def __init__(self, value, freq, left_child, right_child): 8 | self.value = value 9 | self.freq = freq 10 | self.left_child = left_child 11 | self.right_child = right_child 12 | 13 | @classmethod 14 | def init_leaf(self, value, freq): 15 | return self(value, freq, None, None) 16 | 17 | @classmethod 18 | def init_node(self, left_child, right_child): 19 | freq = left_child.freq + right_child.freq 20 | return self(None, freq, left_child, right_child) 21 | 22 | def is_leaf(self): 23 | return self.value is not None 24 | 25 | def __eq__(self, other): 26 | stup = self.value, self.freq, self.left_child, self.right_child 27 | otup = other.value, other.freq, other.left_child, other.right_child 28 | return stup == otup 29 | 30 | def __nq__(self, other): 31 | return not (self == other) 32 | 33 | def __lt__(self, other): 34 | return self.freq < other.freq 35 | 36 | def __le__(self, other): 37 | return self.freq < other.freq or self.freq == other.freq 38 | 39 | def __gt__(self, other): 40 | return not (self <= other) 41 | 42 | def __ge__(self, other): 43 | return not (self < other) 44 | 45 | def __init__(self, arr): 46 | q = PriorityQueue() 47 | 48 | # calculate frequencies and insert them into a priority queue 49 | for val, freq in self.__calc_freq(arr).items(): 50 | q.put(self.__Node.init_leaf(val, freq)) 51 | 52 | while q.qsize() >= 2: 53 | u = q.get() 54 | v = q.get() 55 | 56 | q.put(self.__Node.init_node(u, v)) 57 | 58 | self.__root = q.get() 59 | 60 | # dictionaries to store huffman table 61 | self.__value_to_bitstring = dict() 62 | 63 | def value_to_bitstring_table(self): 64 | if len(self.__value_to_bitstring.keys()) == 0: 65 | self.__create_huffman_table() 66 | return self.__value_to_bitstring 67 | 68 | def __create_huffman_table(self): 69 | def tree_traverse(current_node, bitstring=''): 70 | if current_node is None: 71 | return 72 | if current_node.is_leaf(): 73 | self.__value_to_bitstring[current_node.value] = bitstring 74 | return 75 | tree_traverse(current_node.left_child, bitstring + '0') 76 | tree_traverse(current_node.right_child, bitstring + '1') 77 | 78 | tree_traverse(self.__root) 79 | 80 | def __calc_freq(self, arr): 81 | freq_dict = dict() 82 | for elem in arr: 83 | if elem in freq_dict: 84 | freq_dict[elem] += 1 85 | else: 86 | freq_dict[elem] = 1 87 | return freq_dict 88 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def load_quantization_table(component): 5 | # Quantization Table for: Photoshop - (Save For Web 080) 6 | # (http://www.impulseadventure.com/photo/jpeg-quantization.html) 7 | if component == 'lum': 8 | q = np.array([[2, 2, 2, 2, 3, 4, 5, 6], 9 | [2, 2, 2, 2, 3, 4, 5, 6], 10 | [2, 2, 2, 2, 4, 5, 7, 9], 11 | [2, 2, 2, 4, 5, 7, 9, 12], 12 | [3, 3, 4, 5, 8, 10, 12, 12], 13 | [4, 4, 5, 7, 10, 12, 12, 12], 14 | [5, 5, 7, 9, 12, 12, 12, 12], 15 | [6, 6, 9, 12, 12, 12, 12, 12]]) 16 | elif component == 'chrom': 17 | q = np.array([[3, 3, 5, 9, 13, 15, 15, 15], 18 | [3, 4, 6, 11, 14, 12, 12, 12], 19 | [5, 6, 9, 14, 12, 12, 12, 12], 20 | [9, 11, 14, 12, 12, 12, 12, 12], 21 | [13, 14, 12, 12, 12, 12, 12, 12], 22 | [15, 12, 12, 12, 12, 12, 12, 12], 23 | [15, 12, 12, 12, 12, 12, 12, 12], 24 | [15, 12, 12, 12, 12, 12, 12, 12]]) 25 | else: 26 | raise ValueError(( 27 | "component should be either 'lum' or 'chrom', " 28 | "but '{comp}' was found").format(comp=component)) 29 | 30 | return q 31 | 32 | 33 | def zigzag_points(rows, cols): 34 | # constants for directions 35 | UP, DOWN, RIGHT, LEFT, UP_RIGHT, DOWN_LEFT = range(6) 36 | 37 | # move the point in different directions 38 | def move(direction, point): 39 | return { 40 | UP: lambda point: (point[0] - 1, point[1]), 41 | DOWN: lambda point: (point[0] + 1, point[1]), 42 | LEFT: lambda point: (point[0], point[1] - 1), 43 | RIGHT: lambda point: (point[0], point[1] + 1), 44 | UP_RIGHT: lambda point: move(UP, move(RIGHT, point)), 45 | DOWN_LEFT: lambda point: move(DOWN, move(LEFT, point)) 46 | }[direction](point) 47 | 48 | # return true if point is inside the block bounds 49 | def inbounds(point): 50 | return 0 <= point[0] < rows and 0 <= point[1] < cols 51 | 52 | # start in the top-left cell 53 | point = (0, 0) 54 | 55 | # True when moving up-right, False when moving down-left 56 | move_up = True 57 | 58 | for i in range(rows * cols): 59 | yield point 60 | if move_up: 61 | if inbounds(move(UP_RIGHT, point)): 62 | point = move(UP_RIGHT, point) 63 | else: 64 | move_up = False 65 | if inbounds(move(RIGHT, point)): 66 | point = move(RIGHT, point) 67 | else: 68 | point = move(DOWN, point) 69 | else: 70 | if inbounds(move(DOWN_LEFT, point)): 71 | point = move(DOWN_LEFT, point) 72 | else: 73 | move_up = True 74 | if inbounds(move(DOWN, point)): 75 | point = move(DOWN, point) 76 | else: 77 | point = move(RIGHT, point) 78 | 79 | 80 | def bits_required(n): 81 | n = abs(n) 82 | result = 0 83 | while n > 0: 84 | n >>= 1 85 | result += 1 86 | return result 87 | 88 | 89 | def binstr_flip(binstr): 90 | # check if binstr is a binary string 91 | if not set(binstr).issubset('01'): 92 | raise ValueError("binstr should have only '0's and '1's") 93 | return ''.join(map(lambda c: '0' if c == '1' else '1', binstr)) 94 | 95 | 96 | def uint_to_binstr(number, size): 97 | return bin(number)[2:][-size:].zfill(size) 98 | 99 | 100 | def int_to_binstr(n): 101 | if n == 0: 102 | return '' 103 | 104 | binstr = bin(abs(n))[2:] 105 | 106 | # change every 0 to 1 and vice verse when n is negative 107 | return binstr if n > 0 else binstr_flip(binstr) 108 | 109 | 110 | def flatten(lst): 111 | return [item for sublist in lst for item in sublist] 112 | --------------------------------------------------------------------------------