├── README.md ├── LICENSE ├── .gitignore └── main.py /README.md: -------------------------------------------------------------------------------- 1 | # Differentiable SHA256 2 | A fully differentiable implementation of SHA256. Unfortunately suffers from gradient vanishing. 3 | 4 | Update: an alternative implementation without the gradient vanishing issue was implemented (see adder branch). The gradients are no longer zero, however the surface of the output hash is rife with local minima, likely because of the avalanche effect. It appears that SHA256 is secure from gradient descent pre-image attacks. 5 | 6 | # Dependencies 7 | * PyTorch 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Caleb Helbling 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import hashlib 4 | import math 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | def numerically_stable_sigmoid(x): 9 | pos_mask = x >= 0 10 | neg_mask = ~pos_mask 11 | z = torch.zeros_like(x) 12 | z[pos_mask] = torch.exp(-x[pos_mask]) 13 | z[neg_mask] = torch.exp(x[neg_mask]) 14 | top = torch.ones_like(x) 15 | top[neg_mask] = z[neg_mask] 16 | return top / (1.0 + z) 17 | 18 | def and_bit(a, b): 19 | return a * b 20 | 21 | def xor_bit(a, b): 22 | return 1.0 - (1.0 - a + a * a * b) * (1.0 - b + a * b * b) 23 | 24 | def or_bit(a, b): 25 | return 1.0 - (1.0 - a * a) * (1.0 - b * b) 26 | 27 | def not_bit(a): 28 | return 1.0 - a 29 | 30 | def xor(a, b): 31 | return [xor_bit(a_bit, b_bit) for (a_bit, b_bit) in zip(a, b)] 32 | 33 | def and_(a, b): 34 | return [and_bit(a_bit, b_bit) for (a_bit, b_bit) in zip(a, b)] 35 | 36 | def add(a, b): 37 | sum_0 = xor_bit(a[-1], b[-1]) 38 | carry = and_bit(a[-1], b[-1]) 39 | def full_adder(a_bit, b_bit, carry_bit): 40 | r1 = xor_bit(a_bit, b_bit) 41 | sum = xor_bit(r1, carry_bit) 42 | r2 = and_bit(r1, carry_bit) 43 | r3 = and_bit(a_bit, b_bit) 44 | carry_bit_out = or_bit(r2, r3) 45 | return (sum, carry_bit_out) 46 | 47 | ret = [sum_0] 48 | for (a_bit, b_bit) in list(reversed(list(zip(a, b))))[1:]: 49 | (sum_i, carry) = full_adder(a_bit, b_bit, carry) 50 | ret.insert(0, sum_i) 51 | 52 | return ret 53 | 54 | def add_num(a, b, num_bits): 55 | sum = a + b 56 | cond = numerically_stable_sigmoid(sum - 2 ** num_bits) 57 | return cond * (sum - 2 ** num_bits) + (1.0 - cond) * sum 58 | 59 | def add_num32(a, b): 60 | return add_num(a, b, 32) 61 | 62 | def num_to_bits_differentiable(x, num_bits): 63 | ret = [] 64 | y = x / (2 ** (num_bits - 1)) 65 | for _ in range(num_bits): 66 | bit = numerically_stable_sigmoid(50.0 * (y - 1.0)) 67 | remainder = y - bit 68 | y = remainder * 2.0 69 | ret.append(bit) 70 | return ret 71 | 72 | def triangle_wave(x, delta=0.01): 73 | return 1.0 - 2.0 * torch.acos((1.0 - delta) * torch.sin(2.0 * math.pi * x)) / math.pi 74 | 75 | def square_wave(x, delta=0.01): 76 | return 2.0 * torch.atan(torch.sin(2.0 * math.pi * x) / delta) / math.pi 77 | 78 | def sawtooth_wave(x, delta=0.01): 79 | return (1 + triangle_wave((2.0 * x - 1.0) / 4.0, delta=delta) * square_wave(x / 2.0, delta=delta)) / 2.0 80 | 81 | def floor_(x, delta=0.01): 82 | return x - sawtooth_wave(x, delta=delta) 83 | 84 | def mod_(x, n, delta=0.01): 85 | return n * sawtooth_wave(x / n, delta=delta) 86 | 87 | def mod2_alt(x): 88 | return (torch.sin((x - 0.5) * math.pi) + 1) / 2.0 89 | 90 | def num_to_bits_differentiable2(x, num_bits): 91 | ret = [] 92 | for i in reversed(range(num_bits)): 93 | y = x / (2 ** i) 94 | right_shifted = torch.max(torch.tensor(0.0), floor_(y)) 95 | bit = mod2_alt(right_shifted) 96 | ret.append(bit) 97 | return ret 98 | 99 | def not_(a): 100 | return [not_bit(a_bit) for a_bit in a] 101 | 102 | def right_rotate(a, c): 103 | return a[-c:] + a[0:-c] 104 | 105 | def right_shift(a, c): 106 | return ([0.0] * c) + a[0:-c] 107 | 108 | def bits_to_num(a): 109 | total = a[-1] 110 | multiplier = 2.0 111 | for bit in list(reversed(a))[1:]: 112 | total = total + bit * multiplier 113 | multiplier *= 2.0 114 | return total 115 | 116 | def left_shift(a, c): 117 | return a + [0.0] * c 118 | 119 | def num_to_8_bits(num): 120 | s = "{0:{fill}8b}".format(num, fill='0') 121 | return [1 if c == '1' else 0 for c in s] 122 | 123 | def num_to_32_bits(num): 124 | s = "{0:{fill}32b}".format(num, fill='0') 125 | return [1 if c == '1' else 0 for c in s] 126 | 127 | def num_to_64_bits(num): 128 | s = "{0:{fill}64b}".format(num, fill='0') 129 | return [1 if c == '1' else 0 for c in s] 130 | 131 | def flatten(lst_of_lsts): 132 | return [x for lst in lst_of_lsts for x in lst] 133 | 134 | # Initialize hash values 135 | h0n = 0x6a09e667 136 | h1n = 0xbb67ae85 137 | h2n = 0x3c6ef372 138 | h3n = 0xa54ff53a 139 | h4n = 0x510e527f 140 | h5n = 0x9b05688c 141 | h6n = 0x1f83d9ab 142 | h7n = 0x5be0cd19 143 | 144 | # Initialize array of round constants 145 | 146 | kn = [ 147 | 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, 148 | 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 149 | 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 150 | 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 151 | 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 152 | 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, 153 | 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, 154 | 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2] 155 | 156 | parallel_tries = 1 157 | 158 | parallel_zeros = torch.tensor([0.0]).expand(parallel_tries) 159 | parallel_ones = torch.tensor([1.0]).expand(parallel_tries) 160 | 161 | def expand_constant(n): 162 | #return [parallel_zeros if bit == 0 else parallel_ones for bit in n] 163 | return [torch.tensor([0.0]).expand(parallel_tries) if bit == 0 else torch.tensor([1.0]).expand(parallel_tries) for bit in n] 164 | 165 | def chunks(lst, n): 166 | """Yield successive n-sized chunks from lst.""" 167 | for i in range(0, len(lst), n): 168 | yield lst[i:i + n] 169 | 170 | def sha256(message): 171 | """ 172 | h0 = expand_constant(num_to_32_bits(h0n)) 173 | h1 = expand_constant(num_to_32_bits(h1n)) 174 | h2 = expand_constant(num_to_32_bits(h2n)) 175 | h3 = expand_constant(num_to_32_bits(h3n)) 176 | h4 = expand_constant(num_to_32_bits(h4n)) 177 | h5 = expand_constant(num_to_32_bits(h5n)) 178 | h6 = expand_constant(num_to_32_bits(h6n)) 179 | h7 = expand_constant(num_to_32_bits(h7n)) 180 | """ 181 | h0 = torch.tensor(float(h0n)).expand(parallel_tries) 182 | h1 = torch.tensor(float(h1n)).expand(parallel_tries) 183 | h2 = torch.tensor(float(h2n)).expand(parallel_tries) 184 | h3 = torch.tensor(float(h3n)).expand(parallel_tries) 185 | h4 = torch.tensor(float(h4n)).expand(parallel_tries) 186 | h5 = torch.tensor(float(h5n)).expand(parallel_tries) 187 | h6 = torch.tensor(float(h6n)).expand(parallel_tries) 188 | h7 = torch.tensor(float(h7n)).expand(parallel_tries) 189 | 190 | k = [expand_constant(num_to_32_bits(n)) for n in kn] 191 | 192 | # Pre-processing (Padding) 193 | # begin with the original message of length L bits 194 | L = len(message) 195 | 196 | # append a single '1' bit 197 | message = message + [parallel_ones] 198 | # append K '0' bits, where K is the minimum number >= 0 such that L + 1 + K + 64 is a multiple of 512 199 | K = 0 200 | while True: 201 | if (L + 1 + K + 64) % 512 == 0: 202 | break 203 | K += 1 204 | 205 | message = message + [parallel_zeros] * K 206 | 207 | # append L as a 64-bit big-endian integer, making the total post-processed length a multiple of 512 bits 208 | message = message + [torch.tensor(bit, dtype=torch.float) for bit in num_to_64_bits(L)] 209 | 210 | for chunk in chunks(message, 512): 211 | w = list(chunks(chunk, 32)) + [None] * (64 - 16) 212 | w_num = [bits_to_num(n) for n in chunks(chunk, 32)] + [None] * (64 - 16) 213 | for i in range(16, 64): 214 | # s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift 3) 215 | s0 = xor(xor(right_rotate(w[i - 15], 7), right_rotate(w[i - 15], 18)), right_shift(w[i - 15], 3)) 216 | # s1 := (w[i- 2] rightrotate 17) xor (w[i- 2] rightrotate 19) xor (w[i- 2] rightshift 10) 217 | s1 = xor(xor(right_rotate(w[i - 2], 17), right_rotate(w[i - 2], 19)), right_shift(w[i - 2], 10)) 218 | # w[i] := w[i-16] + s0 + w[i-7] + s1 219 | w_num[i] = add_num32(add_num32(add_num32(w_num[i-16], bits_to_num(s0)), w_num[i-7]), bits_to_num(s1)) 220 | w[i] = num_to_bits_differentiable2(w_num[i], 32) 221 | #w[i] = add(add(add(w[i - 16], s0), w[i - 7]), s1) 222 | 223 | # Initialize working variables to current hash value: 224 | a = num_to_bits_differentiable2(h0, 32) 225 | b = num_to_bits_differentiable2(h1, 32) 226 | c = num_to_bits_differentiable2(h2, 32) 227 | d = num_to_bits_differentiable2(h3, 32) 228 | e = num_to_bits_differentiable2(h4, 32) 229 | f = num_to_bits_differentiable2(h5, 32) 230 | g = num_to_bits_differentiable2(h6, 32) 231 | h = num_to_bits_differentiable2(h7, 32) 232 | 233 | # Compression function main loop: 234 | for i in range(0, 64): 235 | # S1 := (e rightrotate 6) xor (e rightrotate 11) xor (e rightrotate 25) 236 | S1 = xor(xor(right_rotate(e, 6), right_rotate(e, 11)), right_rotate(e, 25)) 237 | # ch := (e and f) xor ((not e) and g) 238 | ch = xor(and_(e, f), and_(not_(e), g)) 239 | # temp1 := h + S1 + ch + k[i] + w[i] 240 | #temp1 = add(add(add(add(h, S1), ch), k[i]), w[i]) 241 | temp1 = add_num32(add_num32(add_num32(add_num32(bits_to_num(h), bits_to_num(S1)), bits_to_num(ch)), bits_to_num(k[i])), w_num[i]) 242 | # S0 := (a rightrotate 2) xor (a rightrotate 13) xor (a rightrotate 22) 243 | S0 = xor(xor(right_rotate(a, 2), right_rotate(a, 13)), right_rotate(a, 22)) 244 | # maj := (a and b) xor (a and c) xor (b and c) 245 | maj = xor(xor(and_(a, b), and_(a, c)), and_(b, c)) 246 | # temp2 := S0 + maj 247 | #temp2 = add(S0, maj) 248 | temp2 = add_num32(bits_to_num(S0), bits_to_num(maj)) 249 | 250 | h = g 251 | g = f 252 | f = e 253 | #e = add(d, temp1) 254 | e = num_to_bits_differentiable2(add_num32(bits_to_num(d), temp1), 32) 255 | d = c 256 | c = b 257 | b = a 258 | #a = add(temp1, temp2) 259 | a = num_to_bits_differentiable2(add_num32(temp1, temp2), 32) 260 | 261 | # Add the compressed chunk to the current hash value: 262 | h0 = add_num32(h0, bits_to_num(a)) 263 | h1 = add_num32(h1, bits_to_num(b)) 264 | h2 = add_num32(h2, bits_to_num(c)) 265 | h3 = add_num32(h3, bits_to_num(d)) 266 | h4 = add_num32(h4, bits_to_num(e)) 267 | h5 = add_num32(h5, bits_to_num(f)) 268 | h6 = add_num32(h6, bits_to_num(g)) 269 | h7 = add_num32(h7, bits_to_num(h)) 270 | 271 | return [h0, h1, h2, h3, h4, h5, h6, h7] 272 | 273 | hello_world = "Hello world" 274 | hello_world_hash = 0x64ec88ca00b268e5ba1a35678a1b5316d212f4f366b2477232534a8aeca37f3c 275 | hello_world_ascii = [ord(c) for c in hello_world] 276 | hello_world_binary = flatten([num_to_8_bits(c) for c in hello_world_ascii]) 277 | assert(len(hello_world_binary) % 8 == 0) 278 | hello_world_torch = expand_constant(hello_world_binary) 279 | 280 | samples = 10 281 | x = np.linspace(0, 1, samples) 282 | y = np.linspace(0, 1, samples) 283 | 284 | z_output = np.zeros((samples, samples)) 285 | 286 | for j in range(samples): 287 | for i in range(samples): 288 | hello_world_torch[0][0] = x[i] 289 | hello_world_torch[1][0] = y[j] 290 | hello_world_digest = sha256(hello_world_torch) 291 | 292 | #leading_zeros = 32 293 | #hw_leading_digest = hello_world_digest[0:leading_zeros] 294 | #hw_leading_float = float(bits_to_num(hw_leading_digest)[0]) 295 | hw_leading_float = float(hello_world_digest[0]) 296 | 297 | z_output[i, j] = float(hello_world_digest[0]) 298 | 299 | print("{" + str(x[i]) + "," + str(y[j]) + "," + str(hw_leading_float) + "},") 300 | 301 | print(z_output) 302 | 303 | input("Continue?") 304 | 305 | x_grid, y_grid = np.meshgrid(x, y) 306 | 307 | fig = plt.figure() 308 | ax = plt.axes(projection='3d') 309 | ax.contour3D(x_grid, y_grid, z_output, 50, cmap='binary') 310 | ax.set_xlabel('x') 311 | ax.set_ylabel('y') 312 | ax.set_zlabel('z') 313 | 314 | plt.show() 315 | 316 | input("Continue?") 317 | 318 | 319 | 320 | bits = [round(float(bit[0])) for bit in hello_world_digest] 321 | bits_string = ''.join(['0' if b == 0 else '1' for b in bits]) 322 | digest_as_int = int(bits_string, 2) 323 | digest_as_hex = hex(digest_as_int) 324 | 325 | assert(digest_as_int == hello_world_hash) 326 | 327 | leading_zeros = 32 328 | 329 | hw_leading_digest = hello_world_digest[0:leading_zeros] 330 | hw_leading_float = float(bits_to_num(hw_leading_digest)[0]) 331 | 332 | #raw_nonce = [torch.randn(parallel_tries, requires_grad=True) for _ in range(32)] 333 | #for n in raw_nonce: 334 | #n[0] = 0.5 335 | 336 | raw_nonce = [torch.zeros(parallel_tries, requires_grad=True) for _ in range(32)] 337 | 338 | prev_hash_py = [random.choice([0, 1]) for _ in range(256)] 339 | prev_hash = expand_constant(prev_hash_py) 340 | 341 | optimizer = torch.optim.SGD(raw_nonce, lr=0.1, momentum=0.9) 342 | 343 | def normalize(input, dim): 344 | norm = torch.norm(input, dim=dim, keepdim=True) 345 | norm_expanded = norm.expand_as(input) 346 | return input / norm_expanded 347 | 348 | while True: 349 | nonce = [torch.sigmoid(bit) for bit in raw_nonce] 350 | message0 = nonce + prev_hash 351 | 352 | optimizer.zero_grad() 353 | #digest = sha256(sha256(message0)) 354 | digest = sha256(message0) 355 | leading_digest = digest[0:leading_zeros] 356 | leading_float = bits_to_num(leading_digest) 357 | total = torch.sum(leading_float) 358 | total.backward() 359 | 360 | print([r.grad for r in raw_nonce]) 361 | print(total) 362 | print(leading_float) 363 | print(nonce) 364 | 365 | #optimizer.step() 366 | 367 | # Normalize the gradient 368 | stacked_grad = torch.stack([r.grad for r in raw_nonce]) 369 | normalized_stacked_grad = normalize(stacked_grad, dim=0) 370 | 371 | for (i, raw_nonce_bit) in enumerate(raw_nonce): 372 | raw_nonce_bit.data -= normalized_stacked_grad[i] * 0.1 373 | 374 | if input("Stop computation?") == "y": 375 | break 376 | 377 | def bitstring_to_bytes(s): 378 | return int(s, 2).to_bytes(len(s) // 8, byteorder='big') 379 | 380 | nonce = [torch.sigmoid(bit) for bit in raw_nonce] 381 | 382 | for i in range(parallel_tries): 383 | bits_py = [] 384 | for bit in nonce: 385 | b = round(float(bit[i])) 386 | bits_py.append(b) 387 | computed_message = bits_py + prev_hash_py 388 | computed_message_str = ''.join([str(bit) for bit in computed_message]) 389 | message_bytes = bitstring_to_bytes(computed_message_str) 390 | print("messsage_bytes:") 391 | print(message_bytes) 392 | print("sha256 digest:") 393 | print(hashlib.sha256(message_bytes).hexdigest()) 394 | --------------------------------------------------------------------------------