├── README.md ├── LICENSE ├── optical_elements.py ├── errorCorrection.py └── qkd_sim.py /README.md: -------------------------------------------------------------------------------- 1 | # qkd-sim 2 | Simulation of Quantum Key Distribution Protocols 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Madhav Jivrajani 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /optical_elements.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | class LinearPolarizer: 5 | def __init__(self): 6 | 7 | self.x = np.array([[1], [0]]) 8 | self.y = np.array([[0], [1]]) 9 | 10 | def horizontal_vertical(self, bit): 11 | if bit == 0: 12 | return self.x 13 | else: 14 | return self.y 15 | 16 | def diagonal_polarization(self, bit): 17 | jones = (1/np.sqrt(2))*np.array([[1,1],[1,-1]]) 18 | 19 | if bit == 0: 20 | return np.dot(jones, self.x) 21 | else: 22 | return np.dot(jones, self.y) 23 | 24 | def general_polarization(self, angle, basis): 25 | """ 26 | angle to be in degrees 27 | """ 28 | angle = (math.pi/180) * (angle) 29 | jones = np.array([[np.cos(angle), np.sin(angle)], [np.sin(angle), -np.cos(angle)]]) 30 | 31 | return np.dot(jones, basis) 32 | 33 | class PolarizingBeamSplitter: 34 | def __init__(self): 35 | pass 36 | 37 | def measure(self, vector, basis): 38 | """ 39 | basis : basis chosen by bob to measure polarization encoded photon 40 | 0 -> horizontal/vertical 41 | 1 -> diagonal 42 | vector : Jones vector for polarized photon 43 | 44 | returns a dictionary with probabilities of the encoded bit sent by Alice being 0 or 1 45 | """ 46 | #horizontal-vertical can be clubbed into an identity matrix 47 | horizontal = np.array([[1, 0], [0, 0]]) 48 | vertical = np.array([[0, 0], [0, 1]]) 49 | plus_minus = (1/np.sqrt(2))*np.array([[1,1],[1,-1]]) 50 | 51 | if basis == 0: 52 | zero = np.dot(horizontal, vector)[0] 53 | one = np.dot(vertical, vector)[1] 54 | 55 | elif basis == 1: 56 | zero = np.dot(plus_minus, vector)[0] 57 | one = np.dot(plus_minus, vector)[1] 58 | else: 59 | print("here") 60 | return None 61 | 62 | return {0: zero[0]**2, 1: one[0]**2} 63 | -------------------------------------------------------------------------------- /errorCorrection.py: -------------------------------------------------------------------------------- 1 | #Alice and Bob announce some of their measured bits 2 | #Every 4th basis reconciled bit in our case. Remove the reading from the list because Eve knows about the measurements for those readings. 3 | #Error rate = number of errors/total compared cases 4 | #Find optimal length of set where the probability of finding more than 1 error is least. 5 | #Find parity in each of the sets. If equal, most likely the measurements are same for ALice and Bob. If different remove from list. 6 | #Also remove one reading from the set for which parity was announced. To maintain security (so that Eve cannot guess) 7 | #if block size is large enough, and parity is different, we can do bisective searching till the error is found and reject that state for both alice and bob. 8 | 9 | 10 | alice = ['H', 'D', 'D', 'H', 'H','H', 'D', 'D', 'H', 'H','H', 'D', 'D', 'H', 'H','H', 'D', 'D', 'H', 'H','H', 'D', 'D', 'H', 'H', 'D'] 11 | bob = ['H', 'D', 'D', 'H', 'H','H', 'D', 'D', 'H', 'H','H', 'D', 'D', 'H', 'H','H', 'D', 'D', 'H', 'H','H', 'D', 'D', 'H', 'H', 'D'] 12 | alice_b = [0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1] 13 | bob_b = [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1] 14 | def calc_error_rate(alice, bob, alice_b, bob_b): 15 | error = 0 16 | count = 0 17 | for j in range(0, len(alice), 4): 18 | i = j-count 19 | #i = alice_b.index(j) 20 | #print i 21 | if(alice_b[i]!=bob_b[i]): 22 | error+=1 23 | count+=1 24 | alice_b.pop(i) 25 | alice.pop(i) 26 | bob_b.pop(i) 27 | bob.pop(i) 28 | error_rate = float(error/count) 29 | return error, count, error_rate, alice, bob, alice_b, bob_b 30 | 31 | def set_length(error, count): 32 | s_len = int(count/error) 33 | if s_len<3: 34 | return 3 35 | return s_len 36 | def find_parity(bits): 37 | count = 0 38 | for i in bits: 39 | count+=i 40 | par = count%2 41 | return par 42 | 43 | def remove_last(bits): 44 | bits.pop() 45 | return bits 46 | 47 | 48 | 49 | error, count, error_rate, alice, bob, alice_b, bob_b = calc_error_rate(alice, bob, alice_b, bob_b) 50 | setLen = set_length(error, count) 51 | j = 0 52 | k = 0 53 | par_alice = [] 54 | par_bob = [] 55 | bits = [] 56 | #m = len(alice) - len(alice)%setLen 57 | 58 | def binary_search(alice_sub, bob_sub): 59 | # if len(bob_sub)>1: 60 | # if find_parity(alice_sub) != find_parity(bob_sub): 61 | # binary_search(alice_sub[:int(len(alice_sub)/2)], bob_sub[:int(len(bob_sub)/2)]) 62 | # #print(alice_sub) 63 | # binary_search(alice_sub[int(len(alice_sub)/2):], bob_sub[int(len(bob_sub)/2):]) 64 | # #print(bob_sub) 65 | # elif len(bob_sub)==1: 66 | # #print(alice_sub, bob_sub) 67 | # #if find_parity(alice_sub)!=find_parity(bob_sub): 68 | # bob_sub[0] = not bob_sub[0] 69 | # return 70 | r = len(alice_sub) 71 | l = 0 72 | 73 | while(l 17 | 0-> horizontal/vertical polarization 18 | 1-> diagonal polarization 19 | 20 | Should generate a dictionary of the form self.alice mentioned above 21 | """ 22 | LP = LinearPolarizer() 23 | encode = [] 24 | count = self.n 25 | 26 | while count!= 0: 27 | self.alice[count] = [ random.randint(0,1), random.randint(0,1)] 28 | if self.alice[count][1] == 0: 29 | encode.append(LP.horizontal_vertical(self.alice[count][0])) 30 | else: 31 | encode.append(LP.diagonal_polarization(self.alice[count][0])) 32 | count-=1 33 | 34 | return encode 35 | 36 | 37 | class Bob: 38 | def __init__(self, n): 39 | self.n = n 40 | self.bob = {} #{no. : [bit after measurement, basis chosen to measure in]} 41 | #{1:[1,0],2:[0,0],3:[1,0]} Example 42 | 43 | def choose_basis_and_measure(self, received): 44 | """ 45 | received : the data received by bob 46 | Dependency for measurement: 47 | 48 | 0-> horizontal/vertical polarization 49 | 1-> diagonal polarization 50 | 51 | Should generate a dictionary of the form self.bob mentioned above 52 | """ 53 | #self.bob[n][0] is the measured bit 54 | 55 | PBS = PolarizingBeamSplitter() 56 | count = self.n 57 | i = 0 58 | while count!= 0: 59 | self.bob[count] = [0, random.randint(0,1)] 60 | measure = PBS.measure(received[i], self.bob[count][1]) 61 | if measure[0] == measure[1]: 62 | self.bob[count][0] = random.randint(0,1) 63 | elif measure[0] > measure[1]: 64 | self.bob[count][0] = 0 65 | else: 66 | self.bob[count][0] = 1 67 | i += 1 68 | count-=1 69 | 70 | class Privacy_amplification: 71 | def __init__(self, n): 72 | 73 | self.n = n 74 | 75 | def find_parity(self, bits): 76 | count = 0 77 | for i in bits: 78 | count+=i 79 | par = count%2 80 | return par 81 | 82 | 83 | def privacy_amplification(self, error_rate, s, alice_bit, bob_bit): 84 | k = int(error_rate * 2) 85 | subset_size = self.n - k - s 86 | final_alice = [] 87 | final_bob = [] 88 | 89 | alice_b = [] 90 | bob_b = [] 91 | 92 | alice_subsets = [] 93 | bob_subsets = [] 94 | 95 | for i in alice_bit: 96 | alice_b.append(i) 97 | 98 | for i in bob_bit: 99 | bob_b.append(i) 100 | 101 | for i in range(0, self.n, subset_size): 102 | alice_subsets.append(alice_b[i:i+subset_size]) 103 | bob_subsets.append(bob_b[i:i+subset_size]) 104 | 105 | bob_parity = 0 106 | alice_parity = 0 107 | 108 | #calculate parities of sets and compare and eliminate if parities dont match 109 | for i in range(len(alice_subsets)): 110 | 111 | alice = self.find_parity(alice_subsets[i]) 112 | bob = self.find_parity(bob_subsets[i]) 113 | 114 | if alice == bob: 115 | final_alice.append(alice) 116 | final_bob.append(bob) 117 | 118 | return final_alice, final_bob 119 | 120 | class BB84: 121 | def __init__(self, n, delta, error_threshold): 122 | """ 123 | Alice generates (4+delta)n bits 124 | delta: small fraction less than one 125 | error_threshold: if error while announcing n bits from 2n bits is greater than this 126 | key generation is aborted 127 | """ 128 | if delta > 1: 129 | print("Value for delta should be lesser than 1") 130 | return 131 | 132 | self.n = n 133 | self.total = math.ceil(4 + delta)*n 134 | self.alice = Alice(self.total) 135 | self.bob = Bob(self.total) 136 | 137 | self.error_rate = 0 138 | self.error = error_threshold 139 | 140 | def eve_interfere(self, intercept, intensity): 141 | """ 142 | intercept: the encoeded bits alice sends to bob 143 | intensity: number of bits to interfere with 144 | """ 145 | 146 | PBS = PolarizingBeamSplitter() 147 | lp = LinearPolarizer() 148 | 149 | indices = random.sample(list(range(self.total)), intensity) 150 | 151 | for i in indices: 152 | basis = random.randint(0, 1) 153 | measure = PBS.measure(intercept[i], basis) 154 | 155 | if measure[0] == measure[1]: 156 | intercept[i] = lp.diagonal_polarization(0) 157 | 158 | if measure[0] == -1 * measure[1]: 159 | intercept[i] = lp.diagonal_polarization(1) 160 | 161 | if measure[0] > measure[1]: 162 | intercept[i] = lp.horizontal_vertical(0) 163 | 164 | else: 165 | intercept[i] = lp.horizontal_vertical(1) 166 | 167 | return intercept 168 | 169 | 170 | 171 | def distribute(self, eve, intensity, priv_amp): 172 | """ 173 | eve: if an evesdropper is present or not 174 | """ 175 | encoded = self.alice.generate_and_encode() 176 | 177 | if eve==1: 178 | encoded = self.eve_interfere(encoded, intensity) 179 | 180 | self.bob.choose_basis_and_measure(encoded) 181 | 182 | recon = Reconciliation(self.error, self.alice.alice, self.bob.bob, self.n) 183 | 184 | recon_alice, recon_bob = recon.basis_reconciliation(self.alice.alice, self.bob.bob) 185 | try: 186 | final_alice, final_bob, error_rate = recon.error_correction(recon_alice, recon_bob) 187 | self.error_rate = error_rate 188 | if priv_amp: 189 | priv = Privacy_amplification(self.n) 190 | 191 | final_priv_alice, final_priv_bob = priv.privacy_amplification(error_rate, 2, final_alice, final_bob) 192 | return final_priv_alice, final_priv_bob 193 | else: 194 | return final_alice, final_bob 195 | 196 | except: 197 | self.abort() 198 | return [], [] 199 | 200 | 201 | def abort(self): 202 | print("Protocol aborted") 203 | return 204 | 205 | def calcRedundantBits(m): 206 | 207 | # Use the formula 2 ^ r >= m + r + 1 208 | # to calculate the no of redundant bits. 209 | # Iterate over 0 .. m and return the value 210 | # that satisfies the equation 211 | 212 | for i in range(m): 213 | if(2**i >= m + i + 1): 214 | return i 215 | 216 | def posRedundantBits(data, r): 217 | 218 | j = 0 219 | k = 1 220 | m = len(data) 221 | res = '' 222 | 223 | for i in range(1, m + r+1): 224 | if(i == 2**j): 225 | res = res + '0' 226 | j += 1 227 | else: 228 | res = res + data[-1 * k] 229 | k += 1 230 | 231 | return res[::-1] 232 | 233 | def calcParityBits(arr, r): 234 | n = len(arr) 235 | 236 | # For finding rth parity bit, iterate over 237 | # 0 to r - 1 238 | for i in range(r): 239 | val = 0 240 | for j in range(1, n + 1): 241 | 242 | # If position has 1 in ith significant 243 | # position then Bitwise OR the array value 244 | # to find parity bit value. 245 | if(j & (2**i) == (2**i)): 246 | val = val ^ int(arr[-1 * j]) 247 | # -1 * j is given since array is reversed 248 | 249 | # String Concatenation 250 | # (0 to n - 2^r) + parity bit + (n - 2^r + 1 to n) 251 | arr = arr[:n-(2**i)] + str(val) + arr[n-(2**i)+1:] 252 | return arr 253 | 254 | def detectError(arr, nr): 255 | n = len(arr) 256 | res = 0 257 | 258 | # Calculate parity bits again 259 | for i in range(nr): 260 | val = 0 261 | for j in range(1, n + 1): 262 | if(j & (2**i) == (2**i)): 263 | val = val ^ int(arr[-1 * j]) 264 | 265 | # Create a binary no by appending 266 | # parity bits together. 267 | 268 | res = res + val*(10**i) 269 | 270 | return int(str(res), 2) 271 | 272 | class Reconciliation: 273 | def __init__(self, error_threshold, alice, bob, n): 274 | 275 | self.alice = alice 276 | self.bob = bob 277 | self.n = n 278 | self.error_threshold = error_threshold 279 | 280 | 281 | def basis_reconciliation(self, alice, bob): 282 | """ 283 | alice: {no. : [bit encoded, basis chosen to encode in ]} 284 | bob : {no. : [bit after measurement, basis chosen to measure in]} 285 | 286 | First check if the length of both lists are the same 287 | -> if yes, keep only those bits for alice and bob for which 288 | the basis encoded in and measured in is the same. 289 | """ 290 | basis_bit_alice = list(alice.values()) 291 | basis_bit_bob = list(bob.values()) 292 | 293 | if len(basis_bit_alice) == len(basis_bit_bob): 294 | raw_key_alice = [] 295 | raw_key_bob = [] 296 | 297 | for i in range(len(basis_bit_alice)): 298 | if basis_bit_alice[i][1] == basis_bit_bob[i][1]: 299 | raw_key_alice.append(basis_bit_alice[i][0]) 300 | raw_key_bob.append(basis_bit_bob[i][0]) 301 | 302 | return raw_key_alice, raw_key_bob 303 | 304 | else: 305 | return None, None 306 | 307 | def abort(self): 308 | print("Protocol aborted here") 309 | return 310 | 311 | def sampling(self, raw_key_alice, raw_key_bob, n): 312 | 313 | sampled_key_alice, sampled_key_bob, sampled_key_index = [], [], [] 314 | sampled_key_index = random.sample(list(enumerate(raw_key_alice)), n) 315 | indices = [] 316 | 317 | for idx, val in sampled_key_index: 318 | sampled_key_alice.append(val) 319 | sampled_key_bob.append(raw_key_bob[idx]) 320 | indices.append(idx) 321 | 322 | return sampled_key_alice, sampled_key_bob, indices 323 | 324 | def error_correction(self, raw_key_alice, raw_key_bob): 325 | 326 | if len(raw_key_alice)<2*self.n: 327 | self.abort() 328 | 329 | else: 330 | 331 | sampled_key_alice, sampled_key_bob, sample_indices = self.sampling(raw_key_alice, raw_key_bob, 2*self.n) 332 | 333 | check_alice, check_bob, indices = self.sampling(sampled_key_alice, sampled_key_bob, self.n) 334 | 335 | error = 0 336 | 337 | for i in range(len(check_alice)): 338 | if check_alice[i] != check_bob[i]: 339 | error+=1 340 | 341 | error_rate = error/self.n 342 | if error_rate >= self.error_threshold: 343 | self.abort() 344 | 345 | else: 346 | 347 | req_alice = [sampled_key_alice[i] for i in range(len(sampled_key_alice)) if i not in indices] 348 | req_bob = [sampled_key_bob[i] for i in range(len(sampled_key_bob)) if i not in indices] 349 | 350 | if error_rate == 0.0: 351 | return req_alice, req_bob, error_rate 352 | 353 | string_alice = "".join(list(map(str, req_alice))) 354 | string_bob = "".join(list(map(str, req_bob))) 355 | 356 | m = len(string_bob) 357 | r = calcRedundantBits(m) 358 | arr = posRedundantBits(string_bob, r) 359 | arr = calcParityBits(arr, r) 360 | 361 | bob = [] 362 | 363 | k = 0 364 | for i in range(self.n): 365 | if i!=(2**k-1): 366 | bob.append(req_bob[i]) 367 | else: 368 | k += 1 369 | 370 | 371 | correction = self.n - detectError(arr, r) - 1 372 | 373 | req_bob[correction] = int(not req_bob[correction]) 374 | 375 | 376 | 377 | return req_alice, req_bob, error_rate 378 | 379 | bb84 = BB84(10, 0.6, 0.3) 380 | a, b = bb84.distribute(0, 3, 0) 381 | count = 0 382 | 383 | for i in range(len(a)): 384 | if a[i] != b[i]: 385 | count += 1 386 | 387 | print("Weight of transmission: ", count) 388 | print("\nDistributed Keys:\nAlice: %s \nBob: %s\n" % ("".join(list(map(str, a))), "".join(list(map(str, a))))) 389 | print("Error rate: ", bb84.error_rate) 390 | --------------------------------------------------------------------------------