├── Image_classification_with_CNN.ipynb ├── Image_classification_with_QNN.ipynb ├── Image_classification_with_QNN_2.ipynb ├── README.md ├── quantum_circuit_examples.ipynb └── quantum_circuit_simulator.py /Image_classification_with_QNN_2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "89319537", 6 | "metadata": {}, 7 | "source": [ 8 | "# Image classification \n", 9 | "\n", 10 | "Image classification is a computer vision task that involves categorizing images into predefined classes or categories. The goal is to develop algorithms or models that can accurately identify and assign labels to images based on their visual features and content. This task is commonly used in various applications, such as object recognition, facial recognition, medical imaging, and autonomous driving, to enable machines to understand and interpret visual information in a similar way to humans. The output of an image classification task is a prediction or probability distribution indicating the likelihood of each class for a given image." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "349fb80c", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import numpy as np\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "\n", 23 | "import time, copy\n", 24 | "\n", 25 | "import torch\n", 26 | "from torch import nn\n", 27 | "import torch.nn.functional as F\n", 28 | "from torch.utils.data import DataLoader\n", 29 | "from torchvision import datasets, transforms\n", 30 | "from torch.utils.data.dataset import random_split\n", 31 | "\n", 32 | "\n", 33 | "from quantum_circuit_simulator import quantum_circuit\n" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "dcbdabb7", 39 | "metadata": {}, 40 | "source": [ 41 | "## Load dataset and transform\n", 42 | "\n", 43 | "The \"Modified National Institute of Standards and Technology\" (__MNIST__) dataset consists of a large collection of 60,000 handwritten digits for training and an additional 10,000 handwritten digits for testing.\n", 44 | "It dataset has served as a benchmark for evaluating and comparing the performance of various machine learning algorithms, particularly in the field of image classification. It has played a crucial role in the development and advancement of deep learning models, especially convolutional neural networks (CNNs), and has been used extensively for educational purposes and as a baseline for assessing new algorithms and techniques in the field.\n", 45 | "\n", 46 | "\n", 47 | "_Image Format_: Each image in the MNIST dataset is a grayscale image with a resolution of 28x28 pixels. This results in a total of 784 pixels per image.\n", 48 | "\n", 49 | "_Digit Classes_: The dataset covers ten classes representing the digits from 0 to 9. Each image is labeled with the corresponding digit class, providing the ground truth for training and evaluation.\n", 50 | "\n", 51 | "_Data Distribution_: The dataset is balanced, meaning that it contains an equal number of samples for each digit class. This balance ensures that the model is exposed to an equal representation of each digit during training.\n", 52 | "\n", 53 | "\n", 54 | "-------------------\n", 55 | "\n", 56 | "__FashionMNIST__ is intended to be a more challenging dataset compared to MNIST, as it requires models to recognize and classify images of various clothing items accurately. FashionMNIST has the same Image Format, Data Distribution, and size of train- and test-set as of MNIST.\n", 57 | "\n", 58 | "_Clothing Categories_: The dataset covers ten different clothing categories, including T-shirts/tops, trousers, pullovers, dresses, coats, sandals, shirts, sneakers, bags, and ankle boots. Each image is labeled with the corresponding clothing category, providing the ground truth for training and evaluation." 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 2, 64 | "id": "0b35907a", 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "name": "stdout", 69 | "output_type": "stream", 70 | "text": [ 71 | "MNIST \n", 72 | "\n", 73 | "number of (train, test) examples = (60000, 10000)\n" 74 | ] 75 | } 76 | ], 77 | "source": [ 78 | "# Load the dataset and transform it into tensors and normalized between -1 and 1.\n", 79 | "\n", 80 | "size = 16\n", 81 | "\n", 82 | "transform = transforms.Compose([\n", 83 | " transforms.Resize((size, size)),\n", 84 | " transforms.ToTensor(),\n", 85 | " transforms.Normalize(mean=(0.5,), std=(0.5,)) \n", 86 | "])\n", 87 | "\n", 88 | "\n", 89 | "def load_dataset(name):\n", 90 | " print(name,'\\n')\n", 91 | " if name == \"FashionMNIST\":\n", 92 | " train_dataset = datasets.FashionMNIST(root=\"FashionMNIST\", train=True, download=True, transform=transform)\n", 93 | " test_dataset = datasets.FashionMNIST(root=\"FashionMNIST\", train=False, download=True, transform=transform)\n", 94 | " elif name == \"MNIST\":\n", 95 | " train_dataset = datasets.MNIST(root=\"MNIST\", train=True, download=True, transform=transform)\n", 96 | " test_dataset = datasets.MNIST(root=\"MNIST\", train=False, download=True, transform=transform)\n", 97 | " return train_dataset, test_dataset \n", 98 | "\n", 99 | "\n", 100 | "# Choose either \"MNIST\" or \"FashionMNIST\" \n", 101 | "train_dataset, test_dataset = load_dataset(\"MNIST\")\n", 102 | "\n", 103 | "\n", 104 | "\n", 105 | "print(f'number of (train, test) examples = {len(train_dataset), len(test_dataset)}')\n" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "id": "97f5b368", 111 | "metadata": {}, 112 | "source": [ 113 | "## Let us consider only two classes 0 and 1 for training and test" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 3, 119 | "id": "8cc59968", 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "number of (train, test) examples = (12665, 2115)\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "def generate_subset(dataset):\n", 132 | " subset = []\n", 133 | " for i in range(len(dataset)):\n", 134 | " x, y = dataset[i]\n", 135 | " if y in [0, 1]:\n", 136 | " subset.append((x, torch.tensor(y, dtype=torch.float32)))\n", 137 | " return subset\n", 138 | "\n", 139 | "train_dataset = generate_subset(train_dataset)\n", 140 | "test_dataset = generate_subset(test_dataset)\n", 141 | "\n", 142 | "print(f'number of (train, test) examples = {len(train_dataset), len(test_dataset)}')" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "id": "09367beb", 148 | "metadata": {}, 149 | "source": [ 150 | "### view a training example:" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 4, 156 | "id": "0eaa7b82", 157 | "metadata": { 158 | "scrolled": true 159 | }, 160 | "outputs": [ 161 | { 162 | "name": "stdout", 163 | "output_type": "stream", 164 | "text": [ 165 | "x of torch.Size([1, 16, 16]) :\n" 166 | ] 167 | }, 168 | { 169 | "data": { 170 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD4CAYAAAAjDTByAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAANUklEQVR4nO3da6xc1XnG8efxBTCEcClKQmwTMCBQE1qwLCAXkSgOyHUAOxIfjErrkKCjSKWFKlHiCKkJ33K/R4lOAglpDEhNoFgIUiwKQkXFYJ/aYMcEbOqCwbFzkSAXJOz6zYfZlsaTmeOZfTvn8P5/kjWXvdbs12vmOXvPnj2zHBECkM+sqS4AwNQg/EBShB9IivADSRF+IKk5ba7MNh8tAA2LCA/Tji0/kBThB5Ii/EBSlcJve5ntX9jeYXtNXUUBaJ7Lnt5re7akZyRdKmm3pCckXR0RP5+kDwf8gIa1ccDvQkk7IuK5iHhN0p2SVlR4PAAtqhL++ZJe6Lq9u7jvMLbHbG+0vbHCugDUrMrn/P12Lf5stz4ixiWNS+z2A9NJlS3/bkkLu24vkPRStXIAtKVK+J+QdLbtM2wfJWmVpHX1lAWgaaV3+yPigO3rJf2HpNmSbo2IbbVVBqBRpT/qK7Uy3vMDjePcfgCTIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyTV6nRdqMeiRYtG7nPzzTeXWtejjz5aqt/4+PjIfQ4ePFhqXSiHLT+QFOEHkiL8QFKlw297oe2HbG+3vc32DXUWBqBZVQ74HZD08YiYsH28pE221082XReA6aP0lj8i9kTERHH9d5K2q8+MPQCmp1o+6rN9uqQLJG3os2xM0lgd6wFQn8rht/0GST+VdGNEvNK7nOm6gOmp0tF+23PVCf7aiLirnpIAtKHK0X5LukXS9oj4Sn0lAWhDlS3/uyX9naT3295c/FteU10AGlZlrr7/Uv9pugHMAJzhByTFt/pmoCuuuGLkPitXriy1rnPPPbdUv4mJiZH7PP7446XWhXLY8gNJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpPhizwx01llnjdxn3rx5pdZ1/vnnl+p31VVXjdyHL/a0iy0/kBThB5Ii/EBSlcNve7bt/7F9bx0FAWhHHVv+G9SZrQfADFL1d/sXSPqgpO/XUw6AtlTd8n9N0iclHaxeCoA2VZm043JJ+yJi0xHajdneaHtj2XUBqF/VSTuutL1L0p3qTN7x495GETEeEUsiYkmFdQGoWZUpuj8dEQsi4nRJqyT9Z0RcU1tlABrF5/xAUrWc2x8RD0t6uI7HAtAOtvxAUnyrD5OaNavc9qHstwjRHrb8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBSVWfsOdH2T2w/bXu77XfWVRiAZlX9Ga+vS/pZRFxl+yhJx9ZQE4AWlA6/7TdKukTShyUpIl6T9Fo9ZQFoWpXd/kWSfiXpB8UU3d+3fVxvI6brAqanKuGfI2mxpO9ExAWS/iBpTW8jpusCpqcq4d8taXdEbChu/0SdPwYAZoAqc/X9UtILts8p7loq6ee1VAWgcVWP9v+jpLXFkf7nJF1bvSQAbagU/ojYLIn38sAMxBl+QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpKp+nx9TYNas0f9m2y61rrL9MP2x5QeSIvxAUoQfSKrqdF3/bHub7a2277B9TF2FAWhW6fDbni/pnyQtiYh3SJotaVVdhQFoVtXd/jmS5tmeo848fS9VLwlAG6r8bv+Lkr4k6XlJeyS9HBEP9LZjui5geqqy23+SpBWSzpD0VknH2b6mtx3TdQHTU5Xd/g9I+t+I+FVE7Jd0l6R31VMWgKZVCf/zki62faw7p4EtlbS9nrIANK3Ke/4N6kzOOSHpqeKxxmuqC0DDqk7X9RlJn6mpFgAt4gw/ICm+1TeF5s6dW6rf0UcfPXKfiCi1Lrx+seUHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0nxxZ4pNH/+/FL9Fi9ePHKfMlN8SeW/EDRv3ryR+5T9otP+/ftL9cuOLT+QFOEHkiL8QFJHDL/tW23vs721676Tba+3/WxxeVKzZQKo2zBb/h9KWtZz3xpJD0bE2ZIeLG4DmEGOGP6IeETSb3vuXiHptuL6bZJW1lsWgKaV/ajvzRGxR5IiYo/tNw1qaHtM0ljJ9QBoSOOf80fEuIrf87fNr0gC00TZo/17bZ8qScXlvvpKAtCGsuFfJ2l1cX21pHvqKQdAW4b5qO8OSf8t6Rzbu21/VNLnJF1q+1lJlxa3AcwgR3zPHxFXD1i0tOZaALSIM/yApPhW3xTau3dvqX6bNm0auc95551Xal1z5pR7iVx00UUj9znttNNKrWvnzp2l+mXHlh9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJMUXe6bQq6++Wqrf7bffPnKfZct6f4B5OAsWLCjV74QTThi5zzHHHFNqXSiHLT+QFOEHkiL8QFJlp+v6ou2nbT9p+27bJzZaJYDalZ2ua72kd0TEX0l6RtKna64LQMNKTdcVEQ9ExIHi5mOSyh0SBjBl6njP/xFJ9w9aaHvM9kbbG2tYF4CaVPqc3/ZNkg5IWjuoDdN1AdNT6fDbXi3pcklLI4JQAzNMqfDbXibpU5LeGxF/rLckAG0oO13XtyQdL2m97c22v9twnQBqVna6rlsaqAVAizjDD0iKb/XNQBs2bBi5z3XXXVdqXddee22pftu2bRu5z65du0qtC+Ww5QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICm3+Qtc/IYf0LyI8DDt2PIDSRF+IKlS03V1LfuE7bB9SjPlAWhK2em6ZHuhpEslPV9zTQBaUGq6rsJXJX1SEgfxgBmo7O/2XynpxYjYYk9+YNH2mKSxMusB0JyRw2/7WEk3SbpsmPZM1wVMT2WO9p8p6QxJW2zvUmeG3gnbb6mzMADNGnnLHxFPSXrTodvFH4AlEfHrGusC0LCy03UBmOE4vRd4neH0XgCTIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkSv2AZwW/lvR/A5adUiyfatRxOOo43HSv423DPkCrP+YxGdsbI2IJdVAHdbRTB7v9QFKEH0hqOoV/fKoLKFDH4ajjcK+bOqbNe34A7ZpOW34ALSL8QFKtht/2Mtu/sL3D9po+y237G8XyJ20vbqCGhbYfsr3d9jbbN/Rp8z7bL9veXPz7l7rr6FrXLttPFevZ2Gd5o2Ni+5yu/+dm26/YvrGnTWPjYftW2/tsb+2672Tb620/W1yeNKDvpK+nGur4ou2ni3G/2/aJA/pO+hzWUMdnbb/YNf7LB/QdbTwiopV/kmZL2ilpkaSjJG2R9Jc9bZZLul+SJV0saUMDdZwqaXFx/XhJz/Sp432S7m1pXHZJOmWS5Y2PSc9z9EtJb2trPCRdImmxpK1d931B0pri+hpJny/zeqqhjsskzSmuf75fHcM8hzXU8VlJnxjiuRtpPNrc8l8oaUdEPBcRr0m6U9KKnjYrJP0oOh6TdKLtU+ssIiL2RMREcf13krZLml/nOmrW+Jh0WSppZ0QMOguzdhHxiKTf9ty9QtJtxfXbJK3s03WY11OlOiLigYg4UNx8TJ1JaRs1YDyGMfJ4tBn++ZJe6Lq9W38eumHa1Mb26ZIukLShz+J32t5i+37bb2+qBkkh6QHbm2yP9Vne5pisknTHgGVtjYckvTki9kidP9bqmhi2S6uvFUkfUWcPrJ8jPYd1uL54+3HrgLdBI49Hm+HvN39Y7+eMw7Sphe03SPqppBsj4pWexRPq7Pr+taRvSvr3JmoovDsiFkv6G0n/YPuS3lL79Kl9TGwfJelKSf/WZ3Gb4zGsNl8rN0k6IGntgCZHeg6r+o6kMyWdL2mPpC/3K7PPfZOOR5vh3y1pYdftBZJeKtGmMttz1Qn+2oi4q3d5RLwSEb8vrt8naa7tU+quo3j8l4rLfZLuVmf3rVsrY6LOC3ciIvb2qbG18SjsPfTWprjc16dNW6+V1ZIul/S3Uby57jXEc1hJROyNiP+PiIOSvjfg8UcejzbD/4Sks22fUWxlVkla19NmnaS/L45wXyzp5UO7f3WxbUm3SNoeEV8Z0OYtRTvZvlCdcfpNnXUUj32c7eMPXVfnANPWnmaNj0nhag3Y5W9rPLqsk7S6uL5a0j192gzzeqrE9jJJn5J0ZUT8cUCbYZ7DqnV0H+P50IDHH3086jhCOcKRzOXqHF3fKemm4r6PSfpYcd2Svl0sf0rSkgZqeI86u0NPStpc/FveU8f1krapc8T0MUnvamg8FhXr2FKsb6rG5Fh1wnxC132tjIc6f3D2SNqvztbro5L+QtKDkp4tLk8u2r5V0n2TvZ5qrmOHOu+jD71Ovttbx6DnsOY6/rV47p9UJ9Cn1jEenN4LJMUZfkBShB9IivADSRF+ICnCDyRF+IGkCD+Q1J8AEyLHHw0YZgIAAAAASUVORK5CYII=", 171 | "text/plain": [ 172 | "
" 173 | ] 174 | }, 175 | "metadata": { 176 | "needs_background": "light" 177 | }, 178 | "output_type": "display_data" 179 | }, 180 | { 181 | "name": "stdout", 182 | "output_type": "stream", 183 | "text": [ 184 | "true label = y = 1.0\n", 185 | "\n", 186 | "(x_min, x_max) = (-1.0, 0.961)\n" 187 | ] 188 | } 189 | ], 190 | "source": [ 191 | "idx = np.random.choice(len(train_dataset))\n", 192 | "\n", 193 | "x = train_dataset[idx][0]\n", 194 | "print(f'x of {x.shape} :')\n", 195 | "plt.imshow(x[0], cmap='gray')\n", 196 | "plt.show()\n", 197 | "\n", 198 | "print(f'true label = y = {train_dataset[idx][1]}\\n')\n", 199 | "\n", 200 | "print(f'(x_min, x_max) = {x.min().item(), round(x.max().item(),3)}')" 201 | ] 202 | }, 203 | { 204 | "attachments": { 205 | "QNN2.png": { 206 | "image/png": "" 207 | } 208 | }, 209 | "cell_type": "markdown", 210 | "id": "519ac1b9", 211 | "metadata": {}, 212 | "source": [ 213 | "# Define model (QNN), training and test loops\n", 214 | "\n", 215 | "![QNN2.png](attachment:QNN2.png)\n", 216 | "\n", 217 | "### QNN = parameterized quantum circuit (PQC)\n", 218 | "\n", 219 | "\n", 220 | "\n", 221 | "A PQC is made of $L$ quantum layers. \n", 222 | "Here, each quantum layer $l$ has two parts: rotation $\\text{R}_y(\\theta^{[l]})$ in green and entangling $\\text{CXs}$ in blue color.\n", 223 | "\n", 224 | "-----------\n", 225 | "\n", 226 | "__The rotation part__ is made of the tensor-product \n", 227 | "$\\ \\text{R}_y(\\theta^{[l]}) = \\displaystyle\\bigotimes_{i} \\text{R}_y(\\theta^{[l]}_i)\\ $\n", 228 | "of single-qubit rotations (unitary gates) \n", 229 | "$\\ \\text{R}(\\theta) = \\cos(\\theta) I + \\texttt{i}\\sin(\\theta)\\,Y$ around $y$-axis.\n", 230 | "where $i\\in\\{1,\\cdots,n\\}$ is the qubit-index, the angles $\\theta^{[l]}=(\\theta^{[l]}_1,\\cdots,\\theta^{[l]}_n)$ of rotations are learnable parameters, and $n$ is the number of qubits.\n", 231 | "$I$ is the Identity and $X, Y$ are the Pauli operators.\n", 232 | "
\n", 233 | "\n", 234 | "\n", 235 | "__The entangling part__ is made of the tensor-product \n", 236 | "$\\text{CXs} = \\displaystyle\\bigotimes_{(i,j)} \\text{CX}_{(i,j)}$\n", 237 | "of two-qubit\n", 238 | "$\\text{CX}_{(i,j)} = |0\\rangle_i\\langle 0|\\otimes I_j +|1\\rangle_i\\langle 1|\\otimes X_j$ gates,\n", 239 | "where $(i,j)$ represents the pair of control and target qubits $i$ and $j$, respectively.\n", 240 | "This part will create entanglement between qubits and facilitate quantum information transfer between qubits. \n", 241 | "However, this part carries no learnable parameters.\n", 242 | "\n", 243 | "-------------\n", 244 | "\n", 245 | "__Input to the PQC__ is a $l_2$-normalized (quantum state) vector $|x\\rangle$ of $2^n$ components. The ket $|x\\rangle$ comes from the input vector $x$ after flattening a $\\text{size}\\times\\text{size}$ image from the data.\n", 246 | "Since $\\text{size}\\times\\text{size}=:2^n=\\text{dim}$ by definition and $\\text{size}=16$, $n=8$.\n", 247 | "\n", 248 | "\n", 249 | "\n", 250 | "__After the PQC__, we perform measurement on $|x,\\Theta\\rangle$ in the $z$-basis (computational-basis) and get a probability-vector $\\textbf{p}(x,\\Theta)=(p_0,\\cdots,p_{\\text{dim}})$. Then we get the marginal probability\n", 251 | "$\\sum_{i \\text{ is even}} p_i$. This whole procedure is equivalent to performing a $z$-measurement on the last qubit shown in red color and calculating the probability for the 0 outcome. With the marginal probability, we compute $loss(\\Theta|x)$. All the learnable parameters are made of $\\Theta=\\{\\theta^{[1]}, \\cdots, \\theta^{[L]}\\}$ associated with the quantum layer.\n", 252 | "\n", 253 | "\n", 254 | "\n", 255 | "For details on the parameterized quantum circuit, see https://iopscience.iop.org/article/10.1088/2058-9565/ab4eb5\n", 256 | " \n" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 5, 262 | "id": "077f1b6b", 263 | "metadata": { 264 | "scrolled": true 265 | }, 266 | "outputs": [ 267 | { 268 | "data": { 269 | "text/plain": [ 270 | "tensor([1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,\n", 271 | " 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,\n", 272 | " 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,\n", 273 | " 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,\n", 274 | " 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,\n", 275 | " 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,\n", 276 | " 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,\n", 277 | " 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,\n", 278 | " 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,\n", 279 | " 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,\n", 280 | " 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,\n", 281 | " 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,\n", 282 | " 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,\n", 283 | " 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,\n", 284 | " 1., 0., 1., 0.])" 285 | ] 286 | }, 287 | "execution_count": 5, 288 | "metadata": {}, 289 | "output_type": "execute_result" 290 | } 291 | ], 292 | "source": [ 293 | "\n", 294 | "n = int(2*np.log2(size))\n", 295 | "dim = 2**n # dimension of the n-qubit Hilbert space\n", 296 | "\n", 297 | "\n", 298 | "#--------------------------------------------------------------------------------------\n", 299 | "\n", 300 | "'''0 outcome projector on the last qubit'''\n", 301 | "\n", 302 | "last_qubit_proj = torch.tensor([(1 + (-1)**i)/2 for i in range(dim)], dtype=torch.float32)\n", 303 | "last_qubit_proj\n" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": 6, 309 | "id": "2905df86", 310 | "metadata": {}, 311 | "outputs": [ 312 | { 313 | "name": "stdout", 314 | "output_type": "stream", 315 | "text": [ 316 | "Using cpu device\n", 317 | "\n" 318 | ] 319 | } 320 | ], 321 | "source": [ 322 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\" # Get gpu or cpu device for training\n", 323 | "print(f\"Using {device} device\\n\")\n", 324 | "\n", 325 | "#=====================================================================================\n", 326 | "\n", 327 | "\n", 328 | "class QNN(torch.nn.Module): # Define model\n", 329 | " def __init__(self, n, L): # number of qubits = n, number of quantum layers = L\n", 330 | " super().__init__()\n", 331 | " \n", 332 | " self.flatten = nn.Flatten()\n", 333 | " \n", 334 | " angles = torch.empty((L, n), dtype=torch.float64)\n", 335 | " torch.nn.init.uniform_(angles, -0.01, 0.01)\n", 336 | " self.angles = torch.nn.Parameter(angles) # it makes angles learnable parameters\n", 337 | " \n", 338 | "\n", 339 | " def forward(self, x):\n", 340 | " x = self.flatten(x)\n", 341 | " x /= torch.linalg.norm(x.clone(), ord=2, dim=1, keepdim=True) # L2 normalization to change x --> |x⟩\n", 342 | " \n", 343 | " '''initializing parameterized quantum circuits (PQC)'''\n", 344 | " qc = quantum_circuit(num_qubits = n, state_vector = x.T) # each column is a feature-vector of an example\n", 345 | " for l in range(L):\n", 346 | " qc.Ry_layer(self.angles[l].to(torch.cfloat)) # rotation part of lth quantum layer\n", 347 | " qc.cx_linear_layer() # entangling part of lth quantum layer\n", 348 | "\n", 349 | " 'after passing through the PQC, measurement on the output-ket in the computational basis'\n", 350 | " x = torch.real(qc.probabilities()) # each column is a probabilities-vector for an example \n", 351 | " # x.shape = (dim, batch size)\n", 352 | " \n", 353 | " #print(torch.sum(x, dim=0)) # to see whether probabilities add up to 1 or not\n", 354 | "\n", 355 | " x = torch.matmul(x.T, last_qubit_proj) # probability of getting 0 outcome on the last qubit\n", 356 | " return x \n", 357 | "\n" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": 7, 363 | "id": "eec9caac", 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [ 367 | "\n", 368 | "def performance_estimate(dataset, model, loss_fn, train_or_test):\n", 369 | " '''this function computes accuracy and loss of a model on the training or test set'''\n", 370 | " data_size = len(dataset)\n", 371 | " \n", 372 | " dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)\n", 373 | " num_batches = len(dataloader)\n", 374 | " \n", 375 | " model.eval()\n", 376 | " loss, accuracy = 0, 0\n", 377 | " with torch.no_grad():\n", 378 | " for X, y in dataloader:\n", 379 | " X, y = X.to(device), y.to(device)\n", 380 | " pred = model(X)\n", 381 | " loss += loss_fn(pred, y).item()\n", 382 | " \n", 383 | " pred = torch.stack([1-pred, pred]).T\n", 384 | " accuracy += (pred.argmax(1) == y).sum().item() \n", 385 | " accuracy /= data_size # accuracy lies in the interval [0, 1] \n", 386 | " loss /= num_batches\n", 387 | " print(f\"{train_or_test} accuracy: {round(accuracy, 3)}, {train_or_test} loss: {round(loss,3)}\")\n", 388 | " return accuracy, loss\n", 389 | "\n", 390 | "\n", 391 | "\n", 392 | "\n", 393 | "def one_epoch(model, loss_fn, optimizer, dataset, batch_size):\n", 394 | " \n", 395 | " A_train, L_train, A_test, L_test = [], [], [], []\n", 396 | "\n", 397 | " dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)\n", 398 | " \n", 399 | " model.train()\n", 400 | " for batch, (X, y) in enumerate(dataloader):\n", 401 | " X, y = X.to(device), y.to(device)\n", 402 | "\n", 403 | " out = model(X) # Perform a single forward pass \n", 404 | " loss = loss_fn(out, y) \n", 405 | " \n", 406 | " optimizer.zero_grad() # Clear gradients\n", 407 | " loss.backward() # Derive gradients, backpropagation\n", 408 | " optimizer.step() # Update parameters based on gradients\n", 409 | " \n", 410 | " \n", 411 | " if batch % batch_size == 0: \n", 412 | " #As training progress, computing and appending loss and accuracy of the model on train and test set\n", 413 | " accuracy_train, loss_train = performance_estimate(train_dataset, model, loss_fn, 'train')\n", 414 | " accuracy_test, loss_test = performance_estimate(test_dataset, model, loss_fn, 'test ')\n", 415 | " print()\n", 416 | " \n", 417 | " A_train.append(accuracy_train) \n", 418 | " L_train.append(loss_train)\n", 419 | " A_test.append(accuracy_test)\n", 420 | " L_test.append(loss_test)\n", 421 | " \n", 422 | " #print(f\"train loss: {round(loss,3)}\")\n", 423 | " \n", 424 | " return A_train, L_train, A_test, L_test \n", 425 | "\n", 426 | "\n", 427 | "\n", 428 | "\n", 429 | "'''Binary Cross Entropy Loss''' \n", 430 | "\n", 431 | "def training(dataset, batch_size, n, L, lr_, weight_decay_, epochs):\n", 432 | " \n", 433 | " model = QNN(n=n, L=L).to(device)\n", 434 | " loss_fn = nn.BCELoss()\n", 435 | " optimizer = torch.optim.Adam(model.parameters(), lr=lr_, weight_decay=weight_decay_)\n", 436 | " \n", 437 | " A_Train, L_Train, A_Test, L_Test = [], [], [], []\n", 438 | " for t in range(epochs): \n", 439 | " print(f\"Epoch {t+1} ---------------------------------- \\n\")\n", 440 | " #As training progress, computing and appending loss and accuracy of the model on train and test set\n", 441 | " A_train, L_train, A_test, L_test = one_epoch(model, loss_fn, optimizer, dataset, batch_size)\n", 442 | " A_Train += A_train\n", 443 | " L_Train += L_train \n", 444 | " A_Test += A_test\n", 445 | " L_Test += L_test\n", 446 | " \n", 447 | " #accuracy, loss = performance_estimate(test_dataset, model, loss_fn, 'test ')\n", 448 | " \n", 449 | " model_state_dict = model.state_dict() # for saving or loading the trained model\n", 450 | " \n", 451 | " return A_Train, L_Train, A_Test, L_Test, model_state_dict\n", 452 | " " 453 | ] 454 | }, 455 | { 456 | "cell_type": "markdown", 457 | "id": "834b1b08", 458 | "metadata": {}, 459 | "source": [ 460 | "# training..." 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": 8, 466 | "id": "fb62ba99", 467 | "metadata": { 468 | "scrolled": false 469 | }, 470 | "outputs": [ 471 | { 472 | "name": "stdout", 473 | "output_type": "stream", 474 | "text": [ 475 | "number of qubits = 8\n", 476 | "number of quantum layers = 3\n", 477 | "number of angles (learnable parameters of quantum circuit) = 24\n", 478 | " \n", 479 | "batch_size = 64\n", 480 | "\n", 481 | "Epoch 1 ---------------------------------- \n", 482 | "\n", 483 | "train accuracy: 0.629, train loss: 0.681\n", 484 | "test accuracy: 0.652, test loss: 0.679\n", 485 | "\n", 486 | "train accuracy: 0.826, train loss: 0.652\n", 487 | "test accuracy: 0.841, test loss: 0.649\n", 488 | "\n", 489 | "train accuracy: 0.82, train loss: 0.622\n", 490 | "test accuracy: 0.833, test loss: 0.618\n", 491 | "\n", 492 | "train accuracy: 0.861, train loss: 0.582\n", 493 | "test accuracy: 0.864, test loss: 0.577\n", 494 | "\n", 495 | "Epoch 2 ---------------------------------- \n", 496 | "\n", 497 | "train accuracy: 0.865, train loss: 0.579\n", 498 | "test accuracy: 0.871, test loss: 0.574\n", 499 | "\n", 500 | "train accuracy: 0.944, train loss: 0.543\n", 501 | "test accuracy: 0.948, test loss: 0.538\n", 502 | "\n", 503 | "train accuracy: 0.967, train loss: 0.521\n", 504 | "test accuracy: 0.967, test loss: 0.515\n", 505 | "\n", 506 | "train accuracy: 0.929, train loss: 0.509\n", 507 | "test accuracy: 0.93, test loss: 0.505\n", 508 | "\n", 509 | "Epoch 3 ---------------------------------- \n", 510 | "\n", 511 | "train accuracy: 0.928, train loss: 0.508\n", 512 | "test accuracy: 0.929, test loss: 0.497\n", 513 | "\n", 514 | "train accuracy: 0.921, train loss: 0.501\n", 515 | "test accuracy: 0.926, test loss: 0.496\n", 516 | "\n", 517 | "train accuracy: 0.913, train loss: 0.497\n", 518 | "test accuracy: 0.916, test loss: 0.495\n", 519 | "\n", 520 | "train accuracy: 0.914, train loss: 0.494\n", 521 | "test accuracy: 0.915, test loss: 0.491\n", 522 | "\n", 523 | " ~~~~~ training is done ~~~~~\n", 524 | "\n", 525 | "CPU times: user 7min 14s, sys: 183 ms, total: 7min 14s\n", 526 | "Wall time: 1min 48s\n" 527 | ] 528 | } 529 | ], 530 | "source": [ 531 | "%%time\n", 532 | "\n", 533 | "\n", 534 | "L = 3\n", 535 | "\n", 536 | "n_angs = n*L\n", 537 | "\n", 538 | "print(\"number of qubits = \", n)\n", 539 | "print(\"number of quantum layers = \", L)\n", 540 | "print(f\"number of angles (learnable parameters of quantum circuit) = {n_angs}\\n \")\n", 541 | "\n", 542 | "#--------------------------------------------------------------------------------------\n", 543 | "\n", 544 | "\n", 545 | "batch_size = 64\n", 546 | "print(f'batch_size = {batch_size}\\n')\n", 547 | "\n", 548 | "\n", 549 | "#----------------------------------------------------------------------------------\n", 550 | "\n", 551 | "\n", 552 | "A_Train, L_Train, A_Test, L_Test, model_state_dict = training(train_dataset, batch_size=batch_size, n=n, L=L,\n", 553 | " lr_=1e-3, weight_decay_=1e-8, epochs=3)\n", 554 | "\n", 555 | "\n", 556 | "print(f' ~~~~~ training is done ~~~~~\\n')" 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "execution_count": 9, 562 | "id": "9f9a6522", 563 | "metadata": {}, 564 | "outputs": [ 565 | { 566 | "data": { 567 | "image/png": "", 568 | "text/plain": [ 569 | "
" 570 | ] 571 | }, 572 | "metadata": { 573 | "needs_background": "light" 574 | }, 575 | "output_type": "display_data" 576 | }, 577 | { 578 | "data": { 579 | "image/png": "", 580 | "text/plain": [ 581 | "
" 582 | ] 583 | }, 584 | "metadata": { 585 | "needs_background": "light" 586 | }, 587 | "output_type": "display_data" 588 | } 589 | ], 590 | "source": [ 591 | "plt.plot(A_Train, label='train set')\n", 592 | "plt.plot(A_Test, label='test set')\n", 593 | "plt.ylabel('accuracy')\n", 594 | "plt.legend()\n", 595 | "plt.show()\n", 596 | "\n", 597 | "plt.plot(L_Train, label='train set')\n", 598 | "plt.plot(L_Test, label='test set')\n", 599 | "plt.ylabel('loss')\n", 600 | "plt.legend()\n", 601 | "plt.show()" 602 | ] 603 | }, 604 | { 605 | "cell_type": "markdown", 606 | "id": "416b30e5", 607 | "metadata": {}, 608 | "source": [ 609 | "## save model" 610 | ] 611 | }, 612 | { 613 | "cell_type": "code", 614 | "execution_count": 10, 615 | "id": "988bcfc3", 616 | "metadata": {}, 617 | "outputs": [], 618 | "source": [ 619 | "torch.save(model_state_dict, \"model_MNIST_QNN2.pth\")" 620 | ] 621 | }, 622 | { 623 | "cell_type": "markdown", 624 | "id": "bdf78969", 625 | "metadata": {}, 626 | "source": [ 627 | "## load model" 628 | ] 629 | }, 630 | { 631 | "cell_type": "code", 632 | "execution_count": 11, 633 | "id": "3eb4ede1", 634 | "metadata": {}, 635 | "outputs": [ 636 | { 637 | "data": { 638 | "text/plain": [ 639 | "" 640 | ] 641 | }, 642 | "execution_count": 11, 643 | "metadata": {}, 644 | "output_type": "execute_result" 645 | } 646 | ], 647 | "source": [ 648 | "model = QNN(n=n, L=L).to(device)\n", 649 | "model.load_state_dict(model_state_dict)\n", 650 | "\n", 651 | "#model.load_state_dict(torch.load(\"model_MNIST_QNN.pth\"))" 652 | ] 653 | }, 654 | { 655 | "cell_type": "markdown", 656 | "id": "52a30c42", 657 | "metadata": {}, 658 | "source": [ 659 | "## predict test examples" 660 | ] 661 | }, 662 | { 663 | "cell_type": "code", 664 | "execution_count": 12, 665 | "id": "b5be82d3", 666 | "metadata": { 667 | "scrolled": false 668 | }, 669 | "outputs": [ 670 | { 671 | "name": "stdout", 672 | "output_type": "stream", 673 | "text": [ 674 | "x of torch.Size([1, 16, 16]) :\n" 675 | ] 676 | }, 677 | { 678 | "data": { 679 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD4CAYAAAAjDTByAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAANoklEQVR4nO3dbYxc5XnG8euqjZ0aQzG1SAi2AgSE1EQFLAuRpIKolMi4Fk4EH4xIu8QRS1BpoWoUO0Jq8gWpaUr6rkRuDAFqgVQHEhNCi2USRUXB8uKuwc4SMIaCwbHjBoHbIBHjux/mWBpPZtY758273P+fZM3LOc88t5+Za8+Zc2bmcUQIQD6/caILAHBiEH4gKcIPJEX4gaQIP5DU7DY7s82pBaBhEeGprMeWH0iK8ANJEX4gqUrht73M9k9t77a9tq6iADTPZT/ea3uWpOckXSlpr6Rtkq6LiJ9M0oYDfkDD2jjgd4mk3RGxJyLelvSApJUVHg9Ai6qE/yxJr3Td3lvcdwzbo7bHbI9V6AtAzaqc5++3a/Fru/URsU7SOondfmA6qbLl3ytpcdftRZJeq1YOgLZUCf82SefbPsf2HEmrJG2qpywATSu92x8Rh23fIuk/JM2SdFdE7KqtMgCNKn2qr1RnvOcHGsdn+wFMivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5BUq9N1AZO5/vrrS7VbsGDB0G3Wr19fqq+33nqrVLvpiC0/kBThB5Ii/EBSpcNve7HtH9iesL3L9q11FgagWVUO+B2W9BcRsd32KZKesr15sum6AEwfpbf8EbEvIrYX1w9JmlCfGXsATE+1nOqzfbakiyVt7bNsVNJoHf0AqE/l8NueL+nbkm6LiDd7lzNdFzA9VTrab/skdYK/ISIerKckAG2ocrTfktZLmoiIr9VXEoA2VNnyf0zSH0n6fdvjxb/lNdUFoGFV5ur7T/WfphvADMAn/ICk+FYfGrFw4cKh24yMjJTqa//+/UO3ue+++0r1xbf6AMx4hB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKb7Yg0nNnTu3VLvVq1cP3WbJkiWl+rr55puHbnPo0KFSfb2bsOUHkiL8QFKEH0iqcvhtz7L9X7a/V0dBANpRx5b/VnVm6wEwg1T93f5Fkv5Q0jfrKQdAW6pu+f9O0hckHaleCoA2VZm0Y4WkAxHx1HHWG7U9ZnusbF8A6ld10o6rbb8k6QF1Ju/4196VImJdRCyNiKUV+gJQsypTdH8xIhZFxNmSVkl6PCI+XVtlABrFeX4gqVo+2x8RP5T0wzoeC0A72PIDSfGtPkzqvPPOK9XupptuGrrN448/XqqvLVu2DN3myBHOTrPlB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKb7Vl8T8+fNLtbvxxhtLtZszZ87Qbe68885Sfb3++uul2mXHlh9IivADSRF+IKmqM/acZnuj7WdtT9j+SF2FAWhW1QN+fy/p3yPiWttzJM2roSYALSgdftunSrpM0g2SFBFvS3q7nrIANK3Kbv+5kn4u6e5iiu5v2j65dyWm6wKmpyrhny1piaSvR8TFkv5P0trelZiuC5ieqoR/r6S9EbG1uL1RnT8GAGaAKnP1/UzSK7YvKO66QtJPaqkKQOOqHu3/U0kbiiP9eyR9pnpJANpQKfwRMS6J9/LADMQXe2aguXPnDt3m2muvLdXXihUrSrW7++67h24zNlbuhFBElGqXHR/vBZIi/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKb7VNwNdeOGFQ7dZs2ZNqb4OHjxYqt299947dJt33nmnVF8ohy0/kBThB5Ii/EBSVafr+nPbu2zvtH2/7ffUVRiAZpUOv+2zJP2ZpKUR8WFJsyStqqswAM2quts/W9Jv2p6tzjx9r1UvCUAbqvxu/6uS/kbSy5L2SXojIh7rXY/puoDpqcpu/wJJKyWdI+n9kk62/ene9ZiuC5iequz2/4GkFyPi5xHxK0kPSvpoPWUBaFqV8L8s6VLb82xbnem6JuopC0DTqrzn36rO5JzbJT1TPNa6muoC0LCq03V9SdKXaqoFQIv4hB+QFN/qO4Fmzy43/Ndcc83QbU499dRSfd1xxx2l2r344oul2qE9bPmBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFJ8sacGnd8yGd5FF11Uqt0NN9wwdJtHHnmkVF8PP/xwqXZMvTX9seUHkiL8QFKEH0jquOG3fZftA7Z3dt13uu3Ntp8vLhc0WyaAuk1ly/8tSct67lsraUtEnC9pS3EbwAxy3PBHxI8k/aLn7pWS7imu3yPpk/WWBaBpZU/1vTci9klSROyzfcagFW2PShot2Q+AhjR+nj8i1qn4PX/b0XR/AKam7NH+/bbPlKTi8kB9JQFoQ9nwb5I0UlwfkfTdesoB0JapnOq7X9KPJV1ge6/tz0r6K0lX2n5e0pXFbQAzyHHf80fEdQMWXVFzLQBaxCf8gKT4Vl8NZs2aVardVVddVardGWcMPLM60LZt20r1dejQoVLtMP2x5QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSfHFnhrMmzevVLvLL7+8VLs9e/YM3eaJJ54o1deRI0dKtcP0x5YfSIrwA0kRfiCpstN1fdX2s7aftv2Q7dMarRJA7cpO17VZ0ocj4nclPSfpizXXBaBhpabriojHIuJwcfNJSYsaqA1Ag+p4z79a0qODFtoetT1me6yGvgDUpNJ5ftu3SzosacOgdZiuC5ieSoff9oikFZKuiAhCDcwwpcJve5mkNZIuj4hf1lsSgDaUna7rnySdImmz7XHb32i4TgA1Kztd1/oGagHQIj7hByTFt/pqUPabb+Pj46Xabdy4ceg2ExMTpfrCuxdbfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHknKbv8DFb/gBzYsIT2U9tvxAUoQfSKrUdF1dyz5vO2wvbKY8AE0pO12XbC+WdKWkl2uuCUALSk3XVfhbSV+QxEE8YAYq+7v9V0t6NSJ22JMfWLQ9Kmm0TD8AmjN0+G3Pk3S7pE9MZX2m6wKmpzJH+z8o6RxJO2y/pM4Mvdttv6/OwgA0a+gtf0Q8I+mMo7eLPwBLI+JgjXUBaFjZ6boAzHB8vBd4l+HjvQAmRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJlfoBzwoOSvrvAcsWFstPNOo4FnUca7rX8YGpPkCrP+YxGdtjEbGUOqiDOtqpg91+ICnCDyQ1ncK/7kQXUKCOY1HHsd41dUyb9/wA2jWdtvwAWkT4gaRaDb/tZbZ/anu37bV9ltv2PxTLn7a9pIEaFtv+ge0J27ts39pnnY/bfsP2ePHvL+uuo6uvl2w/U/Qz1md5o2Ni+4Ku/+e47Tdt39azTmPjYfsu2wds7+y673Tbm20/X1wuGNB20tdTDXV81fazxbg/ZPu0AW0nfQ5rqOPLtl/tGv/lA9oONx4R0co/SbMkvSDpXElzJO2Q9Ds96yyX9KgkS7pU0tYG6jhT0pLi+imSnutTx8clfa+lcXlJ0sJJljc+Jj3P0c8kfaCt8ZB0maQlknZ23ffXktYW19dK+kqZ11MNdXxC0uzi+lf61TGV57CGOr4s6fNTeO6GGo82t/yXSNodEXsi4m1JD0ha2bPOSkn3RseTkk6zfWadRUTEvojYXlw/JGlC0ll19lGzxsekyxWSXoiIQZ/CrF1E/EjSL3ruXinpnuL6PZI+2afpVF5PleqIiMci4nBx80l1JqVt1IDxmIqhx6PN8J8l6ZWu23v166Gbyjq1sX22pIslbe2z+CO2d9h+1PaHmqpBUkh6zPZTtkf7LG9zTFZJun/AsrbGQ5LeGxH7pM4fa3VNDNul1deKpNXq7IH1c7znsA63FG8/7hrwNmjo8Wgz/P3mD+s9zziVdWphe76kb0u6LSLe7Fm8XZ1d3wsl/aOk7zRRQ+FjEbFE0lWS/sT2Zb2l9mlT+5jYniPpakn/1mdxm+MxVW2+Vm6XdFjShgGrHO85rOrrkj4o6SJJ+yTd2a/MPvdNOh5thn+vpMVdtxdJeq3EOpXZPkmd4G+IiAd7l0fEmxHxv8X170s6yfbCuusoHv+14vKApIfU2X3r1sqYqPPC3R4R+/vU2Np4FPYffWtTXB7os05br5URSSskXR/Fm+teU3gOK4mI/RHxTkQckfQvAx5/6PFoM/zbJJ1v+5xiK7NK0qaedTZJ+uPiCPelkt44uvtXF9uWtF7SRER8bcA67yvWk+1L1Bmn/6mzjuKxT7Z9ytHr6hxg2tmzWuNjUrhOA3b52xqPLpskjRTXRyR9t886U3k9VWJ7maQ1kq6OiF8OWGcqz2HVOrqP8XxqwOMPPx51HKEc4kjmcnWOrr8g6fbivs9J+lxx3ZL+uVj+jKSlDdTwe+rsDj0tabz4t7ynjlsk7VLniOmTkj7a0HicW/Sxo+jvRI3JPHXC/Ftd97UyHur8wdkn6VfqbL0+K+m3JW2R9HxxeXqx7vslfX+y11PNdexW53300dfJN3rrGPQc1lzHfcVz/7Q6gT6zjvHg471AUnzCD0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeS+n9RTeyszw5CSgAAAABJRU5ErkJggg==", 680 | "text/plain": [ 681 | "
" 682 | ] 683 | }, 684 | "metadata": { 685 | "needs_background": "light" 686 | }, 687 | "output_type": "display_data" 688 | }, 689 | { 690 | "name": "stdout", 691 | "output_type": "stream", 692 | "text": [ 693 | "true label = y = 1.0\n", 694 | "\n", 695 | "predicted label = 1\n", 696 | "\n" 697 | ] 698 | }, 699 | { 700 | "data": { 701 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAV8klEQVR4nO3df7RdZX3n8feHQGr8iUIcMAkl2ojFAoJXhFlW0amLAFODP4vasvw1mOkgdnVJQVvtdHRGXUxbS0UzWSyWVkejo2madqJZjqNgi2gSA8HAik1jgUt0uFgRf2QVAt/54+zo8XKSe5Lcfa737vdrrbvu3s9+zj7fhxvO5+x99nl2qgpJUncdMdMFSJJmlkEgSR1nEEhSxxkEktRxBoEkddyRM13AwTr22GPrxBNPnOkyJGlW2bJly71VtXDQtlkXBCeeeCKbN2+e6TIkaVZJcsf+tnlqSJI6ziCQpI4zCCSp41oNgiTLk+xIsjPJlfvpc06Sm5NsT3J9m/VIkh6ptQ+Lk8wDrgFeDIwDm5Ksr6rb+vocDXwIWF5VdyZ5clv1SJIGa/OqoTOBnVW1CyDJGmAFcFtfn9cAa6vqToCquqfFeiRpVlq39W6u2riD3fft4SlHL+Dyc0/iwtMXTdv+2zw1tAi4q299vGnr93TgiUm+nGRLkosH7SjJJUk2J9k8MTHRUrmS9Itn3da7efvaW7n7vj0UcPd9e3j72ltZt/XuaXuONoMgA9omz3l9JPBs4ALgXOCdSZ7+iAdVra6qsaoaW7hw4PchJGlOumrjDvY8+NDPte158CGu2rhj2p6jzVND48CSvvXFwO4Bfe6tqh8DP05yA3Aa8K0W65KkWWP3fXsOqv1QtHlEsAlYlmRpkvnARcD6SX3+Bvj1JEcmeTTwXOD2FmuSpFnlKUcvOKj2Q9FaEFTVXuBSYCO9F/dPV9X2JCuTrGz63A58HtgGfB24tqq+2VZNkjTbXH7uSSw4at7PtS04ah6Xn3vStD1HZtutKsfGxsq5hiR1ybqtd/MHn9nGAw89zKJDvGooyZaqGhu0bdZNOidJXXPh6Yv45NfvBOBTbz572vfvFBOS1HEGgSR1nEEgSR1nEEhSxxkEktRxBoEkdZxBIEkdZxBIUscZBJLUcQaBJHWcQSBJHWcQSFLHGQSS1HEGgSR1nEEgSR1nEEhSxxkEktRxBoEkdZxBIEkdZxBIUscZBJLUcQaBJHWcQSBJHddqECRZnmRHkp1Jrhyw/ZwkP0hyc/PzrjbrkSQ90pFt7TjJPOAa4MXAOLApyfqqum1S169U1b9vqw5J0oG1eURwJrCzqnZV1QPAGmBFi88nSToEbQbBIuCuvvXxpm2ys5PckuRzSZ45aEdJLkmyOcnmiYmJNmqVpM5qMwgyoK0mrX8D+OWqOg34S2DdoB1V1eqqGquqsYULF05vlZLUcW0GwTiwpG99MbC7v0NV3V9VP2qWNwBHJTm2xZokSZO0GQSbgGVJliaZD1wErO/vkOS4JGmWz2zq+V6LNUmSJmntqqGq2pvkUmAjMA+4rqq2J1nZbF8FvAL4j0n2AnuAi6pq8ukjSVKLWgsC+Onpng2T2lb1LX8Q+GCbNUiSDsxvFktSxxkEktRxBoEkdZxBIEkdZxBIUscZBJLUcQaBJHWcQSBJHWcQSFLHGQSS1HEGgSR1nEEgSR1nEEhSxxkEktRxBoEkdZxBIEkdZxBIUscZBJLUcQaBJHWcQSBJHWcQSFLHGQSS1HFHznQBo7Bu691ctXEHu+/bw1OOXsDl557EhacvmumyJOkXwpwPgnVb7+bta29lz4MPAXD3fXt4+9pbAQwDSaLlU0NJlifZkWRnkisP0O85SR5K8orpruGqjTt+GgL77HnwIa7auGO6n0qSZqXWgiDJPOAa4DzgZODVSU7eT7/3AxvbqGP3fXsOql2SuqbNI4IzgZ1VtauqHgDWACsG9HsL8FngnjaKeMrRCw6qXZK6ps0gWATc1bc+3rT9VJJFwEuBVQfaUZJLkmxOsnliYuKgirj83JNYcNS8n2tbcNQ8Lj/3pIPajyTNVW0GQQa01aT1DwBXVNVDA/r+7EFVq6tqrKrGFi5ceFBFXHj6It77slOYP6831EVHL+C9LzvFD4olqdHmVUPjwJK+9cXA7kl9xoA1SQCOBc5Psreq1k1nIReevohPfv1OAD715rOnc9eSNOu1GQSbgGVJlgJ3AxcBr+nvUFVL9y0n+Qjwd9MdApKkA2stCKpqb5JL6V0NNA+4rqq2J1nZbD/g5wKSpNFo9QtlVbUB2DCpbWAAVNXr2qxFkjSYcw1JUscZBJLUcQaBJHWcQSBJHWcQSFLHDRUEST6b5IIkBockzTHDvrB/mN6Xwf4xyfuSPKPFmiRJIzRUEFTV/6mq1wJnAP8MfCHJjUlen+SoNguUJLVr6FM9SY4BXge8CdgK/AW9YPhCK5VJkkZiqG8WJ1kLPAP4GPCbVfWdZtOnkmxuqzhJUvuGnWLi2ma6iJ9K8ktV9a9VNdZCXZKkERn21NB7BrR9dToLkSTNjAMeESQ5jt5dxRYkOZ2f3Wzm8cCjW65NkjQCU50aOpfeB8SLgT/ra/8h8I6WapIkjdABg6CqPgp8NMnLq+qzI6pJkjRCU50a+u2q+jhwYpLfn7y9qv5swMMkSbPIVKeGHtP8fmzbhUiSZsZUp4b+R/P7T0ZTjiRp1KY6NXT1gbZX1WXTW44kadSmOjW0ZSRVSJJmzDBXDUmS5rCpTg19oKp+L8nfAjV5e1W9pLXKJEkjMdWpoY81v/9724VIkmbGVKeGtjS/r08yn94MpAXsqKoHRlCfJKllw05DfQGwCvgnevMNLU3y5qr6XJvFSZLaN+zso38KvLCqzqmqFwAvBP58qgclWZ5kR5KdSa4csH1Fkm1Jbk6yOcnzDq58SdLhGvZ+BPdU1c6+9V3APQd6QJJ5wDXAi4FxYFOS9VV1W1+3LwLrq6qSnAp8mt7pJ0nSiEx11dDLmsXtSTbQe6Eu4JXApin2fSaws6p2NftaA6wAfhoEVfWjvv6PYcCVSZKkdk11RPCbfcv/D3hBszwBPHGKxy4C7upbHweeO7lTkpcC7wWeDFwwaEdJLgEuATjhhBOmeFpJ0sGY6qqh1x/GvjOgbdB3Ef4a+OskzwfeDfzGgD6rgdUAY2NjHjVI0jQa9qqhRwFvBJ4JPGpfe1W94QAPGweW9K0vBnbvr3NV3ZDkaUmOrap7h6lLknT4hr1q6GPAcfTuWHY9vRf1H07xmE3AsiRLm+8gXASs7++Q5FeSpFk+A5gPfG/48iVJh2vYq4Z+papemWRFVX00ySeAjQd6QFXtTXJp028ecF1VbU+ystm+Cng5cHGSB4E9wG9Vlad+JGmEhg2CB5vf9yX5NeC7wIlTPaiqNgAbJrWt6lt+P/D+IWuQJLVg2CBYneSJwDvpnd55bLMsSZrlhgqCqrq2WbweeGp75UiSRm2oD4uTHJPkL5N8I8mWJB9IckzbxUmS2jfsVUNr6E0p8XLgFcC9wKfaKkqSNDrDfkbwpKp6d9/6e5Jc2EI9kqQRG/aI4EtJLkpyRPPzKuB/t1mYJGk0ppp07of0poUI8PvAx5tNRwA/Av641eokSa2baq6hx42qEEnSzBj2MwKSvAR4frP65ar6u3ZKkiSN0rCXj74PeCu9ewncBry1aZMkzXLDHhGcDzyrqh4GSPJRYCvwiNtPSpJml2GvGgI4um/5CdNchyRphgx7RPDfgK1JvkTvCqLnA29vrSpJ0shMGQRJjgAeBs4CnkMvCK6oqu+2XJskaQSmDIKqejjJpVX1aSbdWEaSNPsN+xnBF5K8LcmSJE/a99NqZZKkkRj2M4I30PuG8e9OandKakma5YYNgpPphcDz6AXCV4BVB3yEJGlWGDYIPgrcD1zdrL+6aXtVG0VJkkZn2CA4qapO61v/UpJb2ihIkjRaw35YvDXJWftWkjwX+Id2SpIkjdKwRwTPBS5OcmezfgJwe5JbgaqqU1upTpLUumGDYHmrVUiSZsxQQVBVd7RdiCRpZhzMpHMHLcnyJDuS7EzyiJlKk7w2ybbm58Ykpw3ajySpPa0FQZJ5wDXAefS+h/DqJCdP6vZt4AXNZwzvBla3VY8kabA2jwjOBHZW1a6qegBYA6zo71BVN1bV95vVm4DFLdYjSRqgzSBYBNzVtz7etO3PG4HPDdqQ5JIkm5NsnpiYmMYSJUltBkEGtNXAjskL6QXBFYO2V9XqqhqrqrGFCxdOY4mSpKFvXn8IxoElfeuLgd2TOyU5FbgWOK+qvtdiPZKkAdo8ItgELEuyNMl84CIm3c8gyQnAWuB3qupbLdYiSdqP1o4IqmpvkkuBjcA84Lqq2p5kZbN9FfAu4BjgQ0kA9lbVWFs1SZIeqc1TQ1TVBmDDpLZVfctvAt7UZg2SpANr9QtlkqRffAaBJHWcQSBJHWcQSFLHGQSS1HEGgSR1nEEgSR1nEEhSxxkEktRxBoEkdZxBIEkdZxBIUscZBJLUcQaBJHWcQSBJHWcQSFLHGQSS1HEGgSR1nEEgSR1nEEhSxxkEktRxBoEkdZxBIEkdZxBIUse1GgRJlifZkWRnkisHbH9Gkq8m+dckb2uzFknSYEe2teMk84BrgBcD48CmJOur6ra+bv8CXAZc2FYdkqQDa/OI4ExgZ1XtqqoHgDXAiv4OVXVPVW0CHmyxDknSAbQZBIuAu/rWx5s2SdIvkDaDIAPa6pB2lFySZHOSzRMTE4dZliSpX5tBMA4s6VtfDOw+lB1V1eqqGquqsYULF05LcZKknjaDYBOwLMnSJPOBi4D1LT6fJOkQtHbVUFXtTXIpsBGYB1xXVduTrGy2r0pyHLAZeDzwcJLfA06uqvvbqkuS9PNaCwKAqtoAbJjUtqpv+bv0ThlJkmaI3yyWpI4zCCSp4wwCSeo4g0CSOs4gkKSOMwgkqeMMAknqOINAkjrOIJCkjjMIJKnjDAJJ6jiDQJI6ziCQpI4zCCSp4wwCSeo4g0CSOs4gkKSOMwgkqeMMAknqOINAkjrOIJCkjjMIJKnjDAJJ6jiDQJI6ziCQpI5rNQiSLE+yI8nOJFcO2J4kVzfbtyU5o816JEmP1FoQJJkHXAOcB5wMvDrJyZO6nQcsa34uAT7cVj2SpMGObHHfZwI7q2oXQJI1wArgtr4+K4C/qqoCbkpydJLjq+o7013M8i9/guMm7uKOv3/8dO9aklr3uu/cz3cXLoE3nz3t+27z1NAi4K6+9fGm7WD7kOSSJJuTbJ6YmDikYs475XhOPt4QkDQ7nXz84znvlONb2XebRwQZ0FaH0IeqWg2sBhgbG3vE9mEc9453HMrDJGnOa/OIYBxY0re+GNh9CH0kSS1qMwg2AcuSLE0yH7gIWD+pz3rg4ubqobOAH7Tx+YAkaf9aOzVUVXuTXApsBOYB11XV9iQrm+2rgA3A+cBO4CfA69uqR5I0WJufEVBVG+i92Pe3repbLuA/tVmDJOnA/GaxJHWcQSBJHWcQSFLHGQSS1HHpfV47eySZAO44xIcfC9w7jeXMBo65GxxzNxzOmH+5qhYO2jDrguBwJNlcVWMzXccoOeZucMzd0NaYPTUkSR1nEEhSx3UtCFbPdAEzwDF3g2PuhlbG3KnPCCRJj9S1IwJJ0iQGgSR13JwMgiTLk+xIsjPJlQO2J8nVzfZtSc6YiTqn0xBjfm0z1m1Jbkxy2kzUOZ2mGnNfv+ckeSjJK0ZZXxuGGXOSc5LcnGR7kutHXeN0G+Lf9hOS/G2SW5oxz+pZjJNcl+SeJN/cz/bpf/2qqjn1Q2/K638CngrMB24BTp7U53zgc/TukHYW8LWZrnsEY/63wBOb5fO6MOa+fv+X3iy4r5jpukfwdz6a3n3BT2jWnzzTdY9gzO8A3t8sLwT+BZg/07UfxpifD5wBfHM/26f99WsuHhGcCeysql1V9QCwBlgxqc8K4K+q5ybg6CTt3Ax0NKYcc1XdWFXfb1Zvonc3uNlsmL8zwFuAzwL3jLK4lgwz5tcAa6vqToCqmu3jHmbMBTwuSYDH0guCvaMtc/pU1Q30xrA/0/76NReDYBFwV9/6eNN2sH1mk4MdzxvpvaOYzaYcc5JFwEuBVcwNw/ydnw48McmXk2xJcvHIqmvHMGP+IPCr9G5zeyvw1qp6eDTlzYhpf/1q9cY0MyQD2iZfIztMn9lk6PEkeSG9IHheqxW1b5gxfwC4oqoe6r1ZnPWGGfORwLOBfwcsAL6a5Kaq+lbbxbVkmDGfC9wMvAh4GvCFJF+pqvtbrm2mTPvr11wMgnFgSd/6YnrvFA62z2wy1HiSnApcC5xXVd8bUW1tGWbMY8CaJgSOBc5Psreq1o2kwuk37L/te6vqx8CPk9wAnAbM1iAYZsyvB95XvRPoO5N8G3gG8PXRlDhy0/76NRdPDW0CliVZmmQ+cBGwflKf9cDFzafvZwE/qKrvjLrQaTTlmJOcAKwFfmcWvzvsN+WYq2ppVZ1YVScCnwF+dxaHAAz3b/tvgF9PcmSSRwPPBW4fcZ3TaZgx30nvCIgk/wY4Cdg10ipHa9pfv+bcEUFV7U1yKbCR3hUH11XV9iQrm+2r6F1Bcj6wE/gJvXcUs9aQY34XcAzwoeYd8t6axTM3DjnmOWWYMVfV7Uk+D2wDHgauraqBlyHOBkP+nd8NfCTJrfROm1xRVbN2euoknwTOAY5NMg78MXAUtPf65RQTktRxc/HUkCTpIBgEktRxBoEkdZxBIEkdZxBIUscZBOqcJP85ydta2vc/Jzl2ij4/Osh9tlavBAaBJHWeQaA5LcnFzZzttyT52IDt/yHJpmb7Z5tv45LklUm+2bTf0LQ9M8nXm7n+tyVZNsVzr2smftue5JJJ2/40yTeSfDHJwqbtaUk+3zzmK0meMWCflyW5rXn+NYfz30baxyDQnJXkmcAfAi+qqtOAtw7otraqntNsv53ehHzQ+yb2uU37S5q2lcBfVNWz6M1jND5FCW+oqmc3fS9LckzT/hjgG1V1BnA9vW+OQu/G5G9pHvM24EMD9nklcHpVndrUIx22OTfFhNTnRcBn9k03UFWD5nj/tSTvoXdDl8fSm8oA4B/oTVvwaXpzNAF8FfjDJIvpBcg/TvH8lyV5abO8BFgGfI/e1A+fato/DqxN8lh6Nw/6X30zpf7SgH1uA/5nknXAuimeXxqKRwSay8LU0/N+BLi0qk4B/gR4FEBVrQT+iN4L+M1JjqmqT9A7OtgDbEzyov0+cXIO8BvA2c1RxdZ9+x6g6P2/eF9VPavv51cH9L0AuIbeVNNbkvhmTofNINBc9kXgVftOySR50oA+jwO+k+Qo4LX7GpM8raq+VlXvAu4FliR5KrCrqq6mNwPkqQd47icA36+qnzTn+s/q23YEsO/+ya8B/r6ZO//bSV7ZPH8y6b7SSY4AllTVl4A/4GdHMdJh8d2E5qxmlsr/Clyf5CF678pfN6nbO4GvAXfQu7vV45r2q5oPg0MvUG6hd37+t5M8CHwX+C8HePrPAyuTbAN20Ls96D4/Bp6ZZAvwA+C3mvbXAh9O8kf0Zptc0zzvPvOAjyd5QlPXn1fVfUP8p5AOyNlHJanjPDUkSR1nEEhSxxkEktRxBoEkdZxBIEkdZxBIUscZBJLUcf8fXbJC4txZ1hUAAAAASUVORK5CYII=", 702 | "text/plain": [ 703 | "
" 704 | ] 705 | }, 706 | "metadata": { 707 | "needs_background": "light" 708 | }, 709 | "output_type": "display_data" 710 | }, 711 | { 712 | "name": "stdout", 713 | "output_type": "stream", 714 | "text": [ 715 | "The sum of probability = 1.0\n" 716 | ] 717 | } 718 | ], 719 | "source": [ 720 | "idx = np.random.choice(len(test_dataset))\n", 721 | "\n", 722 | "x = test_dataset[idx][0]\n", 723 | "print(f'x of {x.shape} :')\n", 724 | "plt.imshow(x[0], cmap='gray')\n", 725 | "plt.show()\n", 726 | "\n", 727 | "print(f'true label = y = {test_dataset[idx][1]}\\n')\n", 728 | "\n", 729 | "\n", 730 | "out = model(x.view(1, 1, size, size)).detach().flatten()\n", 731 | "prob = torch.stack([1 - out, out])\n", 732 | "pred = prob.argmax().item()\n", 733 | "print(f'predicted label = {pred}\\n')\n", 734 | "\n", 735 | "plt.stem(np.arange(2), prob)\n", 736 | "plt.ylabel('probability')\n", 737 | "plt.xlabel('class labels')\n", 738 | "plt.show()\n", 739 | "\n", 740 | "print(f'The sum of probability = {torch.sum(prob).item()}')" 741 | ] 742 | }, 743 | { 744 | "cell_type": "markdown", 745 | "id": "8b298922", 746 | "metadata": {}, 747 | "source": [ 748 | "$ $\n", 749 | "\n", 750 | "$ $\n", 751 | "\n", 752 | "## simplified training loop" 753 | ] 754 | }, 755 | { 756 | "cell_type": "code", 757 | "execution_count": 13, 758 | "id": "b3dcb3aa", 759 | "metadata": {}, 760 | "outputs": [], 761 | "source": [ 762 | "def one_epoch_(model, loss_fn, optimizer, dataset, batch_size):\n", 763 | " \n", 764 | " dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)\n", 765 | " \n", 766 | " model.train()\n", 767 | " for batch, (X, y) in enumerate(dataloader):\n", 768 | " X, y = X.to(device), y.to(device)\n", 769 | "\n", 770 | " out = model(X) # Perform a single forward pass\n", 771 | " loss = loss_fn(out, y) \n", 772 | " \n", 773 | " optimizer.zero_grad() # Clear gradients\n", 774 | " loss.backward() # Derive gradients, backpropagation\n", 775 | " optimizer.step() # Update parameters based on gradients\n", 776 | " \n", 777 | " if batch % batch_size == 0: \n", 778 | " print(f\"train loss: {round(loss.item(),3)}\")\n", 779 | "\n", 780 | "\n", 781 | " \n", 782 | " \n", 783 | "def training_(dataset, batch_size, n, L, lr_, weight_decay_, epochs):\n", 784 | " \n", 785 | " model = QNN(n=n, L=L).to(device)\n", 786 | " loss_fn = nn.BCELoss()\n", 787 | " optimizer = torch.optim.Adam(model.parameters(), lr=lr_, weight_decay=weight_decay_)\n", 788 | " \n", 789 | " for t in range(epochs): \n", 790 | " print(f\"Epoch {t+1} ---------------------------------- \\n \")\n", 791 | " one_epoch_(model, loss_fn, optimizer, dataset, batch_size)\n", 792 | " \n", 793 | " #accuracy_train, loss_train = performance_estimate(train_dataset, model, loss_fn, 'train')\n", 794 | " #accuracy_test, loss_test = performance_estimate(test_dataset, model, loss_fn, 'test ')\n", 795 | " #print()\n", 796 | "\n", 797 | " \n", 798 | " model_state_dict = model.state_dict()\n", 799 | " return model_state_dict\n", 800 | "\n", 801 | " " 802 | ] 803 | }, 804 | { 805 | "cell_type": "code", 806 | "execution_count": 14, 807 | "id": "7217eb93", 808 | "metadata": { 809 | "scrolled": false 810 | }, 811 | "outputs": [ 812 | { 813 | "name": "stdout", 814 | "output_type": "stream", 815 | "text": [ 816 | "batch_size = 64\n", 817 | "\n", 818 | "Epoch 1 ---------------------------------- \n", 819 | " \n", 820 | "train loss: 0.674\n", 821 | "train loss: 0.64\n", 822 | "train loss: 0.594\n", 823 | "train loss: 0.566\n", 824 | "Epoch 2 ---------------------------------- \n", 825 | " \n", 826 | "train loss: 0.566\n", 827 | "train loss: 0.575\n", 828 | "train loss: 0.516\n", 829 | "train loss: 0.5\n", 830 | "Epoch 3 ---------------------------------- \n", 831 | " \n", 832 | "train loss: 0.52\n", 833 | "train loss: 0.511\n", 834 | "train loss: 0.488\n", 835 | "train loss: 0.493\n", 836 | "\n", 837 | " ~~~~~ training is done ~~~~~\n", 838 | "\n", 839 | "CPU times: user 1min 33s, sys: 47.8 ms, total: 1min 33s\n", 840 | "Wall time: 23.3 s\n" 841 | ] 842 | } 843 | ], 844 | "source": [ 845 | "%%time\n", 846 | "\n", 847 | "\n", 848 | "batch_size = 64\n", 849 | "print(f'batch_size = {batch_size}\\n')\n", 850 | "\n", 851 | "#----------------------------------------------------------------------------------\n", 852 | "\n", 853 | "model_state_dict = training_(train_dataset, batch_size=batch_size, lr_=1e-3, n=n, L=L, \n", 854 | " weight_decay_=1e-8, epochs=3)\n", 855 | "\n", 856 | "print()\n", 857 | "print(f' ~~~~~ training is done ~~~~~\\n')" 858 | ] 859 | }, 860 | { 861 | "cell_type": "code", 862 | "execution_count": 15, 863 | "id": "6f5645b9", 864 | "metadata": {}, 865 | "outputs": [ 866 | { 867 | "name": "stdout", 868 | "output_type": "stream", 869 | "text": [ 870 | "train accuracy: 0.914, train loss: 0.493\n", 871 | "test accuracy: 0.916, test loss: 0.492\n", 872 | "\n", 873 | "CPU times: user 22.4 s, sys: 7.94 ms, total: 22.4 s\n", 874 | "Wall time: 5.62 s\n" 875 | ] 876 | } 877 | ], 878 | "source": [ 879 | "%%time\n", 880 | "\n", 881 | "\n", 882 | "model = QNN(n=n, L=L).to(device)\n", 883 | "model.load_state_dict(model_state_dict)\n", 884 | "loss_fn = nn.BCELoss()\n", 885 | "\n", 886 | "\n", 887 | "accuracy_train, loss_train = performance_estimate(train_dataset, model, loss_fn, 'train')\n", 888 | "accuracy_test, loss_test = performance_estimate(test_dataset, model, loss_fn, 'test ')\n", 889 | "print()\n", 890 | "\n" 891 | ] 892 | }, 893 | { 894 | "cell_type": "markdown", 895 | "id": "2c75a3b1", 896 | "metadata": {}, 897 | "source": [ 898 | "# " 899 | ] 900 | } 901 | ], 902 | "metadata": { 903 | "kernelspec": { 904 | "display_name": "Python 3 (ipykernel)", 905 | "language": "python", 906 | "name": "python3" 907 | }, 908 | "language_info": { 909 | "codemirror_mode": { 910 | "name": "ipython", 911 | "version": 3 912 | }, 913 | "file_extension": ".py", 914 | "mimetype": "text/x-python", 915 | "name": "python", 916 | "nbconvert_exporter": "python", 917 | "pygments_lexer": "ipython3", 918 | "version": "3.10.6" 919 | } 920 | }, 921 | "nbformat": 4, 922 | "nbformat_minor": 5 923 | } 924 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image classification with CNN and QNN 2 | 3 | Image classification is a computer vision task, where a machine learning algorithm has to categorize images into predefined classes. Here, we have taken the "MNIST" and "FashionMNIST" datasets (for more details, see the attached notebooks). 4 | 5 | The "MNIST" dataset contains images of handwritten digits from 0 to 9. 6 | 7 | The "FashionMNIST" dataset covers ten different clothing categories, including T-shirts/tops, trousers, pullovers, dresses, coats, sandals, shirts, sneakers, bags, and ankle boots. Each image is labeled with the corresponding clothing category, providing the ground truth for training and testing. 8 | 9 | ![pred_CNN](https://github.com/ArunSehrawat/Image_classification_with_CNN/assets/99533657/169936fb-72c7-436c-b069-2e43f0ad49ae) 10 | 11 | ----- 12 | 13 | **(1)** In the attached Jupyter notebook __Image_classification_with_CNN__, I have presented the **Convolutional Neural Networks (CNNs)** for the task. 14 | 15 | 16 | ![LeNET5](https://github.com/ArunSehrawat/Image_classification_with_CNN_and_QNN/assets/99533657/21d097f2-d958-4100-bf90-0ffdb3360017) 17 | 18 | Image (LeNET-5) source: http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf 19 | 20 | ----- 21 | 22 | **(2)** In the notebook __Image_classification_with_QNN__, I have presented the **Quantum Neural Network (QNN)** for the task. Here, the QNN consists of quantum layers followed by a classical layer at the end. 23 | 24 | ![QNN](https://github.com/ArunSehrawat/Image_classification_with_CNN_and_QNN/assets/99533657/febb497c-1dd7-4517-8ea4-4c1871c6540b) 25 | 26 | ----- 27 | 28 | **(3)** Unlike __Image_classification_with_QNN__, 29 | we are taking only two classes 0 and 1 and taking __only quantum layers__ in the notebook __Image_classification_with_QNN_2__. 30 | 31 | 32 | ![QNN2](https://github.com/ArunSehrawat/Image_classification_with_CNN_and_QNN/assets/99533657/8ee8b2fa-d71e-4850-9531-028c31ffb79e) 33 | 34 | 35 | ----- 36 | 37 | I have also uploaded: 38 | **(4) quantum circuit simulator** (used in QNN) 39 | 40 | **(5)** a notebook with **quantum circuit examples** to test the quantum circuit simulator 41 | 42 | ------- 43 | 44 | ![acc_loss_CNN](https://github.com/ArunSehrawat/Image_classification_with_CNN/assets/99533657/6c1f2025-db88-411f-b3e3-22e2779d40a4) 45 | -------------------------------------------------------------------------------- /quantum_circuit_examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "0aaf95a2", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import torch\n", 19 | "#import numpy as np\n", 20 | "\n", 21 | "from quantum_circuit_simulator import quantum_circuit\n", 22 | "\n", 23 | "\n", 24 | "'''\n", 25 | "state_vector can \n", 26 | "(1) either be a vector of shape (dim,)\n", 27 | "(2) either be a matrix of shape (dim, number of examples)\n", 28 | "\n", 29 | "n is the number of qubits\n", 30 | "dim = 2**n\n", 31 | "'''\n", 32 | "\n", 33 | "print()" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "2e53aba3", 39 | "metadata": {}, 40 | "source": [ 41 | "## X, Y, Z gates" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 2, 47 | "id": "60902d84", 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "tensor([1.+0.j, 2.+0.j]) = initial state vector\n", 55 | "\n", 56 | "1 = n\n", 57 | "2 = dim\n", 58 | "\n", 59 | "tensor([[1.+0.j, 0.+0.j],\n", 60 | " [0.+0.j, 1.+0.j]]) = I\n", 61 | "tensor([[0.+0.j, 1.+0.j],\n", 62 | " [1.+0.j, 0.+0.j]]) = X\n", 63 | "tensor([[0.+0.j, -0.-1.j],\n", 64 | " [0.+1.j, 0.+0.j]]) = Y\n", 65 | "tensor([[ 1.+0.j, 0.+0.j],\n", 66 | " [ 0.+0.j, -1.+0.j]]) = Z\n", 67 | "tensor([[ 0.7071+0.j, 0.7071+0.j],\n", 68 | " [ 0.7071+0.j, -0.7071+0.j]]) = H\n", 69 | "tensor([[1.+0.j, 0.+0.j],\n", 70 | " [0.+0.j, 0.+0.j]]) = proj_0\n", 71 | "tensor([[0.+0.j, 0.+0.j],\n", 72 | " [0.+0.j, 1.+0.j]]) = proj_1\n", 73 | "\n", 74 | "tensor([2.+0.j, 1.+0.j]) = final state vector\n", 75 | "\n" 76 | ] 77 | } 78 | ], 79 | "source": [ 80 | "num_qubits = 1\n", 81 | "dim = 2**num_qubits\n", 82 | "\n", 83 | "state_vector = torch.arange(dim) + 1 # state_vector must be normalized\n", 84 | "#state_vector = state_vector.reshape(-1,1)\n", 85 | "qc = quantum_circuit(num_qubits = num_qubits, state_vector = state_vector)\n", 86 | "print(qc.state_vector, '= initial state vector\\n')\n", 87 | "\n", 88 | "print(qc.n, '= n')\n", 89 | "print(qc.dim, '= dim\\n')\n", 90 | "print(qc.I, '= I')\n", 91 | "print(qc.x_matrix, '= X')\n", 92 | "print(qc.y_matrix, '= Y')\n", 93 | "print(qc.z_matrix, '= Z')\n", 94 | "\n", 95 | "print(qc.h_matrix, '= H')\n", 96 | "\n", 97 | "print(qc.proj_0, '= proj_0')\n", 98 | "print(qc.proj_1, '= proj_1\\n')\n", 99 | "\n", 100 | "\n", 101 | "qc.x(0)\n", 102 | "#qc.y(0)\n", 103 | "#qc.z(2)\n", 104 | "\n", 105 | "print(qc.state_vector, '= final state vector\\n')\n", 106 | "\n" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "id": "b4fa3fc7", 112 | "metadata": {}, 113 | "source": [ 114 | "## H gate" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 3, 120 | "id": "49b82f13", 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "name": "stdout", 125 | "output_type": "stream", 126 | "text": [ 127 | "tensor([[1.+0.j],\n", 128 | " [0.+0.j],\n", 129 | " [0.+0.j],\n", 130 | " [0.+0.j],\n", 131 | " [0.+0.j],\n", 132 | " [0.+0.j],\n", 133 | " [0.+0.j],\n", 134 | " [0.+0.j]]) = initial state vector\n", 135 | "\n", 136 | "tensor([[0.5000+0.j],\n", 137 | " [0.0000+0.j],\n", 138 | " [0.5000+0.j],\n", 139 | " [0.0000+0.j],\n", 140 | " [0.5000+0.j],\n", 141 | " [0.0000+0.j],\n", 142 | " [0.5000+0.j],\n", 143 | " [0.0000+0.j]]) = final state vector\n", 144 | "\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "num_qubits = 3\n", 150 | "dim = 2**num_qubits\n", 151 | "\n", 152 | "qc = quantum_circuit(num_qubits = num_qubits)\n", 153 | "print(qc.state_vector, '= initial state vector\\n')\n", 154 | "\n", 155 | "qc.h(0)\n", 156 | "qc.h(1)\n", 157 | "print(qc.state_vector, '= final state vector\\n')" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "id": "ac7da2a2", 163 | "metadata": {}, 164 | "source": [ 165 | "## Rx, Ry, Rz gates" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 4, 171 | "id": "7aeac4d1", 172 | "metadata": {}, 173 | "outputs": [ 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "tensor([[1.+0.j],\n", 179 | " [0.+0.j],\n", 180 | " [0.+0.j],\n", 181 | " [0.+0.j]]) = initial state vector\n", 182 | "\n", 183 | "[tensor(0.9394, grad_fn=), tensor(0.3429, grad_fn=)] \n", 184 | "\n", 185 | "[tensor(0.7649, grad_fn=), tensor(0.6442, grad_fn=)] \n", 186 | "\n", 187 | "tensor([[0.9394+0.j],\n", 188 | " [0.0000+0.j],\n", 189 | " [0.3429+0.j],\n", 190 | " [0.0000+0.j]], grad_fn=) = final state vector\n", 191 | "\n" 192 | ] 193 | } 194 | ], 195 | "source": [ 196 | "num_qubits = 2\n", 197 | "dim = 2**num_qubits\n", 198 | "\n", 199 | "\n", 200 | "qc = quantum_circuit(num_qubits = num_qubits)\n", 201 | "print(qc.state_vector, '= initial state vector\\n')\n", 202 | "\n", 203 | "\n", 204 | "\n", 205 | "ang_ = torch.rand(1, requires_grad=True)\n", 206 | "ang=ang_[0]\n", 207 | "\n", 208 | "print([torch.cos(ang/2), torch.sin(ang/2)], '\\n')\n", 209 | "print([torch.cos(ang), torch.sin(ang)], '\\n')\n", 210 | "\n", 211 | "\n", 212 | "\n", 213 | "#qc.Rx(0, ang)\n", 214 | "\n", 215 | "\n", 216 | "qc.Ry(0, ang)\n", 217 | "\n", 218 | "#qc.x(0)\n", 219 | "#qc.Rz(0, ang)\n", 220 | "\n", 221 | "\n", 222 | "print(qc.state_vector, '= final state vector\\n')" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "id": "e8962768", 228 | "metadata": {}, 229 | "source": [ 230 | "## general rotation R" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 5, 236 | "id": "18c0b6f6", 237 | "metadata": {}, 238 | "outputs": [ 239 | { 240 | "name": "stdout", 241 | "output_type": "stream", 242 | "text": [ 243 | "tensor([[1.+0.j],\n", 244 | " [0.+0.j]]) = initial state vector\n", 245 | "\n", 246 | "[tensor(0.9846, grad_fn=), tensor(0.1748, grad_fn=)] \n", 247 | "\n", 248 | "[tensor(0.9389, grad_fn=), tensor(0.3443, grad_fn=)] \n", 249 | "\n", 250 | "tensor([[0.0000+0.0000j],\n", 251 | " [0.9389+0.3443j]], grad_fn=) = final state vector\n", 252 | "\n" 253 | ] 254 | } 255 | ], 256 | "source": [ 257 | "num_qubits = 1\n", 258 | "dim = 2**num_qubits\n", 259 | "\n", 260 | "\n", 261 | "qc = quantum_circuit(num_qubits = num_qubits)\n", 262 | "print(qc.state_vector, '= initial state vector\\n')\n", 263 | "\n", 264 | "ang_ = torch.rand(1, requires_grad=True)\n", 265 | "ang=ang_[0]\n", 266 | "\n", 267 | "\n", 268 | "print([torch.cos(ang/2), torch.sin(ang/2)], '\\n')\n", 269 | "print([torch.cos(ang), torch.sin(ang)], '\\n')\n", 270 | "zero = torch.tensor(0.0)\n", 271 | "\n", 272 | "\n", 273 | "theta = zero #ang \n", 274 | "phi = zero #ang\n", 275 | "lamda = ang #zero\n", 276 | "\n", 277 | "\n", 278 | "\n", 279 | "qc.x(0)\n", 280 | "qc.R(0, theta=theta, phi=phi, lamda=lamda)\n", 281 | "print(qc.state_vector, '= final state vector\\n')\n", 282 | "\n", 283 | "\n" 284 | ] 285 | }, 286 | { 287 | "cell_type": "markdown", 288 | "id": "d18a9854", 289 | "metadata": {}, 290 | "source": [ 291 | "## CX, CZ gates" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 6, 297 | "id": "6cb1b102", 298 | "metadata": {}, 299 | "outputs": [ 300 | { 301 | "name": "stdout", 302 | "output_type": "stream", 303 | "text": [ 304 | "tensor([1.+0.j, 2.+0.j, 3.+0.j, 4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j, 8.+0.j]) = initial state vector\n", 305 | "\n", 306 | "tensor([ 1.+0.j, 2.+0.j, 3.+0.j, 4.+0.j, 5.+0.j, 6.+0.j, -7.+0.j, -8.+0.j]) = final state vector\n", 307 | "\n" 308 | ] 309 | } 310 | ], 311 | "source": [ 312 | "num_qubits = 3\n", 313 | "dim = 2**num_qubits\n", 314 | "\n", 315 | "\n", 316 | "state_vector = torch.arange(dim) + 1 # state_vector must be normalized\n", 317 | "#state_vector = state_vector.reshape(-1,1)\n", 318 | "qc = quantum_circuit(num_qubits = num_qubits, state_vector = state_vector)\n", 319 | "print(qc.state_vector, '= initial state vector\\n')\n", 320 | "\n", 321 | "#qc.cx(control=0, target=1)\n", 322 | "qc.cz(control=0, target=1)\n", 323 | "\n", 324 | "print(qc.state_vector, '= final state vector\\n')" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 7, 330 | "id": "4add728a", 331 | "metadata": {}, 332 | "outputs": [ 333 | { 334 | "name": "stdout", 335 | "output_type": "stream", 336 | "text": [ 337 | "tensor([[0.3959+0.j, 0.0264+0.j],\n", 338 | " [0.7602+0.j, 0.7937+0.j],\n", 339 | " [0.6448+0.j, 0.6465+0.j],\n", 340 | " [0.3658+0.j, 0.4472+0.j],\n", 341 | " [0.9422+0.j, 0.6814+0.j],\n", 342 | " [0.5886+0.j, 0.2620+0.j],\n", 343 | " [0.3211+0.j, 0.7520+0.j],\n", 344 | " [0.6750+0.j, 0.4294+0.j]]) = initial state vectors\n", 345 | "\n", 346 | "tensor([[ 0.3959+0.j, 0.0264+0.j],\n", 347 | " [ 0.7602+0.j, 0.7937+0.j],\n", 348 | " [ 0.6448+0.j, 0.6465+0.j],\n", 349 | " [ 0.3658+0.j, 0.4472+0.j],\n", 350 | " [ 0.9422+0.j, 0.6814+0.j],\n", 351 | " [ 0.5886+0.j, 0.2620+0.j],\n", 352 | " [-0.3211+0.j, -0.7520+0.j],\n", 353 | " [-0.6750+0.j, -0.4294+0.j]]) = final state vector\n", 354 | "\n" 355 | ] 356 | } 357 | ], 358 | "source": [ 359 | "num_qubits = 3\n", 360 | "dim = 2**num_qubits\n", 361 | "\n", 362 | "state_vector = torch.rand((dim, 2)) # state_vector must be normalized\n", 363 | "qc = quantum_circuit(num_qubits = num_qubits, state_vector = state_vector)\n", 364 | "print(qc.state_vector, '= initial state vectors\\n')\n", 365 | "\n", 366 | "\n", 367 | "#qc.cx(control=0, target=1)\n", 368 | "qc.cz(control=0, target=1)\n", 369 | "\n", 370 | "print(qc.state_vector, '= final state vector\\n')" 371 | ] 372 | }, 373 | { 374 | "cell_type": "markdown", 375 | "id": "885e524b", 376 | "metadata": {}, 377 | "source": [ 378 | "# cx_, cz_linear_layer" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 8, 384 | "id": "b8534c5c", 385 | "metadata": {}, 386 | "outputs": [ 387 | { 388 | "name": "stdout", 389 | "output_type": "stream", 390 | "text": [ 391 | "3 4\n", 392 | "2 3\n", 393 | "1 2\n", 394 | "0 1\n" 395 | ] 396 | } 397 | ], 398 | "source": [ 399 | "nn = 5\n", 400 | "\n", 401 | "print(nn-2, nn-1)\n", 402 | "for i in range(nn - 3, -1, -1):\n", 403 | " print(i, i+1)" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": 9, 409 | "id": "88b94daf", 410 | "metadata": {}, 411 | "outputs": [ 412 | { 413 | "name": "stdout", 414 | "output_type": "stream", 415 | "text": [ 416 | "tensor([1.+0.j, 2.+0.j, 3.+0.j, 4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j, 8.+0.j]) = initial state vector\n", 417 | "\n", 418 | "tensor([1.+0.j, 2.+0.j, 4.+0.j, 3.+0.j, 8.+0.j, 7.+0.j, 5.+0.j, 6.+0.j]) = final state vector\n", 419 | "\n" 420 | ] 421 | } 422 | ], 423 | "source": [ 424 | "num_qubits = 3\n", 425 | "dim = 2**num_qubits\n", 426 | "\n", 427 | "state_vector = torch.arange(dim) + 1 # state_vector must be normalized\n", 428 | "#state_vector = state_vector.reshape(-1,1)\n", 429 | "qc = quantum_circuit(num_qubits = num_qubits, state_vector = state_vector)\n", 430 | "print(qc.state_vector, '= initial state vector\\n')\n", 431 | "\n", 432 | "''' \n", 433 | "cx_linear_layer applies cx(n-1,n) ... cx(2,3) cx(1,2) cx(0,1) |state_vector>\n", 434 | "\n", 435 | "NOTE: First cx(0,1) will act on |state_vector>, then cx(1,2)\n", 436 | " And in the last cx(n-1,n) will act.\n", 437 | " order matter in case of cx\n", 438 | "'''\n", 439 | "\n", 440 | "qc.cx_linear_layer()\n", 441 | "\n", 442 | "#qc.cz_linear_layer()\n", 443 | "print(qc.state_vector, '= final state vector\\n')" 444 | ] 445 | }, 446 | { 447 | "cell_type": "markdown", 448 | "id": "48954fe1", 449 | "metadata": {}, 450 | "source": [ 451 | "# Ry_layer" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": 10, 457 | "id": "0ab3be16", 458 | "metadata": {}, 459 | "outputs": [ 460 | { 461 | "name": "stdout", 462 | "output_type": "stream", 463 | "text": [ 464 | "tensor([[1.+0.j],\n", 465 | " [0.+0.j],\n", 466 | " [0.+0.j],\n", 467 | " [0.+0.j]]) = initial state vector\n", 468 | "\n", 469 | "ang = tensor([0.8070, 0.8190], requires_grad=True)\n", 470 | "ang = tensor([0.8070+0.j, 0.8190+0.j], grad_fn=)\n" 471 | ] 472 | }, 473 | { 474 | "data": { 475 | "text/plain": [ 476 | "tensor([[0.4724+0.j],\n", 477 | " [0.5053+0.j],\n", 478 | " [0.4932+0.j],\n", 479 | " [0.5275+0.j]], grad_fn=)" 480 | ] 481 | }, 482 | "execution_count": 10, 483 | "metadata": {}, 484 | "output_type": "execute_result" 485 | } 486 | ], 487 | "source": [ 488 | "num_qubits = 2\n", 489 | "dim = 2**num_qubits\n", 490 | "\n", 491 | "\n", 492 | "qc = quantum_circuit(num_qubits = num_qubits)\n", 493 | "print(qc.state_vector, '= initial state vector\\n')\n", 494 | "\n", 495 | "\n", 496 | "\n", 497 | "ang = torch.rand(num_qubits, dtype=torch.float, requires_grad=True)\n", 498 | "print('ang =',ang)\n", 499 | "ang = ang.to(torch.cfloat) # = torch.complex(real=ang, imag=torch.zeros(num_qubits))\n", 500 | "print('ang =',ang)\n", 501 | "\n", 502 | "\n", 503 | "\n", 504 | "qc.Ry_layer(ang)" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": null, 510 | "id": "779636d0", 511 | "metadata": {}, 512 | "outputs": [], 513 | "source": [] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "execution_count": null, 518 | "id": "df53882f", 519 | "metadata": {}, 520 | "outputs": [], 521 | "source": [] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": null, 526 | "id": "892b33be", 527 | "metadata": {}, 528 | "outputs": [], 529 | "source": [] 530 | } 531 | ], 532 | "metadata": { 533 | "kernelspec": { 534 | "display_name": "Python 3 (ipykernel)", 535 | "language": "python", 536 | "name": "python3" 537 | }, 538 | "language_info": { 539 | "codemirror_mode": { 540 | "name": "ipython", 541 | "version": 3 542 | }, 543 | "file_extension": ".py", 544 | "mimetype": "text/x-python", 545 | "name": "python", 546 | "nbconvert_exporter": "python", 547 | "pygments_lexer": "ipython3", 548 | "version": "3.10.6" 549 | } 550 | }, 551 | "nbformat": 4, 552 | "nbformat_minor": 5 553 | } 554 | -------------------------------------------------------------------------------- /quantum_circuit_simulator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | 6 | def get_device(gpu_no): 7 | if torch.cuda.is_available(): 8 | return torch.device('cuda', gpu_no) 9 | else: 10 | return torch.device('cpu') 11 | 12 | 13 | 14 | 15 | 16 | 17 | class quantum_circuit: 18 | 19 | def __init__(self, num_qubits : int, state_vector = None, device = 'cuda', gpu_no = 0): 20 | 21 | """ 22 | Defines a quantum circuit object that stores the full state-vector (evolved through 23 | the unitary operations of a quantum circuit) of `num_qubits` number of qubits. 24 | 25 | Args: 26 | num_qubits (int): Number of qubits in the circuit. 27 | 28 | state_vector (torch.Tensor, optional): The full state vector of the quantum circuit. 29 | Defaults to None. If None is provided then the state 30 | vector is automatically initialized to the ket |0000...0>. 31 | 32 | device (str, optional): Device on which the state vector should be stored (CPU / GPU). 33 | Defaults to 'cuda' i.e. GPU. 34 | 35 | gpu_no (int, optional): If there are multiple GPUs then this parameter defines which 36 | GPU to use. Defaults to 0 i.e. the first device. 37 | """ 38 | #---------------------------------------------------------------------------------------- 39 | 40 | if device != 'cuda': 41 | self.device = torch.device(device) 42 | else: 43 | self.device = get_device(gpu_no) 44 | 45 | 46 | #---------------------------------------------------------------------------------------- 47 | 48 | 49 | self.n = num_qubits # number of qubits 50 | self.dim = 2**self.n # dimention of the n-qubit hilbert space 51 | 52 | 53 | #---------------------------------------------------------------------------------------- 54 | 55 | ''' 56 | state_vector can 57 | (1) either be a vector of shape (dim,) 58 | (2) either be a matrix of shape (dim, number of examples) 59 | ''' 60 | 61 | if state_vector is None: 62 | ''' Initialize the state-vector to |0000...0> ''' 63 | state_vector = torch.zeros(self.dim, device=self.device, dtype=torch.cfloat) 64 | state_vector[0] = 1 65 | self.state_vector = state_vector.reshape(-1,1) 66 | else: 67 | if state_vector.shape[0] == self.dim: 68 | ''' state_vector must be normalized ''' 69 | self.state_vector = state_vector.to(torch.cfloat) 70 | else: 71 | print('The dimension 2**n does NOT match the shape of the state vector. n is the number of qubits.') 72 | 73 | 74 | #---------------------------------------------------------------------------------------- 75 | 76 | # single qubit Pauli gates (matrices) : 77 | self.I = torch.tensor([[1, 0], [0, 1]], device=self.device, dtype=torch.cfloat) 78 | self.x_matrix = torch.tensor([[0., 1], [1, 0]], device=self.device, dtype=torch.cfloat) 79 | self.y_matrix = torch.tensor([[0, -1j], [1j, 0]], device=self.device, dtype=torch.cfloat) 80 | self.z_matrix = torch.tensor([[1, 0], [0, -1]], device=self.device, dtype=torch.cfloat) 81 | 82 | self.h_matrix = (1 / math.sqrt(2)) * torch.tensor([[1, 1], [1, -1]], device=self.device, dtype=torch.cfloat) 83 | 84 | 85 | # single qubit projectors : 86 | self.proj_0 = torch.tensor([[1, 0], [0, 0]], device=self.device, dtype=torch.cfloat) 87 | self.proj_1 = torch.tensor([[0, 0], [0, 1]], device=self.device, dtype=torch.cfloat) 88 | 89 | 90 | 91 | #====================================================================================================== 92 | 93 | 94 | 95 | def single_qubit_gate(self, target : int, gate : torch.Tensor): 96 | """ 97 | Applies a single qubit gate = I ⊗ I ⊗ ... ⊗ gate ⊗ ... ⊗ I 98 | 99 | Args: 100 | target (int): The qubit index on which the gate will be applied 101 | gate (torch.Tensor): The matrix representation of a single qubit gate 102 | 103 | Returns: 104 | The state vector of the full quantum circuit after applying the single qubit gate. 105 | 106 | """ 107 | 108 | if target < 0 or self.n <= target: 109 | print('0 <= traget <= num_qubits - 1 is NOT satisfied!') 110 | 111 | else: 112 | single_q_gate = torch.tensor(1, device=self.device, dtype=torch.cfloat) # initialize 113 | 114 | for k in range(self.n): 115 | if k == target: 116 | single_q_gate = torch.kron(single_q_gate, gate) 117 | else: 118 | single_q_gate = torch.kron(single_q_gate, self.I) 119 | 120 | #------------------------------------------------------ 121 | 122 | self.state_vector = torch.matmul(single_q_gate, self.state_vector) 123 | return self.state_vector 124 | 125 | 126 | 127 | 128 | def controlled_gate(self, control: int, target: int, gate : torch.Tensor): 129 | """ 130 | Applies a two-qubit controlled gate between the 'control` and `target` qubits. 131 | 132 | control_gate_part_0 = I ⊗ |0><0| ⊗ ... ⊗ I ⊗ ... ⊗ I 133 | control_gate_part_1 = I ⊗ |1><1| ⊗ ... ⊗ gate ⊗ ... ⊗ I SEE: the control is set to 1 134 | 135 | control_gate = control_gate_part_0 + control_gate_part_1 136 | 137 | 138 | Args: 139 | control (int): Control qubit index 140 | target (int): Target qubit index 141 | gate (torch.Tensor): The matrix representation of a single qubit gate 142 | 143 | Returns: 144 | The state vector of the full quantum circuit after applying the two-qubit gate. 145 | """ 146 | 147 | if control < 0 or self.n <= control: 148 | print('0 <= control <= num_qubits - 1 is NOT satisfied!') 149 | elif target < 0 or self.n <= target: 150 | print('0 <= target <= num_qubits - 1 is NOT satisfied!') 151 | elif control == target: 152 | print('control and traget qubits must be different!') 153 | else: 154 | control_gate_part_0 = torch.tensor(1, device=self.device, dtype=torch.cfloat) # initialize 155 | control_gate_part_1 = torch.tensor(1, device=self.device, dtype=torch.cfloat) 156 | 157 | for k in range(self.n): 158 | if k == control: 159 | control_gate_part_0 = torch.kron(control_gate_part_0, self.proj_0) 160 | control_gate_part_1 = torch.kron(control_gate_part_1, self.proj_1) 161 | elif k == target: 162 | control_gate_part_0 = torch.kron(control_gate_part_0, self.I) 163 | control_gate_part_1 = torch.kron(control_gate_part_1, gate) 164 | else: 165 | control_gate_part_0 = torch.kron(control_gate_part_0, self.I) 166 | control_gate_part_1 = torch.kron(control_gate_part_1, self.I) 167 | 168 | control_gate = control_gate_part_0 + control_gate_part_1 169 | 170 | self.state_vector = torch.matmul(control_gate, self.state_vector) 171 | return self.state_vector 172 | 173 | 174 | #====================================================================================================== 175 | 176 | def x(self, target : int): # Applies X gate (matrix) on the target qubit 177 | 'NOTE: 0 <= target <= num_qubits - 1' 178 | self.single_qubit_gate(target, self.x_matrix) 179 | 180 | 181 | def y(self, target : int): 182 | self.single_qubit_gate(target, self.y_matrix) 183 | 184 | 185 | def z(self, target : int): 186 | self.single_qubit_gate(target, self.z_matrix) 187 | 188 | 189 | def h(self, target : int): # Applies Hadamard gate (matrix) on the target qubit 190 | self.single_qubit_gate(target, self.h_matrix) 191 | 192 | 193 | #====================================================================================================== 194 | 195 | 196 | def Rx(self, target : int, theta): 197 | 198 | """ 199 | Applies Rx gate (rotation around x axis) on the target qubit 200 | 201 | Args: 202 | theta (torch.Tensor): Angle by which the qubit should be rotated around X axis. 203 | Usually a tunable parameter is passed. 204 | 205 | target (int): Qubit index on which the Rx gate will be applied. 206 | NOTE: 0 <= target <= num_qubits - 1 207 | """ 208 | 209 | co = torch.cos(theta / 2) 210 | si = torch.sin(theta / 2) 211 | self.Rx_matrix = torch.stack([torch.stack([co, -1j*si]), torch.stack([-1j*si, co])]) 212 | 213 | self.single_qubit_gate(target, self.Rx_matrix) 214 | 215 | 216 | 217 | 218 | def Ry(self, target : int, theta): #like Rx, Ry gate applies (rotation around y axis) on the target qubit 219 | 220 | co = torch.cos(theta / 2) 221 | si = torch.sin(theta / 2) 222 | self.Ry_matrix = torch.stack([torch.stack([co, -si]), torch.stack([si, co])]) 223 | 224 | self.single_qubit_gate(target, self.Ry_matrix) 225 | 226 | 227 | 228 | 229 | 230 | def Rz(self, target : int, theta): #like Rx, Ry gate applies (rotation around z axis) on the target qubit 231 | 232 | exp_theta = torch.exp( 1j*theta ) 233 | zero = torch.tensor(0) 234 | one = torch.tensor(1) 235 | self.Rz_matrix = torch.stack([torch.stack([one, zero]), torch.stack([zero, exp_theta])]) 236 | 237 | self.single_qubit_gate(target, self.Rz_matrix) 238 | 239 | 240 | 241 | 242 | def R(self, target : int, theta, phi, lamda): 243 | """ 244 | Applies general rotation to the target qubit 245 | 246 | Args: 247 | theta, phi and lamda (torch.Tensor): The Euler angles which define a general rotation around Bloch sphere. 248 | 249 | target (int): Qubit index on which the gate will be applied. 250 | """ 251 | 252 | a = torch.cos(theta / 2) 253 | b = - torch.exp(1j * lamda) * torch.sin(theta / 2) 254 | c = torch.exp(1j * phi) * torch.sin(theta / 2) 255 | d = torch.exp(1j * (phi + lamda)) * torch.cos(theta / 2) 256 | self.R_matrix = torch.stack([torch.stack([a, b]), torch.stack([c, d])]) 257 | 258 | self.single_qubit_gate(target, self.R_matrix) 259 | 260 | 261 | #====================================================================================================== 262 | 263 | 264 | def Ry_layer(self, angs: torch.Tensor): 265 | ''' 266 | Applies tensor-product of single-qubit rotations around y-axis 267 | ''' 268 | 269 | cos, sin = torch.cos(angs[0]), torch.sin(angs[0]) 270 | ''' 271 | Use torch.stack otherwise computation graph will be broken (or will not begin). 272 | And, grad will be gone (will not be stored). 273 | ''' 274 | rot = torch.stack([torch.stack([cos, -sin]), torch.stack([sin, cos])]) 275 | 276 | for i in range(1, len(angs)): # one angles for each qubit 277 | cos, sin = torch.cos(angs[i]), torch.sin(angs[i]) 278 | rot = torch.kron(rot, torch.stack([torch.stack([cos, -sin]), torch.stack([sin, cos])])) 279 | 280 | #-------------------------------------------------------------------------- 281 | 282 | self.state_vector = torch.matmul(rot, self.state_vector) # rotated state vector 283 | return self.state_vector 284 | 285 | 286 | 287 | 288 | 289 | def Rz_layer(self, angs: torch.Tensor): #like Ry_layer, Rz_layer acts 290 | 291 | exp_ang = torch.exp( 1j*angs[0] ) 292 | zero = torch.tensor(0) 293 | one = torch.tensor(1) 294 | 295 | rot = torch.stack([torch.stack([one, zero]), torch.stack([zero, exp_ang])]) 296 | 297 | for i in range(1, len(angs)): # one angles for each qubit 298 | exp_ang = torch.exp( 1j*angs[i] ) 299 | rot = torch.kron(rot, torch.stack([torch.stack([one, zero]), torch.stack([zero, exp_ang])]) ) 300 | 301 | #-------------------------------------------------------------------------- 302 | 303 | self.state_vector = torch.matmul(rot, self.state_vector) # rotated state vector 304 | return self.state_vector 305 | 306 | 307 | 308 | #====================================================================================================== 309 | 310 | def cx(self, control: int, target: int): 311 | """ 312 | Applies controlled-X gate = I ⊗ |0><0| ⊗ ... ⊗ I ⊗ ... ⊗ I + 313 | I ⊗ |1><1| ⊗ ... ⊗ X ⊗ ... ⊗ I 314 | 315 | Args: 316 | control (int): Control qubit index 317 | target (int): Target qubit index 318 | """ 319 | self.controlled_gate(control, target, self.x_matrix) 320 | 321 | 322 | 323 | 324 | def cz(self, control: int, target: int): #like cx, cz gate acts 325 | self.controlled_gate(control, target, self.z_matrix) 326 | 327 | 328 | 329 | #====================================================================================================== 330 | 331 | 332 | def cx_linear_layer(self): 333 | ''' 334 | Applies cx(n-1,n) ... cx(2,3) cx(1,2) cx(0,1) |state_vector> 335 | 336 | NOTE: First cx(0,1) will act on |state_vector>, then cx(1,2) 337 | And in the last cx(n-1,n) will act. 338 | order matter in case of cx 339 | ''' 340 | 341 | self.controlled_gate(self.n - 2, self.n - 1, self.x_matrix) 342 | for i in range(self.n - 3, -1, -1): 343 | self.controlled_gate(i, i+1, self.x_matrix) 344 | 345 | 346 | 347 | def cz_linear_layer(self): #like cx_linear_layer, cz_linear_layer acts 348 | self.controlled_gate(self.n - 2, self.n - 1, self.z_matrix) 349 | for i in range(self.n - 3, -1, -1): 350 | self.controlled_gate(i, i+1, self.z_matrix) 351 | 352 | 353 | #====================================================================================================== 354 | 355 | def probabilities(self): 356 | """ 357 | probabilities obtained in the z-measurement (computational basis) on the state vector 358 | 359 | Returns: A torch.Tensor of the size same as the state vector 360 | """ 361 | 362 | return self.state_vector.conj() * self.state_vector 363 | 364 | 365 | 366 | 367 | 368 | --------------------------------------------------------------------------------