├── README.md └── QuGan-Notebook.ipynb /README.md: -------------------------------------------------------------------------------- 1 | The following code is used to replicate the results of the paper: 2 | 3 | Qugan: A generative adversarial network through quantum states, Samuel A. Stein and Betis Baheri and Daniel Chen and Ying Mao and Qiang Guan and Ang Li and Bo Fang and Shuai Xu, 2021 IEEE International Conference on Quantum Computing and Engineering (QCE). 4 | 5 | Bibtex: 6 | 7 | @inproceedings{stein2021qugan, 8 | title={QuGAN: A Generative Adversarial Network Through Quantum States}, 9 | author={Samuel A. Stein and Betis Baheri and Daniel Chen and Ying Mao and Qiang Guan and Ang Li and Bo Fang and Shuai Xu}, 10 | year={2021}, 11 | booktitle={2021 IEEE International Conference on Quantum Computing and Engineering (QCE)}, 12 | year={2021}, 13 | organization={IEEE} 14 | } 15 | 16 | 17 | The code is run through an interactive Jupyter Notebook with comments. 18 | 19 | 20 | - Prerequisite packages include: 21 | - Tensorflow 22 | - Qiskit 23 | - Matplotlib 24 | - Numpy 25 | 26 | -------------------------------------------------------------------------------- /QuGan-Notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "import random\n", 11 | "import numpy as np\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "import seaborn as sns\n", 14 | "from sklearn import datasets\n", 15 | "from qiskit import QuantumRegister, ClassicalRegister\n", 16 | "from qiskit import QuantumRegister\n", 17 | "from qiskit import QuantumCircuit\n", 18 | "from qiskit import Aer, execute\n", 19 | "from qiskit.providers.aer import QasmSimulator\n", 20 | "from math import pi\n", 21 | "from qiskit import * \n", 22 | "import tensorflow as tf\n", 23 | "from qutip import *\n", 24 | "from sklearn.decomposition import PCA\n", 25 | "import time\n", 26 | "import pandas as pd\n", 27 | "import numpy as np\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "import matplotlib.patheffects as PathEffects\n", 30 | "from qiskit import IBMQ\n", 31 | "\n", 32 | "# Set this to the backend you are choosing for qiskit.\n", 33 | "# For real IBMQ Evaluation, use a provider\n", 34 | "backend = Aer.get_backend('qasm_simulator')\n" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": { 41 | "scrolled": false 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "# --------------------------------------------------\n", 46 | "# The following section we prepare the MNIST dataset\n", 47 | "# and normalize the dataset to be in the bound 0-1\n", 48 | "# Following this, the data is transformed using the \n", 49 | "# PCA algorithm down to k dimensions \n", 50 | "# --------------------------------------------------\n", 51 | "test_images,test_labels = tf.keras.datasets.mnist.load_data()\n", 52 | "train_images = test_images[0].reshape(60000,784)\n", 53 | "train_labels = test_images[1]\n", 54 | "labels = test_images[1]\n", 55 | "train_images = train_images/255\n", 56 | "\n", 57 | "# --------------------------------------------------\n", 58 | "# ---------------- PCA Section ---------------------\n", 59 | "# --------------------------------------------------\n", 60 | "k=2\n", 61 | "pca = PCA(n_components=k)\n", 62 | "pca.fit(train_images)\n", 63 | "pca_data = pca.transform(train_images)[:10000]\n", 64 | "train_labels = train_labels[:10000]\n", 65 | "t_pca_data = pca_data.copy()\n", 66 | "pca_descaler = [[] for _ in range(k)]\n", 67 | "\n", 68 | "for i in range(k):\n", 69 | " if pca_data[:,i].min() < 0:\n", 70 | " pca_descaler[i].append(pca_data[:,i].min())\n", 71 | " pca_data[:,i] += np.abs(pca_data[:,i].min())\n", 72 | " else:\n", 73 | " pca_descaler[i].append(pca_data[:,i].min())\n", 74 | " pca_data[:,i] -= pca_data[:,i].min()\n", 75 | " pca_descaler[i].append(pca_data[:,i].max())\n", 76 | " pca_data[:,i] /= pca_data[:,i].max()\n", 77 | "\n", 78 | "# --------------------------------------------------\n", 79 | "# ----- Transform PCA data to rotations ----------\n", 80 | "# --------------------------------------------------\n", 81 | "pca_data_rot= 2*np.arcsin(np.sqrt(pca_data))\n", 82 | "valid_labels = None\n", 83 | "valid_labels = train_labels==9\n", 84 | "valid_labels = train_labels == 3 \n", 85 | "\n", 86 | "pca_data_rot = pca_data_rot[valid_labels]\n", 87 | "pca_data = pca_data[valid_labels]\n", 88 | "\n", 89 | "print(f\"The Total Explained Variance of {k} Dimensions is {sum(pca.explained_variance_ratio_).round(3)}\")\n", 90 | "\n", 91 | "# --------------------------------------------------\n", 92 | "# Define a function that can take in PCA'ed data and return an image\n", 93 | "# --------------------------------------------------\n", 94 | "def descale_points(d_point,scales=pca_descaler,tfrm=pca):\n", 95 | " for col in range(d_point.shape[1]):\n", 96 | " d_point[:,col] *= scales[col][1]\n", 97 | " d_point[:,col] += scales[col][0]\n", 98 | " reconstruction = tfrm.inverse_transform(d_point)\n", 99 | " return reconstruction" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "## These Functions Lead to Qubits Encoding 1 Dimension of Data\n", 107 | "## Next Section Is For Dual Qubit Encoding" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "#All functions needed for the functionality of the circuit simulation\n", 117 | "def generate_and_save_images(model, epoch, test_input):\n", 118 | " # Notice `training` is set to False.\n", 119 | " # This is so all layers run in inference mode (batchnorm).\n", 120 | " predictions = model(test_input, training=False)\n", 121 | "\n", 122 | " fig = plt.figure(figsize=(4,4))\n", 123 | "\n", 124 | " for i in range(predictions.shape[0]):\n", 125 | " plt.subplot(4, 4, i+1)\n", 126 | " dp = np.array((predictions[i] * 127.5) + 127.5).astype('uint8')\n", 127 | " plt.imshow(dp)\n", 128 | " plt.axis('off')\n", 129 | "\n", 130 | " #plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", 131 | " plt.show()\n", 132 | " \n", 133 | "\n", 134 | " \n", 135 | "def ran_ang():\n", 136 | " #return np.pi/2\n", 137 | " return np.random.rand()*np.pi\n", 138 | "\n", 139 | "def single_qubit_unitary(circ_ident,qubit_index,values):\n", 140 | " circ_ident.ry(values[0],qubit_index)\n", 141 | "\n", 142 | "def dual_qubit_unitary(circ_ident,qubit_1,qubit_2,values):\n", 143 | " circ_ident.ryy(values[0],qubit_1,qubit_2)\n", 144 | "\n", 145 | "def controlled_dual_qubit_unitary(circ_ident,control_qubit,act_qubit,values):\n", 146 | " circ_ident.cry(values[0],control_qubit,act_qubit)\n", 147 | " #circ_ident.cry(values[0],act_qubit,control_qubit)\n", 148 | " \n", 149 | "def traditional_learning_layer(circ_ident,num_qubits,values,style=\"Dual\",qubit_start=1,qubit_end=5):\n", 150 | " if style == \"Dual\":\n", 151 | " for qub in np.arange(qubit_start,qubit_end):\n", 152 | " single_qubit_unitary(circ_ident,qub,values[str(qub)])\n", 153 | " for qub in np.arange(qubit_start,qubit_end-1):\n", 154 | " dual_qubit_unitary(circ_ident,qub,qub+1,values[str(qub)+\",\"+str(qub+1)])\n", 155 | " elif style ==\"Single\":\n", 156 | " for qub in np.arange(qubit_start,qubit_end):\n", 157 | " single_qubit_unitary(circ_ident,qub,values[str(qub)])\n", 158 | " elif style==\"Controlled-Dual\":\n", 159 | " for qub in np.arange(qubit_start,qubit_end):\n", 160 | " single_qubit_unitary(circ_ident,qub,values[str(qub)])\n", 161 | " for qub in np.arange(qubit_start,qubit_end-1):\n", 162 | " dual_qubit_unitary(circ_ident,qub,qub+1,values[str(qub)+\",\"+str(qub+1)])\n", 163 | " for qub in np.arange(qubit_start,qubit_end-1):\n", 164 | " controlled_dual_qubit_unitary(circ_ident,qub,qub+1,values[str(qub)+\"--\"+str(qub+1)])\n", 165 | "\n", 166 | "def data_loading_circuit(circ_ident,num_qubits,values,qubit_start=1,qubit_end=5):\n", 167 | " k = 0\n", 168 | " for qub in np.arange(qubit_start,qubit_end):\n", 169 | " circ_ident.ry(values[k],qub)\n", 170 | " k += 1\n", 171 | "\n", 172 | "def swap_test(circ_ident,num_qubits):\n", 173 | " num_swap = num_qubits//2\n", 174 | " for i in range(num_swap):\n", 175 | " circ_ident.cswap(0,i+1,i+num_swap+1)\n", 176 | " circ_ident.h(0)\n", 177 | " circ_ident.measure(0,0)\n", 178 | " \n", 179 | "def init_random_variables(q,style):\n", 180 | " trainable_variables = {}\n", 181 | " if style==\"Single\":\n", 182 | " for i in np.arange(1,q+1):\n", 183 | " trainable_variables[str(i)] = [ran_ang()]\n", 184 | " elif style==\"Dual\":\n", 185 | " for i in np.arange(1,q+1):\n", 186 | " trainable_variables[str(i)] = [ran_ang()]\n", 187 | " if i != q:\n", 188 | " trainable_variables[str(i)+\",\"+str(i+1)] = [ran_ang()]\n", 189 | " elif style==\"Controlled-Dual\":\n", 190 | " for i in np.arange(1,q+1):\n", 191 | " trainable_variables[str(i)] = [ran_ang()]\n", 192 | " if i != q:\n", 193 | " trainable_variables[str(i)+\",\"+str(i+1)] = [ran_ang()]\n", 194 | " trainable_variables[str(i)+\"--\"+str(i+1)] = [ran_ang()]\n", 195 | " return trainable_variables\n", 196 | " \n", 197 | "def get_probabilities(circ,counts=5000):\n", 198 | " job = execute(circ, backend, shots=counts)\n", 199 | " results = job.result().get_counts(circ)\n", 200 | " try:\n", 201 | " prob = results['0']/(results['1']+results['0'])\n", 202 | " prob = (prob-0.5)\n", 203 | " if prob <= 0.005:\n", 204 | " prob = 0.005\n", 205 | " else:\n", 206 | " prob = prob*2\n", 207 | " except:\n", 208 | " prob = 1\n", 209 | " return prob\n", 210 | " \n", 211 | "# Define loss function. SWAP Test returns probability, so minmax probability is logical\n", 212 | "def cost_function(p,yreal,trimming):\n", 213 | " if yreal == 0:\n", 214 | " return -np.log(p)\n", 215 | " #return 1-p\n", 216 | " elif yreal == 1:\n", 217 | " return -np.log(1-p)\n", 218 | " #return p\n", 219 | " \n", 220 | "def generator_cost_function(p):\n", 221 | " return -np.log(p)\n", 222 | "\n", 223 | "def update_weights(init_value,lr,grad):\n", 224 | " while lr*grad > 2*np.pi:\n", 225 | " lr /= 10\n", 226 | " print(\"Warning - Gradient taking steps that are very large. Drop learning rate\")\n", 227 | " weight_update = lr*grad\n", 228 | " new_value = init_value\n", 229 | " print(\"Updating with a new value of \" + str(weight_update))\n", 230 | " if new_value-weight_update > 2*np.pi:\n", 231 | " new_value = (new_value-weight_update) - 2*np.pi\n", 232 | " elif new_value-weight_update < 0:\n", 233 | " new_value = (new_value-weight_update) + 2*np.pi\n", 234 | " else:\n", 235 | " new_value = new_value - weight_update\n", 236 | " return new_value \n", 237 | "\n", 238 | "# Define loss function. SWAP Test returns probability, so minmax probability is logical\n", 239 | "def cost_function(p,yreal,trimming):\n", 240 | " if yreal == 0:\n", 241 | " return -np.log(p)\n", 242 | " #return 1-p\n", 243 | " elif yreal == 1:\n", 244 | " return -np.log(1-p)\n", 245 | " #return p\n", 246 | " \n", 247 | "def generator_cost_function(p):\n", 248 | " return -np.log(p)\n", 249 | "\n", 250 | "def update_weights(init_value,lr,grad):\n", 251 | " while lr*grad > 2*np.pi:\n", 252 | " lr /= 10\n", 253 | " print(\"Warning - Gradient taking steps that are very large. Drop learning rate\")\n", 254 | " weight_update = lr*grad\n", 255 | " new_value = init_value\n", 256 | " print(\"Updating with a new value of \" + str(weight_update))\n", 257 | " if new_value-weight_update > 2*np.pi:\n", 258 | " new_value = (new_value-weight_update) - 2*np.pi\n", 259 | " elif new_value-weight_update < 0:\n", 260 | " new_value = (new_value-weight_update) + 2*np.pi\n", 261 | " else:\n", 262 | " new_value = new_value - weight_update\n", 263 | " return new_value \n", 264 | "\n" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": null, 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [ 273 | "# ------------------------------------------------------------------------------------\n", 274 | "# We treat the first n qubits are the discriminators state. n is always defined as the\n", 275 | "# integer division floor of the qubit count.\n", 276 | "# This is due to the fact that a state will always be k qubits, therefore the \n", 277 | "# number of total qubits must be 2k+1. 2k as we need k for the disc, and k to represent\n", 278 | "# either the other learned quantum state, or k to represent a data point\n", 279 | "# then +1 to perform the SWAP test. Therefore, we know that we will always end up\n", 280 | "# with an odd number of qubits. We take the floor to solve for k. 1st k represents \n", 281 | "# disc, 2nd k represents the \"loaded\" state be it gen or real data\n", 282 | "# ------------------------------------------------------------------------------------\n", 283 | "# Use different function calls to represent training a GENERATOR or training a DISCRIMINATOR\n", 284 | "# ------------------------------------------------------------------------------------\n", 285 | "# THIS SECTION IS FOR THE ONLINE GENERATION OF QUANTUM CIRCUITS\n", 286 | "\n", 287 | "def disc_fake_training_circuit(trainable_variables,key,key_value,diff=False,fwd_diff = False,Sample=False):\n", 288 | " if Sample:\n", 289 | " z = q//2\n", 290 | " circ = QuantumCircuit(q,z)\n", 291 | " else:\n", 292 | " circ = QuantumCircuit(q,c)\n", 293 | " circ.h(0)\n", 294 | " if diff == True and fwd_diff == True:\n", 295 | " trainable_variables[key][key_value] += par_shift\n", 296 | " if diff == True and fwd_diff == False:\n", 297 | " trainable_variables[key][key_value] -= par_shift\n", 298 | " traditional_learning_layer(circ,q,trainable_variables,style=layer_style,qubit_start=1,qubit_end=q//2 +1)\n", 299 | " traditional_learning_layer(circ,q,trainable_variables,style=layer_style,qubit_start=q//2 +1,qubit_end=q)\n", 300 | " if Sample:\n", 301 | " for qub in range(q//2):\n", 302 | " circ.measure(q//2 + 1 + qub,qub)\n", 303 | " else:\n", 304 | " swap_test(circ,q)\n", 305 | " if diff == True and fwd_diff == True:\n", 306 | " trainable_variables[key][key_value] -= par_shift\n", 307 | " if diff == True and fwd_diff == False:\n", 308 | " trainable_variables[key][key_value] += par_shift\n", 309 | " return circ\n", 310 | "\n", 311 | "def disc_real_training_circuit(training_variables,data,key,key_value,diff,fwd_diff):\n", 312 | " circ = QuantumCircuit(q,c)\n", 313 | " circ.h(0)\n", 314 | " if diff == True & fwd_diff == True:\n", 315 | " training_variables[key][key_value] += par_shift\n", 316 | " if diff == True & fwd_diff == False:\n", 317 | " training_variables[key][key_value] -= par_shift\n", 318 | " traditional_learning_layer(circ,q,training_variables,style=layer_style,qubit_start=1,qubit_end=q//2 +1)\n", 319 | " data_loading_circuit(circ,q,data,qubit_start=q//2 +1,qubit_end=q)\n", 320 | " if diff == True & fwd_diff == True:\n", 321 | " training_variables[key][key_value] -= par_shift\n", 322 | " if diff == True & fwd_diff == False:\n", 323 | " training_variables[key][key_value] += par_shift\n", 324 | " swap_test(circ,q)\n", 325 | " return circ\n", 326 | "\n", 327 | "def generate_kl_divergence_hist(actual_data, epoch_results_data):\n", 328 | " plt.clf() # clears current figure\n", 329 | " sns.set()\n", 330 | " kl_div_vec = []\n", 331 | " for kl_dim in range(actual_data.shape[1]):\n", 332 | " kl_div = kl_divergence(actual_data[:,kl_dim],epoch_results_data[:,kl_dim])\n", 333 | " kl_div_vec.append(kl_div)\n", 334 | " return kl_div_vec\n", 335 | "\n", 336 | "def bin_data(dataset):\n", 337 | " bins = np.zeros(10)\n", 338 | " for point in dataset:\n", 339 | " indx = int(str(point).split('.')[-1][0]) # The shittest way imaginable to extract the first val aft decimal\n", 340 | " bins[indx] +=1 \n", 341 | " bins /= sum(bins)\n", 342 | " return bins\n", 343 | "\n", 344 | "def kl_divergence(p_dist, q_dist):\n", 345 | " p = bin_data(p_dist)\n", 346 | " q = bin_data(q_dist)\n", 347 | " kldiv = 0\n", 348 | " for p_point,q_point in zip(p,q):\n", 349 | " kldiv += (np.sqrt(p_point) - np.sqrt(q_point))**2\n", 350 | " kldiv = (1/np.sqrt(2))*kldiv**0.5 \n", 351 | " return kldiv\n", 352 | " #return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p))) # ?... are we confident in this... \n", 353 | "\n", 354 | " \n", 355 | "# Checkpointing code\n", 356 | "def save_variables(var_dict,epoch):\n", 357 | " with open(f\"Epoch-{epoch}-Variables-numbers-9\",'w') as file:\n", 358 | " file.write(str(train_var))" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "metadata": { 365 | "scrolled": true 366 | }, 367 | "outputs": [], 368 | "source": [ 369 | "q=5 # Set it to dimensionality of data *n 2 +. 1\n", 370 | "c=1\n", 371 | "tracked_kl_div_1 = []\n", 372 | "tracked_kl_div_2 = []\n", 373 | "# Initialize a quantum circuit with q qubits, and c cbits\n", 374 | "# Initializt Ancilla qubit in equi-superposition\n", 375 | "circ = QuantumCircuit(q,c)\n", 376 | "circ.h(0)\n", 377 | "layer_style = \"Controlled-Dual\"\n", 378 | "train_var = init_random_variables(q-1,layer_style)\n", 379 | "\n", 380 | "# Initial Learning Settings such as alpha etc.\n", 381 | "tracked_d_loss = []\n", 382 | "gradients = []\n", 383 | "learning_rate=0.01\n", 384 | "train_iter = 250\n", 385 | "tracked_g_loss = []\n", 386 | "gradients_g = []\n", 387 | "corr = 0\n", 388 | "wrong= 0 \n", 389 | "loss_d_to_g = 0\n", 390 | "loss_d_to_real = 0\n", 391 | "tracked_loss_d_to_g = []\n", 392 | "tracked_loss_d_to_real = []\n", 393 | "train_on_fake = 5\n", 394 | "df = [0,0]\n", 395 | "print('Starting Training')\n", 396 | "print('-'*20)\n", 397 | "\n", 398 | "for epoch in np.arange(1,100):\n", 399 | " par_shift = 0.5*np.pi*np.sqrt(1/(epoch+1))\n", 400 | " # ------------------------------------------------------------------------------------------\n", 401 | " # This section is the discriminator training section\n", 402 | " # Each data point is tested against a random number, of which it decidesa wheter to \n", 403 | " # Train against discerning between fake or real \n", 404 | " # This causes \"unstable\" loss functions, but not very \"unstable\". Just slightly inconsistent\n", 405 | " # ------------------------------------------------------------------------------------------\n", 406 | " counter = 0\n", 407 | " for _ in range(1):\n", 408 | " for key,value in train_var.items():\n", 409 | " if str(q//2 + 1 ) in key:\n", 410 | " break\n", 411 | " for key_value in range(len(value)):\n", 412 | " forward_diff = cost_function(get_probabilities(disc_fake_training_circuit(train_var,key,key_value,diff=True,fwd_diff=True)),1,None)\n", 413 | " backward_diff = cost_function(get_probabilities(disc_fake_training_circuit(train_var,key,key_value,diff=True,fwd_diff=False)),1,None)\n", 414 | " df = 0.5*(forward_diff-backward_diff)\n", 415 | " if abs(df)>1:\n", 416 | " df = df/abs(df)\n", 417 | " #train_var[key][key_value] -= df*learning_rate/10\n", 418 | " for index,point in enumerate(pca_data_rot):\n", 419 | " df = [0,0]\n", 420 | " gradients = []\n", 421 | " loss= [0,0]\n", 422 | " #Training the Discriminator:\n", 423 | " for key,value in train_var.items():\n", 424 | " if str(q//2 + 1) in key:\n", 425 | " break\n", 426 | " for key_value in range(len(value)):\n", 427 | " #TRAIN ON REAL DATA\n", 428 | " # BETIS HERE\n", 429 | " # _________\n", 430 | " forward_diff = cost_function(get_probabilities(disc_real_training_circuit(train_var,point,key,key_value,diff=True,fwd_diff=True)),0,None)\n", 431 | " backward_diff = cost_function(get_probabilities(disc_real_training_circuit(train_var,point,key,key_value,diff=True,fwd_diff=False)),0,None)\n", 432 | " df = 0.5*(forward_diff-backward_diff)\n", 433 | " train_var[key][key_value] -= learning_rate*df\n", 434 | " loss[0] += cost_function(get_probabilities(disc_real_training_circuit(train_var,point,key,key_value,diff=False,fwd_diff=False)),0,None)\n", 435 | " loss[1] += 1\n", 436 | " loss_g = [0,0]\n", 437 | " # ------------------------------------------------------------------------------------------\n", 438 | " # This section is the generator training section\n", 439 | " # The discriminator just looks to fool the state we learnt above \n", 440 | " # This means that instead of learning 10000 times, we could up the learning rate and just learn a few more times\n", 441 | " # We dont want it to be too large so it spins around the qubits state\n", 442 | " # ------------------------------------------------------------------------------------------\n", 443 | " #Train the generator now as much as we trained the Disc\n", 444 | " for _ in range(len(pca_data_rot)//10):\n", 445 | " gen_params=True\n", 446 | " for key,value in train_var.items():\n", 447 | " if str(q//2 + 1) not in key and gen_params:\n", 448 | " #print(f\"{key} is not a GAN parameter\")\n", 449 | " continue\n", 450 | " else: \n", 451 | " gen_params = False\n", 452 | " for key_value in range(len(value)):\n", 453 | " #TRAIN ON FAKE DATA\n", 454 | " forward_diff = generator_cost_function(get_probabilities(disc_fake_training_circuit(train_var,key,key_value,diff=True,fwd_diff=True)))\n", 455 | " backward_diff = generator_cost_function(get_probabilities(disc_fake_training_circuit(train_var,key,key_value,diff=True,fwd_diff=False)))\n", 456 | " df = 0.5*(forward_diff-backward_diff)\n", 457 | " train_var[key][key_value] -= df*learning_rate*2.5\n", 458 | " loss_g[0] += generator_cost_function(get_probabilities(disc_fake_training_circuit(train_var,key,key_value,diff=False,fwd_diff=False)))\n", 459 | " loss_g[1] +=1\n", 460 | " print(f\"Generator Loss: {loss_g[0]/loss_g[1]}\")\n", 461 | " tracked_g_loss.append(loss_g[0]/loss_g[1])\n", 462 | " loss_qgan = cost_function(get_probabilities(disc_fake_training_circuit(train_var,key,key_value,diff=False,fwd_diff=False)),1,None) \n", 463 | " t_loss = loss_qgan + (loss[0]/loss[1])\n", 464 | " tracked_loss_d_to_real.append(loss[0]/loss[1])\n", 465 | " tracked_loss_d_to_g.append(loss_qgan)\n", 466 | " print(f\"Discriminator Loss: {t_loss}\")\n", 467 | " tracked_d_loss.append(t_loss)\n", 468 | " print(\"-\"*20)\n", 469 | " data = []\n", 470 | " circ = disc_fake_training_circuit(train_var,point,key,key_value,Sample=True)\n", 471 | " n_results = q//2\n", 472 | " for _ in range(500):\n", 473 | " job = execute(circ, backend, shots=20)\n", 474 | " results = job.result().get_counts(circ)\n", 475 | " bins = [[0,0] for _ in range(n_results)]\n", 476 | " for key,value in results.items():\n", 477 | " for i in range(n_results):\n", 478 | " if key[-i-1]== '1':\n", 479 | " bins[i][0] += value\n", 480 | " bins[i][1] += value\n", 481 | " for i,pair in enumerate(bins):\n", 482 | " bins[i]= pair[0]/pair[1]\n", 483 | " data.append(bins)\n", 484 | " data = np.array(data)\n", 485 | " try:\n", 486 | " graph = sns.jointplot(x=data[:,0],y=data[:,1],kind=\"kde\",ylim=(0,1),xlim=(0,1))\n", 487 | " graph.x = pca_data[:,0]\n", 488 | " graph.y = pca_data[:,1]\n", 489 | " graph.plot_joint(plt.scatter, marker='o', c='r', s=5)\n", 490 | " plt.show()\n", 491 | " except:\n", 492 | " pass\n", 493 | " plt.savefig(\"qgan_ICLR_-epoch-mnist-{}-generated-distribution\".format(epoch))\n", 494 | " dim1_kl_div = generate_kl_divergence_hist(pca_data, data)\n", 495 | " print(dim1_kl_div)\n", 496 | " tracked_kl_div_1.append(np.mean(np.array(dim1_kl_div)))\n", 497 | " print(tracked_kl_div_1)\n", 498 | " # For accurate KL Div we need to usue higher shots\n", 499 | " data = []\n", 500 | " for _ in range(16):\n", 501 | " job = execute(circ, backend, shots=20)\n", 502 | " results = job.result().get_counts(circ)\n", 503 | " bins = [[0,0] for _ in range(n_results)]\n", 504 | " for key,value in results.items():\n", 505 | " for i in range(n_results):\n", 506 | " if key[-i-1]== '1':\n", 507 | " bins[i][0] += value\n", 508 | " bins[i][1] += value\n", 509 | " for i,pair in enumerate(bins):\n", 510 | " bins[i]= pair[0]/pair[1]\n", 511 | " data.append(bins)\n", 512 | " data = np.array(data)\n", 513 | " new_info = descale_points(data[:16])\n", 514 | " new_info = new_info.reshape(new_info.shape[0],28,28)\n", 515 | " print(f\"Epoch {epoch} Generated Images\")\n", 516 | " for i in range(new_info.shape[0]):\n", 517 | " plt.subplot(4, 4, i+1)\n", 518 | " plt.imshow(new_info[i, :, :], cmap='gray')\n", 519 | " plt.axis('off')\n", 520 | " plt.savefig(\"qgan_ICLR_-epoch-mnist-{}-generated-images\".format(epoch))\n", 521 | " plt.show()\n", 522 | " with open('new_qgan_results_mnis_epoch_ICLR_{}.txt'.format(epoch), 'w') as file:\n", 523 | " file.write(\"Tracked KL Divergence\\n\")\n", 524 | " file.write(str(tracked_kl_div_1)+\"\\n\")\n", 525 | " file.write(\"Loss Of Generator\\n\")\n", 526 | " file.write(str(tracked_g_loss)+\"\\n\")\n", 527 | " file.write(\"Loss Of Discriminator\\n\")\n", 528 | " file.write(str(tracked_d_loss)+\"\\n\")\n", 529 | " save_variables(train_var,epoch)" 530 | ] 531 | } 532 | ], 533 | "metadata": { 534 | "kernelspec": { 535 | "display_name": "Python 3", 536 | "language": "python", 537 | "name": "python3" 538 | }, 539 | "language_info": { 540 | "codemirror_mode": { 541 | "name": "ipython", 542 | "version": 3 543 | }, 544 | "file_extension": ".py", 545 | "mimetype": "text/x-python", 546 | "name": "python", 547 | "nbconvert_exporter": "python", 548 | "pygments_lexer": "ipython3", 549 | "version": "3.8.3" 550 | } 551 | }, 552 | "nbformat": 4, 553 | "nbformat_minor": 4 554 | } 555 | --------------------------------------------------------------------------------