├── .gitignore ├── CNN1.ipynb ├── CNN4.ipynb ├── HQNN-Quanv.ipynb ├── ImgClass_Classical.ipynb ├── ImgClass_Hybrid.ipynb ├── README.md ├── dataset_indices_500.pt └── hqcnn quanv output.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /CNN1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [] 7 | }, 8 | "kernelspec": { 9 | "name": "python3", 10 | "display_name": "Python 3" 11 | }, 12 | "language_info": { 13 | "name": "python" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "source": [ 20 | "\n", 21 | "\n", 22 | "#CNN1" 23 | ], 24 | "metadata": { 25 | "id": "b56c8rM5cN-U" 26 | } 27 | }, 28 | { 29 | "cell_type": "code", 30 | "source": [ 31 | "import numpy as np\n", 32 | "import matplotlib.pyplot as plt\n", 33 | "\n", 34 | "from pathlib import Path\n", 35 | "\n", 36 | "import torch\n", 37 | "from torch.autograd import Function\n", 38 | "from torchvision import datasets, transforms\n", 39 | "import torch.optim as optim\n", 40 | "from torch.optim import lr_scheduler\n", 41 | "import torch.nn as nn\n", 42 | "import torch.nn.functional as F\n", 43 | "\n" 44 | ], 45 | "metadata": { 46 | "id": "-RM98MtEeTww" 47 | }, 48 | "execution_count": 68, 49 | "outputs": [] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "source": [ 54 | "# Loading of the MNIST dataset\n", 55 | "train_data = datasets.MNIST(\n", 56 | " root = 'data', # root: The root directory where the dataset should be stored. In this case, it is set to 'data'. If the 'data' directory doesn't exist, the dataset will be downloaded to this location.\n", 57 | " train = True,\n", 58 | " transform = transforms.ToTensor(), # transform: This parameter applies transformations to the data. In this case, transforms.ToTensor() is used to convert the images to PyTorch tensors.\n", 59 | " download = True,\n", 60 | ")\n", 61 | "test_data = datasets.MNIST(\n", 62 | " root = 'data',\n", 63 | " train = False,\n", 64 | " transform = transforms.ToTensor()\n", 65 | ")" 66 | ], 67 | "metadata": { 68 | "id": "_I845zHUeoXV" 69 | }, 70 | "execution_count": 69, 71 | "outputs": [] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "source": [ 76 | "#Setting of the main hyper-parameters of the model\n", 77 | "batch_size = 4 # The number of samples in each mini-batch used during training. Smaller batch sizes can lead to faster convergence but may introduce more noise into the training process.\n", 78 | "n_train = batch_size * 125 # The total size of the training dataset. It's calculated as the product of batch_size and the number of batches (125 in this case). Adjusting the training dataset size can impact the model's ability to generalize.\n", 79 | "n_test = batch_size * 25 # The total size of the test dataset. Similar to n_train, it's calculated as the product of batch_size and the number of test batches (25 in this case). The test dataset is used to evaluate the model's performance on unseen data.\n", 80 | "n_channels = 4 # The number of channels in the output of the quantum convolution layer. In your model, you have set it to 4. This parameter determines the depth of the feature maps produced by the convolutional layer.\n", 81 | "initial_lr = 0.005 # The initial learning rate for the stochastic gradient descent (SGD) optimizer." 82 | ], 83 | "metadata": { 84 | "id": "yFlOU3zGecUq" 85 | }, 86 | "execution_count": 70, 87 | "outputs": [] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "source": [ 92 | "class Net(nn.Module):\n", 93 | " def __init__(self):\n", 94 | " super(Net, self).__init__()\n", 95 | " # Convolutional layer 1 with 1 input channels, 4 output channels, and 2x2 kernel\n", 96 | " self.conv = nn.Conv2d(1, 1, 4, stride=4)\n", 97 | " self.fc = nn.Linear(1 * 7 * 7, 10)\n", 98 | "\n", 99 | " def forward(self, x):\n", 100 | " # Propagate the input through the CNN layers\n", 101 | " x = self.conv(x)\n", 102 | " # Flatten the output from the convolutional layer\n", 103 | " x = torch.flatten(x, start_dim=1)\n", 104 | " x = F.relu(self.fc(x))\n", 105 | " return x\n" 106 | ], 107 | "metadata": { 108 | "id": "-3CmM5H-BOh5" 109 | }, 110 | "execution_count": 71, 111 | "outputs": [] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "source": [ 116 | "cnn = Net()\n", 117 | "dataset = train_data\n", 118 | "train_size = n_train\n", 119 | "train_set, val_set = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])\n", 120 | "train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)\n", 121 | "for data in train_loader:\n", 122 | " inputs, labels = data\n", 123 | " print(f\"{inputs.shape=}\")\n", 124 | " print(f\"{labels=}\")\n", 125 | " outputs = cnn(inputs)\n", 126 | " print(f\"{outputs.shape=}\")\n", 127 | " print(f\"{outputs=}\")\n", 128 | " break" 129 | ], 130 | "metadata": { 131 | "colab": { 132 | "base_uri": "https://localhost:8080/" 133 | }, 134 | "id": "QlHwR2hdRzb2", 135 | "outputId": "e3a911b9-44d7-4d1f-ff6a-1f0f9f271f23" 136 | }, 137 | "execution_count": 72, 138 | "outputs": [ 139 | { 140 | "output_type": "stream", 141 | "name": "stdout", 142 | "text": [ 143 | "inputs.shape=torch.Size([4, 1, 28, 28])\n", 144 | "labels=tensor([5, 8, 7, 0])\n", 145 | "outputs.shape=torch.Size([4, 10])\n", 146 | "outputs=tensor([[0.0000, 0.0967, 0.0000, 0.0376, 0.0762, 0.0462, 0.0965, 0.0399, 0.0115,\n", 147 | " 0.0000],\n", 148 | " [0.0145, 0.0863, 0.0000, 0.0000, 0.0000, 0.0170, 0.1100, 0.0616, 0.0000,\n", 149 | " 0.0000],\n", 150 | " [0.0330, 0.0799, 0.0000, 0.0884, 0.0000, 0.0714, 0.0843, 0.1359, 0.0000,\n", 151 | " 0.0000],\n", 152 | " [0.0000, 0.1155, 0.0000, 0.0000, 0.0000, 0.0264, 0.1649, 0.0000, 0.0000,\n", 153 | " 0.0000]], grad_fn=)\n" 154 | ] 155 | } 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "source": [ 161 | "# Train the model\n", 162 | "\n", 163 | "import datetime\n", 164 | "import os\n", 165 | "\n", 166 | "dataset = train_data\n", 167 | "\n", 168 | "# Initialize your QCNN model\n", 169 | "cnn = Net()\n", 170 | "\n", 171 | "# Define loss function and optimizer\n", 172 | "criterion = nn.CrossEntropyLoss() # Cross-entropy loss for classification\n", 173 | "optimizer = optim.SGD(cnn.parameters(), lr=initial_lr, momentum=0.90) # Stochastic Gradient Descent optimizer\n", 174 | "# Create a learning rate scheduler\n", 175 | "# Here, we use StepLR which reduces the learning rate by a factor every step_size epochs\n", 176 | "scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=1.0)\n", 177 | "# Split your data into training and validation sets\n", 178 | "train_size = n_train\n", 179 | "train_set, val_set = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])\n", 180 | "train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)\n", 181 | "val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False)\n", 182 | "\n", 183 | "MODEL_PATH = Path(\"models\")\n", 184 | "MODEL_PATH.mkdir(parents=True, exist_ok=True)\n", 185 | "\n", 186 | "MODEL_NAME = \"ImgClass-Quanvolv.pth\"\n", 187 | "MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME\n", 188 | "\n", 189 | "RESUME_TRAINING = True\n", 190 | "\n", 191 | "# Training loop\n", 192 | "num_epochs = 20\n", 193 | "loss_list = []\n", 194 | "cnn.train()\n", 195 | "\n", 196 | "if RESUME_TRAINING is False:\n", 197 | " print(f\"Restore model state from {MODEL_SAVE_PATH}\")\n", 198 | " if os.path.exists(MODEL_SAVE_PATH):\n", 199 | " model_dict = torch.load(MODEL_SAVE_PATH)\n", 200 | " initial_epoch = model_dict['epoch'] + 1\n", 201 | " cnn.load_state_dict(model_dict['model_state_dict'])\n", 202 | " optimizer.load_state_dict(model_dict['optimizer_state_dict'])\n", 203 | " loss_list = model_dict['loss'].copy()\n", 204 | " else:\n", 205 | " print(f\"No saved model state found. Training from scratch.\")\n", 206 | " initial_epoch = 0\n", 207 | " loss_list = []\n", 208 | "else:\n", 209 | " initial_epoch = 0\n", 210 | " loss_list = []\n", 211 | "\n", 212 | "for epoch in range(num_epochs):\n", 213 | " ct = datetime.datetime.now()\n", 214 | " # Decay Learning Rate\n", 215 | " optimizer.step()\n", 216 | " scheduler.step()\n", 217 | " lr = scheduler.get_last_lr()\n", 218 | " print(f\"{epoch=}, {lr=}, {ct}\")\n", 219 | " running_loss = []\n", 220 | " for i, data in enumerate(train_loader, 0):\n", 221 | " inputs, labels = data\n", 222 | " optimizer.zero_grad() # Zero the parameter gradients to avoid accumulation\n", 223 | " outputs = cnn(inputs) # Forward pass\n", 224 | " loss = criterion(outputs, labels) # Compute the loss\n", 225 | " loss.backward() # Backpropagation\n", 226 | " running_loss.append(loss.item())\n", 227 | " optimizer.step() # Update the model parameters\n", 228 | " loss_list.append(sum(running_loss) / len(running_loss))\n", 229 | " print('Training [{:.0f}%]\\tLoss: {:.4f}'.format(100. * (epoch + 1) / num_epochs, loss_list[-1]))\n", 230 | " torch.save({\n", 231 | " 'epoch': epoch,\n", 232 | " 'model_state_dict': cnn.state_dict(),\n", 233 | " 'optimizer_state_dict': optimizer.state_dict(),\n", 234 | " 'loss': loss_list,\n", 235 | " }, MODEL_SAVE_PATH)\n", 236 | " print(f\"Saving model state to {MODEL_SAVE_PATH}\")\n", 237 | "\n", 238 | "print('Finished Training')" 239 | ], 240 | "metadata": { 241 | "id": "RCxAXutdvoRg", 242 | "colab": { 243 | "base_uri": "https://localhost:8080/" 244 | }, 245 | "outputId": "7a6d30f5-db73-493f-e2d2-7f9b5842128d" 246 | }, 247 | "execution_count": 73, 248 | "outputs": [ 249 | { 250 | "output_type": "stream", 251 | "name": "stdout", 252 | "text": [ 253 | "epoch=0, lr=[0.005], 2024-01-06 04:47:25.174867\n", 254 | "Training [5%]\tLoss: 2.2966\n", 255 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 256 | "epoch=1, lr=[0.005], 2024-01-06 04:47:25.448558\n", 257 | "Training [10%]\tLoss: 2.2018\n", 258 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 259 | "epoch=2, lr=[0.005], 2024-01-06 04:47:25.687253\n", 260 | "Training [15%]\tLoss: 1.5487\n", 261 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 262 | "epoch=3, lr=[0.005], 2024-01-06 04:47:25.974222\n", 263 | "Training [20%]\tLoss: 1.0240\n", 264 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 265 | "epoch=4, lr=[0.005], 2024-01-06 04:47:26.233153\n", 266 | "Training [25%]\tLoss: 0.8959\n", 267 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 268 | "epoch=5, lr=[0.005], 2024-01-06 04:47:26.618003\n", 269 | "Training [30%]\tLoss: 0.7970\n", 270 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 271 | "epoch=6, lr=[0.005], 2024-01-06 04:47:26.983339\n", 272 | "Training [35%]\tLoss: 0.7453\n", 273 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 274 | "epoch=7, lr=[0.005], 2024-01-06 04:47:27.339387\n", 275 | "Training [40%]\tLoss: 0.7120\n", 276 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 277 | "epoch=8, lr=[0.005], 2024-01-06 04:47:27.750256\n", 278 | "Training [45%]\tLoss: 0.6890\n", 279 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 280 | "epoch=9, lr=[0.005], 2024-01-06 04:47:28.095177\n", 281 | "Training [50%]\tLoss: 0.6882\n", 282 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 283 | "epoch=10, lr=[0.005], 2024-01-06 04:47:28.431352\n", 284 | "Training [55%]\tLoss: 0.6109\n", 285 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 286 | "epoch=11, lr=[0.005], 2024-01-06 04:47:28.827147\n", 287 | "Training [60%]\tLoss: 0.6215\n", 288 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 289 | "epoch=12, lr=[0.005], 2024-01-06 04:47:29.189826\n", 290 | "Training [65%]\tLoss: 0.6004\n", 291 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 292 | "epoch=13, lr=[0.005], 2024-01-06 04:47:29.535841\n", 293 | "Training [70%]\tLoss: 0.5856\n", 294 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 295 | "epoch=14, lr=[0.005], 2024-01-06 04:47:29.801140\n", 296 | "Training [75%]\tLoss: 0.5956\n", 297 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 298 | "epoch=15, lr=[0.005], 2024-01-06 04:47:30.041473\n", 299 | "Training [80%]\tLoss: 0.5829\n", 300 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 301 | "epoch=16, lr=[0.005], 2024-01-06 04:47:30.319492\n", 302 | "Training [85%]\tLoss: 0.5417\n", 303 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 304 | "epoch=17, lr=[0.005], 2024-01-06 04:47:30.593614\n", 305 | "Training [90%]\tLoss: 0.5126\n", 306 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 307 | "epoch=18, lr=[0.005], 2024-01-06 04:47:30.860810\n", 308 | "Training [95%]\tLoss: 0.5446\n", 309 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 310 | "epoch=19, lr=[0.005], 2024-01-06 04:47:31.128450\n", 311 | "Training [100%]\tLoss: 0.5426\n", 312 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 313 | "Finished Training\n" 314 | ] 315 | } 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "source": [ 321 | "cnn.state_dict()" 322 | ], 323 | "metadata": { 324 | "id": "nUl3J1rhtfvQ", 325 | "colab": { 326 | "base_uri": "https://localhost:8080/" 327 | }, 328 | "outputId": "92e7e31b-6fa1-4e25-cf1f-376c3ba5ab77" 329 | }, 330 | "execution_count": 74, 331 | "outputs": [ 332 | { 333 | "output_type": "execute_result", 334 | "data": { 335 | "text/plain": [ 336 | "OrderedDict([('conv.weight',\n", 337 | " tensor([[[[ 1.4048, 0.3264, 1.0383, 1.2971],\n", 338 | " [ 1.1972, 0.3413, 0.8934, 0.2260],\n", 339 | " [ 0.3179, 0.7019, 0.4972, -0.0530],\n", 340 | " [ 0.4917, -0.2844, 0.6161, 0.6083]]]])),\n", 341 | " ('conv.bias', tensor([-1.2372])),\n", 342 | " ('fc.weight',\n", 343 | " tensor([[ 3.9516e-02, -8.9391e-03, 2.9735e-03, -7.2266e-02, -1.0897e-01,\n", 344 | " 4.0562e-02, 1.3330e-02, -1.3625e-02, 2.5555e-01, 8.4740e-02,\n", 345 | " 2.7625e-01, 3.2129e-01, -3.1033e-02, -2.0671e-01, 7.9392e-02,\n", 346 | " -5.4600e-01, 2.4809e-01, 4.2205e-01, 5.9236e-02, 2.6793e-01,\n", 347 | " 1.6021e-02, 1.3742e-01, 9.3207e-02, 6.1311e-01, -1.1547e+00,\n", 348 | " -6.1099e-01, 6.1775e-01, 2.1080e-01, -2.0552e-02, 2.5187e-01,\n", 349 | " 7.2714e-01, -1.0613e+00, 3.2208e-02, 2.1923e-01, 3.7639e-01,\n", 350 | " -7.7008e-02, -4.6974e-02, 9.1232e-01, -1.2347e-01, -1.8920e-01,\n", 351 | " 2.4733e-01, 1.2517e-01, -1.0238e-01, 1.6751e-02, 5.9525e-02,\n", 352 | " 1.1069e-01, -1.8601e-01, -5.3204e-03, 1.6881e-02],\n", 353 | " [-3.6192e-02, 5.4493e-02, 3.1680e-02, 4.1406e-02, -1.2708e-01,\n", 354 | " 2.2807e-02, 2.1444e-02, 8.4755e-02, 7.0894e-02, 5.6158e-02,\n", 355 | " -4.7497e-02, -1.0310e-01, 7.9952e-02, 7.7607e-02, 5.6697e-02,\n", 356 | " -9.2312e-02, -5.0002e-02, -3.8068e-02, -1.5994e-02, -1.1188e-01,\n", 357 | " 9.6931e-02, -1.8965e-02, -2.7969e-02, 3.6021e-02, -1.0572e-01,\n", 358 | " 1.4151e-02, -6.0890e-02, -1.0724e-02, -5.8725e-02, -7.4264e-02,\n", 359 | " -1.2739e-01, -9.4618e-02, -6.0154e-02, 6.5751e-02, -2.7802e-02,\n", 360 | " 1.2717e-01, 8.0848e-02, -8.4595e-02, 2.0592e-02, 1.0631e-01,\n", 361 | " -1.0776e-01, -8.3935e-02, 1.6426e-02, 1.0841e-01, 2.3631e-02,\n", 362 | " -1.0286e-01, 7.8699e-02, 9.4374e-02, 2.8440e-02],\n", 363 | " [-1.5881e-01, 4.4943e-02, 1.3225e-01, 4.3745e-03, -5.1543e-02,\n", 364 | " 5.6042e-02, -1.7321e-01, -1.7050e-01, 3.2444e-03, 3.4765e-01,\n", 365 | " 6.3836e-01, -4.3278e-01, -3.7222e-01, 2.2531e-01, 1.0259e-01,\n", 366 | " 5.3528e-02, 4.2847e-01, -2.2931e-02, 3.0123e-01, -3.5106e-01,\n", 367 | " -4.4605e-02, -7.3440e-02, -6.5674e-01, -6.5953e-01, -1.0008e+00,\n", 368 | " -1.4382e-01, -2.0706e-01, 1.8900e-02, -1.6820e-01, 9.6760e-01,\n", 369 | " 4.8350e-01, 2.9994e-01, 1.8156e-01, -1.4609e-01, 1.4241e-01,\n", 370 | " -5.9694e-02, 7.6400e-01, 5.0697e-01, -3.1419e-02, 1.0828e-01,\n", 371 | " 9.6191e-01, 2.5402e-01, -9.7012e-02, -4.9915e-02, -3.0773e-01,\n", 372 | " 1.9505e-02, -1.3685e-01, -1.5747e-01, 2.2056e-02],\n", 373 | " [-1.3179e-01, -2.5032e-02, -1.5924e-01, -1.6590e-01, 1.0881e-01,\n", 374 | " -4.9108e-02, 8.6894e-03, -4.4048e-02, 2.1537e-01, 6.9329e-01,\n", 375 | " 2.8167e-01, 4.6076e-01, 2.6266e-01, -1.2818e-03, -1.4891e-01,\n", 376 | " -2.0506e-01, -6.7152e-01, -8.1662e-02, 3.1654e-01, 3.0378e-01,\n", 377 | " -1.8028e-01, -3.0421e-02, -3.6418e-01, -5.1703e-01, 1.4932e-01,\n", 378 | " 2.3619e-01, -5.0433e-01, -2.5283e-02, -1.9715e-01, 2.9533e-01,\n", 379 | " -4.4303e-01, -1.2122e+00, 8.7635e-01, 6.4722e-01, -1.3637e-01,\n", 380 | " 3.9300e-02, 2.3581e-01, -5.8410e-01, 1.1449e-01, 4.1370e-01,\n", 381 | " -1.4592e-01, -1.6116e-01, -1.3589e-02, 9.1764e-01, 7.5564e-01,\n", 382 | " 7.6356e-01, -2.6992e-02, 2.9534e-02, 4.7044e-02],\n", 383 | " [ 1.7485e-01, -2.6919e-02, 1.4817e-01, 1.4692e-01, -3.0932e-02,\n", 384 | " 1.1175e-01, 1.0811e-01, -4.7629e-02, -7.6028e-02, -3.7016e-01,\n", 385 | " -9.0677e-01, 3.1150e-02, 2.8712e-01, 1.1951e-01, -1.5880e-02,\n", 386 | " -9.3328e-03, -3.7832e-01, -1.0750e+00, -3.2539e-01, 2.2106e-01,\n", 387 | " 4.4113e-02, 2.3231e-03, 2.7959e-01, 1.7489e+00, -2.6656e-01,\n", 388 | " 6.3533e-01, 1.0301e-01, 5.6956e-02, -3.9100e-02, -1.6969e-03,\n", 389 | " 4.9899e-01, 1.1352e-01, 4.8352e-01, 1.4766e-01, -1.5708e-01,\n", 390 | " 1.0279e-01, -3.6780e-01, -1.1283e+00, -4.8725e-01, -2.9683e-01,\n", 391 | " -1.9176e-01, 3.0236e-02, 4.5318e-02, 1.0482e-02, -2.7236e-01,\n", 392 | " -3.4472e-01, -4.5023e-01, 1.4275e-02, 1.1712e-01],\n", 393 | " [-1.7610e-01, -2.2336e-01, -9.2900e-02, -2.0519e-01, -5.9762e-02,\n", 394 | " -1.9503e-01, -1.2590e-01, -2.9694e-01, -1.4031e-01, -4.4186e-01,\n", 395 | " -1.4337e-01, -6.6036e-02, 6.7138e-01, 8.2501e-02, -1.6842e-01,\n", 396 | " -1.7413e-01, 3.1763e-01, -2.0236e-02, -4.4402e-01, 6.3977e-01,\n", 397 | " 7.4786e-01, -4.0373e-02, -2.9013e-02, 8.0607e-01, 2.6007e-01,\n", 398 | " -6.2151e-01, -4.1412e-01, -2.9010e-01, -4.5549e-02, 5.5743e-04,\n", 399 | " -5.1935e-01, -1.1232e+00, -3.1870e-01, 1.6905e-01, 2.6397e-01,\n", 400 | " -1.2840e-01, -2.8200e-01, 3.7611e-01, 2.2851e-01, 3.6634e-01,\n", 401 | " 3.8307e-02, 2.2412e-02, -1.1541e-01, -2.9553e-02, 2.3179e-01,\n", 402 | " 6.6799e-01, -3.3803e-01, -4.4550e-02, -2.7134e-01],\n", 403 | " [ 1.3825e-01, 2.4101e-02, 2.3982e-01, 1.7226e-01, 2.6642e-01,\n", 404 | " 6.6927e-02, -4.3618e-03, 7.6473e-02, -4.7733e-03, 1.5782e-01,\n", 405 | " -6.1078e-01, -1.4987e-01, 2.8631e-02, 1.8857e-01, 1.1530e-01,\n", 406 | " -2.3441e-01, -8.0758e-02, -5.1254e-02, -1.3241e+00, -8.2284e-01,\n", 407 | " -1.2278e-01, 1.9420e-01, -1.9553e-02, 3.7280e-01, -1.6782e-01,\n", 408 | " 2.3166e-02, 3.7654e-01, 3.2446e-01, 1.6999e-01, 1.3028e-01,\n", 409 | " 9.9995e-01, 3.8728e-01, 3.5863e-01, 5.7643e-01, 5.1182e-02,\n", 410 | " -1.1926e-03, -2.9085e-03, 1.6399e-01, 5.5114e-01, 4.6883e-01,\n", 411 | " -2.6006e-01, -9.9516e-03, 1.8048e-01, 5.1141e-03, -3.0927e-01,\n", 412 | " -2.2117e-01, -1.8719e-02, 2.0913e-01, 2.0808e-01],\n", 413 | " [ 4.0271e-02, -9.0818e-02, 4.0012e-02, 6.0681e-02, 8.4752e-02,\n", 414 | " -2.5054e-02, 1.1475e-01, -8.8231e-02, 2.9376e-01, -6.5522e-02,\n", 415 | " -5.3953e-01, -5.8608e-01, -1.4010e-01, -1.2081e-01, -5.4410e-02,\n", 416 | " 2.5739e-01, 4.1932e-01, 4.0477e-01, 5.7905e-01, 4.7881e-01,\n", 417 | " -3.7136e-02, 1.4731e-01, -1.7544e-01, -6.6157e-01, -1.0747e+00,\n", 418 | " 4.5838e-01, 3.6600e-02, -1.1188e-01, 4.4505e-02, -7.8272e-02,\n", 419 | " -7.3755e-01, 4.0217e-01, 4.4128e-01, -1.5443e-01, 3.9592e-03,\n", 420 | " 5.5246e-02, -2.4772e-01, -6.1794e-01, -3.0037e-02, 4.6874e-02,\n", 421 | " -4.9648e-01, -3.5576e-03, 3.0834e-02, 1.4282e-01, -1.8585e-01,\n", 422 | " -8.7112e-02, 2.0395e-01, 1.3959e-01, -7.6560e-02],\n", 423 | " [ 1.2779e-01, 1.5208e-01, 4.0023e-02, 8.8653e-03, -3.5405e-02,\n", 424 | " 1.9084e-01, -3.9200e-02, -6.1964e-02, -1.7048e-01, -9.1533e-02,\n", 425 | " 4.8280e-01, 8.4810e-01, -4.5259e-01, 6.4811e-02, -1.9177e-02,\n", 426 | " 4.9078e-01, 6.9568e-01, -7.6542e-02, -3.4126e-01, 1.0035e+00,\n", 427 | " 7.9551e-01, -8.3518e-02, 1.1030e-02, -1.4257e-01, 2.3750e-01,\n", 428 | " 4.8162e-01, -1.3487e-02, -6.3713e-02, -6.7853e-02, -4.9875e-01,\n", 429 | " 3.0822e-01, 8.9436e-03, -2.3644e-01, -1.0793e-01, -1.5945e-01,\n", 430 | " 8.4852e-03, -4.3368e-01, -9.3560e-02, -2.8644e-01, 9.3931e-02,\n", 431 | " -4.5864e-01, -7.2238e-02, 1.8250e-01, -3.8495e-01, -2.3065e-01,\n", 432 | " 4.3112e-01, 2.0016e-01, 1.5119e-01, 2.7250e-02],\n", 433 | " [-6.7383e-02, -7.6093e-02, -3.0245e-02, -5.7215e-02, 2.5727e-02,\n", 434 | " -8.6782e-02, -5.4652e-02, -1.3550e-01, -1.2150e-01, -6.3947e-01,\n", 435 | " 2.3127e-01, -2.0515e-01, -4.5082e-01, -2.5826e-02, -3.2389e-02,\n", 436 | " 1.1141e-01, 2.4593e-02, 5.0594e-02, -1.9517e-01, 1.7442e-01,\n", 437 | " -2.8398e-01, 1.7936e-04, 5.1665e-01, 4.9924e-01, -5.1164e-01,\n", 438 | " 5.6993e-01, -8.9654e-02, -1.3680e-01, -1.2141e-01, -3.7746e-01,\n", 439 | " 5.3835e-01, -4.1166e-01, 4.2543e-01, -5.7350e-01, -2.1468e-02,\n", 440 | " 8.0114e-02, -2.0817e-01, -1.2528e+00, -2.6281e-01, -6.4290e-01,\n", 441 | " 2.4594e-01, -2.8212e-02, -5.7839e-02, -2.3948e-01, 6.6199e-01,\n", 442 | " 3.3856e-01, 9.2445e-01, 1.0376e-01, 1.2107e-02]])),\n", 443 | " ('fc.bias',\n", 444 | " tensor([-0.0784, -0.1995, -0.0360, 0.1387, 0.1450, 0.3877, -0.1174, 0.0136,\n", 445 | " -0.1624, 0.1925]))])" 446 | ] 447 | }, 448 | "metadata": {}, 449 | "execution_count": 74 450 | } 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "source": [ 456 | "#accuracy\n", 457 | "\n", 458 | "# Use a small subset of the full validation dataset\n", 459 | "from torch.utils.data import SubsetRandomSampler\n", 460 | "\n", 461 | "K = n_test # enter your length here\n", 462 | "subsample_train_indices = torch.randperm(len(val_set))[:K]\n", 463 | "val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, sampler=SubsetRandomSampler(subsample_train_indices))\n", 464 | "\n", 465 | "correct = 0\n", 466 | "total = 0\n", 467 | "# Set the model to evaluation mode\n", 468 | "cnn.eval()\n", 469 | "with torch.inference_mode():\n", 470 | " for data in val_loader:\n", 471 | " images, labels = data\n", 472 | " outputs = cnn(images)\n", 473 | " _, predicted = torch.max(outputs.data, 1)\n", 474 | " total += labels.size(0)\n", 475 | " correct += (predicted == labels).sum().item()\n", 476 | "print(f'Accuracy on the validation set: {100 * correct / total:.2f}%')" 477 | ], 478 | "metadata": { 479 | "id": "EZCJdLCoqc10", 480 | "colab": { 481 | "base_uri": "https://localhost:8080/" 482 | }, 483 | "outputId": "78c296c0-4ccf-4066-be59-58f2375eaf3b" 484 | }, 485 | "execution_count": 81, 486 | "outputs": [ 487 | { 488 | "output_type": "stream", 489 | "name": "stdout", 490 | "text": [ 491 | "Accuracy on the validation set: 79.00%\n" 492 | ] 493 | } 494 | ] 495 | } 496 | ] 497 | } -------------------------------------------------------------------------------- /CNN4.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [] 7 | }, 8 | "kernelspec": { 9 | "name": "python3", 10 | "display_name": "Python 3" 11 | }, 12 | "language_info": { 13 | "name": "python" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "source": [ 20 | "#CNN4" 21 | ], 22 | "metadata": { 23 | "id": "b56c8rM5cN-U" 24 | } 25 | }, 26 | { 27 | "cell_type": "code", 28 | "source": [ 29 | "import numpy as np\n", 30 | "import matplotlib.pyplot as plt\n", 31 | "\n", 32 | "from pathlib import Path\n", 33 | "\n", 34 | "import torch\n", 35 | "from torch.autograd import Function\n", 36 | "from torchvision import datasets, transforms\n", 37 | "import torch.optim as optim\n", 38 | "from torch.optim import lr_scheduler\n", 39 | "import torch.nn as nn\n", 40 | "import torch.nn.functional as F\n", 41 | "\n" 42 | ], 43 | "metadata": { 44 | "id": "-RM98MtEeTww" 45 | }, 46 | "execution_count": 153, 47 | "outputs": [] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "source": [ 52 | "# Loading of the MNIST dataset\n", 53 | "train_data = datasets.MNIST(\n", 54 | " root = 'data', # root: The root directory where the dataset should be stored. In this case, it is set to 'data'. If the 'data' directory doesn't exist, the dataset will be downloaded to this location.\n", 55 | " train = True,\n", 56 | " transform = transforms.ToTensor(), # transform: This parameter applies transformations to the data. In this case, transforms.ToTensor() is used to convert the images to PyTorch tensors.\n", 57 | " download = True,\n", 58 | ")\n", 59 | "test_data = datasets.MNIST(\n", 60 | " root = 'data',\n", 61 | " train = False,\n", 62 | " transform = transforms.ToTensor()\n", 63 | ")" 64 | ], 65 | "metadata": { 66 | "id": "_I845zHUeoXV" 67 | }, 68 | "execution_count": 154, 69 | "outputs": [] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "source": [ 74 | "#Setting of the main hyper-parameters of the model\n", 75 | "batch_size = 4 # The number of samples in each mini-batch used during training. Smaller batch sizes can lead to faster convergence but may introduce more noise into the training process.\n", 76 | "n_train = batch_size * 125 # The total size of the training dataset. It's calculated as the product of batch_size and the number of batches (125 in this case). Adjusting the training dataset size can impact the model's ability to generalize.\n", 77 | "n_test = batch_size * 25 # The total size of the test dataset. Similar to n_train, it's calculated as the product of batch_size and the number of test batches (25 in this case). The test dataset is used to evaluate the model's performance on unseen data.\n", 78 | "n_channels = 4 # The number of channels in the output of the quantum convolution layer. In your model, you have set it to 4. This parameter determines the depth of the feature maps produced by the convolutional layer.\n", 79 | "initial_lr = 0.005 # The initial learning rate for the stochastic gradient descent (SGD) optimizer." 80 | ], 81 | "metadata": { 82 | "id": "yFlOU3zGecUq" 83 | }, 84 | "execution_count": 155, 85 | "outputs": [] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "source": [ 90 | "class Net(nn.Module):\n", 91 | " def __init__(self):\n", 92 | " super(Net, self).__init__()\n", 93 | " # Convolutional layer 1 with 1 input channels, 4 output channels, and 4x4 kernel\n", 94 | " self.conv = nn.Conv2d(1, 4, 4, stride=4)\n", 95 | " self.fc = nn.Linear(4 * 7 * 7, 10)\n", 96 | "\n", 97 | " def forward(self, x):\n", 98 | " # Propagate the input through the CNN layers\n", 99 | " x = self.conv(x)\n", 100 | " # Flatten the output from the convolutional layer\n", 101 | " x = torch.flatten(x, start_dim=1)\n", 102 | " x = F.relu(self.fc(x))\n", 103 | " return x\n", 104 | "cnn=Net()" 105 | ], 106 | "metadata": { 107 | "id": "-3CmM5H-BOh5" 108 | }, 109 | "execution_count": 156, 110 | "outputs": [] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "source": [ 115 | "dataset = train_data\n", 116 | "train_size = n_train\n", 117 | "train_set, val_set = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])\n", 118 | "train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)\n", 119 | "for data in train_loader:\n", 120 | " inputs, labels = data\n", 121 | " print(f\"{inputs.shape=}\")\n", 122 | " print(f\"{labels=}\")\n", 123 | " outputs = cnn(inputs)\n", 124 | " print(f\"{outputs.shape=}\")\n", 125 | " print(f\"{outputs=}\")\n", 126 | " break" 127 | ], 128 | "metadata": { 129 | "colab": { 130 | "base_uri": "https://localhost:8080/" 131 | }, 132 | "id": "QlHwR2hdRzb2", 133 | "outputId": "d4510d66-8f4f-4807-f13f-30d9c0402f10" 134 | }, 135 | "execution_count": 157, 136 | "outputs": [ 137 | { 138 | "output_type": "stream", 139 | "name": "stdout", 140 | "text": [ 141 | "inputs.shape=torch.Size([4, 1, 28, 28])\n", 142 | "labels=tensor([4, 0, 9, 2])\n", 143 | "outputs.shape=torch.Size([4, 10])\n", 144 | "outputs=tensor([[0.0000, 0.2931, 0.0000, 0.0509, 0.1659, 0.0686, 0.0000, 0.0630, 0.0587,\n", 145 | " 0.2090],\n", 146 | " [0.0681, 0.1966, 0.0000, 0.0923, 0.0608, 0.0000, 0.0000, 0.0133, 0.0989,\n", 147 | " 0.0934],\n", 148 | " [0.0000, 0.2260, 0.0000, 0.0000, 0.1664, 0.0000, 0.0000, 0.1725, 0.0000,\n", 149 | " 0.4029],\n", 150 | " [0.0000, 0.2944, 0.0000, 0.0733, 0.1128, 0.0000, 0.0000, 0.0473, 0.0000,\n", 151 | " 0.1634]], grad_fn=)\n" 152 | ] 153 | } 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "source": [ 159 | "# Train the model\n", 160 | "\n", 161 | "import datetime\n", 162 | "import os\n", 163 | "\n", 164 | "dataset = train_data\n", 165 | "\n", 166 | "# Initialize your QCNN model\n", 167 | "cnn = Net()\n", 168 | "\n", 169 | "# Define loss function and optimizer\n", 170 | "criterion = nn.CrossEntropyLoss() # Cross-entropy loss for classification\n", 171 | "optimizer = optim.SGD(cnn.parameters(), lr=initial_lr, momentum=0.90) # Stochastic Gradient Descent optimizer\n", 172 | "# Create a learning rate scheduler\n", 173 | "# Here, we use StepLR which reduces the learning rate by a factor every step_size epochs\n", 174 | "scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=1.0)\n", 175 | "# Split your data into training and validation sets\n", 176 | "train_size = n_train\n", 177 | "train_set, val_set = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])\n", 178 | "train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)\n", 179 | "val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False)\n", 180 | "\n", 181 | "MODEL_PATH = Path(\"models\")\n", 182 | "MODEL_PATH.mkdir(parents=True, exist_ok=True)\n", 183 | "\n", 184 | "MODEL_NAME = \"ImgClass-Quanvolv.pth\"\n", 185 | "MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME\n", 186 | "\n", 187 | "RESUME_TRAINING = True\n", 188 | "\n", 189 | "# Training loop\n", 190 | "num_epochs = 20\n", 191 | "loss_list = []\n", 192 | "cnn.train()\n", 193 | "\n", 194 | "if RESUME_TRAINING is False:\n", 195 | " print(f\"Restore model state from {MODEL_SAVE_PATH}\")\n", 196 | " if os.path.exists(MODEL_SAVE_PATH):\n", 197 | " model_dict = torch.load(MODEL_SAVE_PATH)\n", 198 | " initial_epoch = model_dict['epoch'] + 1\n", 199 | " cnn.load_state_dict(model_dict['model_state_dict'])\n", 200 | " optimizer.load_state_dict(model_dict['optimizer_state_dict'])\n", 201 | " loss_list = model_dict['loss'].copy()\n", 202 | " else:\n", 203 | " print(f\"No saved model state found. Training from scratch.\")\n", 204 | " initial_epoch = 0\n", 205 | " loss_list = []\n", 206 | "else:\n", 207 | " initial_epoch = 0\n", 208 | " loss_list = []\n", 209 | "\n", 210 | "for epoch in range(num_epochs):\n", 211 | " ct = datetime.datetime.now()\n", 212 | " # Decay Learning Rate\n", 213 | " optimizer.step()\n", 214 | " scheduler.step()\n", 215 | " lr = scheduler.get_last_lr()\n", 216 | " print(f\"{epoch=}, {lr=}, {ct}\")\n", 217 | " running_loss = []\n", 218 | " for i, data in enumerate(train_loader, 0):\n", 219 | " inputs, labels = data\n", 220 | " optimizer.zero_grad() # Zero the parameter gradients to avoid accumulation\n", 221 | " outputs = cnn(inputs) # Forward pass\n", 222 | " loss = criterion(outputs, labels) # Compute the loss\n", 223 | " loss.backward() # Backpropagation\n", 224 | " running_loss.append(loss.item())\n", 225 | " optimizer.step() # Update the model parameters\n", 226 | " loss_list.append(sum(running_loss) / len(running_loss))\n", 227 | " print('Training [{:.0f}%]\\tLoss: {:.4f}'.format(100. * (epoch + 1) / num_epochs, loss_list[-1]))\n", 228 | " torch.save({\n", 229 | " 'epoch': epoch,\n", 230 | " 'model_state_dict': cnn.state_dict(),\n", 231 | " 'optimizer_state_dict': optimizer.state_dict(),\n", 232 | " 'loss': loss_list,\n", 233 | " }, MODEL_SAVE_PATH)\n", 234 | " print(f\"Saving model state to {MODEL_SAVE_PATH}\")\n", 235 | "\n", 236 | "print('Finished Training')" 237 | ], 238 | "metadata": { 239 | "id": "RCxAXutdvoRg", 240 | "colab": { 241 | "base_uri": "https://localhost:8080/" 242 | }, 243 | "outputId": "4668e779-7c7e-4410-9b25-695de58a33b8" 244 | }, 245 | "execution_count": 158, 246 | "outputs": [ 247 | { 248 | "output_type": "stream", 249 | "name": "stdout", 250 | "text": [ 251 | "epoch=0, lr=[0.005], 2024-01-06 04:49:22.671176\n", 252 | "Training [5%]\tLoss: 2.2268\n", 253 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 254 | "epoch=1, lr=[0.005], 2024-01-06 04:49:22.895774\n", 255 | "Training [10%]\tLoss: 1.6212\n", 256 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 257 | "epoch=2, lr=[0.005], 2024-01-06 04:49:23.117166\n", 258 | "Training [15%]\tLoss: 1.2250\n", 259 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 260 | "epoch=3, lr=[0.005], 2024-01-06 04:49:23.357329\n", 261 | "Training [20%]\tLoss: 1.1190\n", 262 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 263 | "epoch=4, lr=[0.005], 2024-01-06 04:49:23.570795\n", 264 | "Training [25%]\tLoss: 1.0462\n", 265 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 266 | "epoch=5, lr=[0.005], 2024-01-06 04:49:23.784763\n", 267 | "Training [30%]\tLoss: 1.0016\n", 268 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 269 | "epoch=6, lr=[0.005], 2024-01-06 04:49:23.999354\n", 270 | "Training [35%]\tLoss: 0.9918\n", 271 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 272 | "epoch=7, lr=[0.005], 2024-01-06 04:49:24.208826\n", 273 | "Training [40%]\tLoss: 0.9384\n", 274 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 275 | "epoch=8, lr=[0.005], 2024-01-06 04:49:24.427452\n", 276 | "Training [45%]\tLoss: 0.9241\n", 277 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 278 | "epoch=9, lr=[0.005], 2024-01-06 04:49:24.643629\n", 279 | "Training [50%]\tLoss: 0.8962\n", 280 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 281 | "epoch=10, lr=[0.005], 2024-01-06 04:49:24.859728\n", 282 | "Training [55%]\tLoss: 0.8939\n", 283 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 284 | "epoch=11, lr=[0.005], 2024-01-06 04:49:25.069977\n", 285 | "Training [60%]\tLoss: 0.8646\n", 286 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 287 | "epoch=12, lr=[0.005], 2024-01-06 04:49:25.273394\n", 288 | "Training [65%]\tLoss: 0.8554\n", 289 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 290 | "epoch=13, lr=[0.005], 2024-01-06 04:49:25.497585\n", 291 | "Training [70%]\tLoss: 0.8444\n", 292 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 293 | "epoch=14, lr=[0.005], 2024-01-06 04:49:25.707545\n", 294 | "Training [75%]\tLoss: 0.8379\n", 295 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 296 | "epoch=15, lr=[0.005], 2024-01-06 04:49:25.921570\n", 297 | "Training [80%]\tLoss: 0.8275\n", 298 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 299 | "epoch=16, lr=[0.005], 2024-01-06 04:49:26.133493\n", 300 | "Training [85%]\tLoss: 0.8217\n", 301 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 302 | "epoch=17, lr=[0.005], 2024-01-06 04:49:26.350431\n", 303 | "Training [90%]\tLoss: 0.8122\n", 304 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 305 | "epoch=18, lr=[0.005], 2024-01-06 04:49:26.589532\n", 306 | "Training [95%]\tLoss: 0.8084\n", 307 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 308 | "epoch=19, lr=[0.005], 2024-01-06 04:49:26.886995\n", 309 | "Training [100%]\tLoss: 0.8084\n", 310 | "Saving model state to models/ImgClass-Quanvolv.pth\n", 311 | "Finished Training\n" 312 | ] 313 | } 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "source": [ 319 | "cnn.state_dict()" 320 | ], 321 | "metadata": { 322 | "id": "nUl3J1rhtfvQ", 323 | "colab": { 324 | "base_uri": "https://localhost:8080/" 325 | }, 326 | "outputId": "be827968-f89b-4514-8e5a-d9efb0532f0a" 327 | }, 328 | "execution_count": 159, 329 | "outputs": [ 330 | { 331 | "output_type": "execute_result", 332 | "data": { 333 | "text/plain": [ 334 | "OrderedDict([('conv.weight',\n", 335 | " tensor([[[[-9.2388e-01, -5.8929e-01, -6.7031e-01, -5.2077e-01],\n", 336 | " [-1.0605e+00, -1.0203e+00, -8.6686e-01, -3.9097e-01],\n", 337 | " [-1.0201e+00, -6.6626e-01, -2.3081e-01, -3.9852e-01],\n", 338 | " [-8.1234e-01, -2.5749e-01, 8.2552e-04, 2.5906e-02]]],\n", 339 | " \n", 340 | " \n", 341 | " [[[-1.7290e-01, -1.0526e-01, 2.5084e-02, 5.1990e-01],\n", 342 | " [ 4.0717e-02, -6.1931e-02, 2.3308e-01, 1.1402e-01],\n", 343 | " [-1.9381e-01, -2.5186e-01, -1.3225e-01, -7.5126e-01],\n", 344 | " [-5.5205e-01, -6.6428e-01, -2.5852e-01, -9.1830e-01]]],\n", 345 | " \n", 346 | " \n", 347 | " [[[ 1.7420e-01, 4.1984e-01, 9.8764e-01, 5.3831e-01],\n", 348 | " [-1.9157e-02, -6.4329e-02, 3.8352e-01, 5.9392e-01],\n", 349 | " [ 6.6614e-02, 1.3925e-01, 8.1434e-01, 1.3035e+00],\n", 350 | " [-4.7035e-02, 5.9577e-01, 1.0783e+00, 1.1108e+00]]],\n", 351 | " \n", 352 | " \n", 353 | " [[[ 8.2520e-01, 6.2562e-01, 1.9157e-01, -2.6445e-01],\n", 354 | " [ 6.4359e-01, 5.0645e-01, -1.4111e-01, -4.5830e-01],\n", 355 | " [ 2.8926e-01, -2.6936e-02, -5.6172e-01, -6.2375e-01],\n", 356 | " [-4.8164e-01, -2.0072e-01, -6.4104e-01, -8.8802e-01]]]])),\n", 357 | " ('conv.bias', tensor([ 0.6154, 0.2659, -0.6529, -0.0022])),\n", 358 | " ('fc.weight',\n", 359 | " tensor([[-0.0908, -0.0592, 0.0221, ..., -0.0667, 0.0262, -0.0283],\n", 360 | " [-0.0855, -0.0359, -0.0147, ..., 0.0358, 0.0183, -0.0358],\n", 361 | " [ 0.0523, 0.0282, -0.0601, ..., -0.1361, -0.0186, -0.0759],\n", 362 | " ...,\n", 363 | " [-0.0512, 0.0172, 0.0214, ..., 0.0456, 0.0002, -0.0607],\n", 364 | " [-0.0275, -0.0169, 0.0243, ..., -0.0080, -0.0436, -0.0173],\n", 365 | " [ 0.0692, -0.0282, 0.0634, ..., 0.0649, -0.0513, 0.0403]])),\n", 366 | " ('fc.bias',\n", 367 | " tensor([-0.1888, 0.0599, 0.1772, -0.0362, -0.0030, 0.0218, -0.0206, 0.1077,\n", 368 | " -0.0287, -0.0747]))])" 369 | ] 370 | }, 371 | "metadata": {}, 372 | "execution_count": 159 373 | } 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "source": [ 379 | "#accuracy\n", 380 | "\n", 381 | "# Use a small subset of the full validation dataset\n", 382 | "from torch.utils.data import SubsetRandomSampler\n", 383 | "\n", 384 | "K = n_test # enter your length here\n", 385 | "subsample_train_indices = torch.randperm(len(val_set))[:K]\n", 386 | "val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, sampler=SubsetRandomSampler(subsample_train_indices))\n", 387 | "\n", 388 | "correct = 0\n", 389 | "total = 0\n", 390 | "# Set the model to evaluation mode\n", 391 | "cnn.eval()\n", 392 | "with torch.inference_mode():\n", 393 | " for data in val_loader:\n", 394 | " images, labels = data\n", 395 | " outputs = cnn(images)\n", 396 | " _, predicted = torch.max(outputs.data, 1)\n", 397 | " total += labels.size(0)\n", 398 | " correct += (predicted == labels).sum().item()\n", 399 | "print(f'Accuracy on the validation set: {100 * correct / total:.2f}%')" 400 | ], 401 | "metadata": { 402 | "id": "EZCJdLCoqc10", 403 | "colab": { 404 | "base_uri": "https://localhost:8080/" 405 | }, 406 | "outputId": "25720c9e-cc37-484b-ccbf-fe25eba068c6" 407 | }, 408 | "execution_count": 160, 409 | "outputs": [ 410 | { 411 | "output_type": "stream", 412 | "name": "stdout", 413 | "text": [ 414 | "Accuracy on the validation set: 67.00%\n" 415 | ] 416 | } 417 | ] 418 | } 419 | ] 420 | } -------------------------------------------------------------------------------- /ImgClass_Classical.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "gpuType": "T4" 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | } 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "source": [ 21 | "The code in this notebook is a reimplementation of the classical computing version of the code described in the paper: https://arxiv.org/abs/2304.09224" 22 | ], 23 | "metadata": { 24 | "id": "QZDjxguivu6n" 25 | } 26 | }, 27 | { 28 | "cell_type": "code", 29 | "source": [ 30 | "!pip install torch torchvision\n", 31 | "!pip install pennylane" 32 | ], 33 | "metadata": { 34 | "colab": { 35 | "base_uri": "https://localhost:8080/" 36 | }, 37 | "id": "B-9mL7lXz0mo", 38 | "outputId": "726af243-d37e-48a5-a424-1f290408cfc4" 39 | }, 40 | "execution_count": 2, 41 | "outputs": [ 42 | { 43 | "output_type": "stream", 44 | "name": "stdout", 45 | "text": [ 46 | "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.1.0+cu121)\n", 47 | "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.16.0+cu121)\n", 48 | "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.13.1)\n", 49 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.5.0)\n", 50 | "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n", 51 | "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.2.1)\n", 52 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)\n", 53 | "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0)\n", 54 | "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.1.0)\n", 55 | "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.23.5)\n", 56 | "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torchvision) (2.31.0)\n", 57 | "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (9.4.0)\n", 58 | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.3)\n", 59 | "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (3.3.2)\n", 60 | "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (3.6)\n", 61 | "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (2.0.7)\n", 62 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (2023.11.17)\n", 63 | "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", 64 | "Collecting pennylane\n", 65 | " Downloading PennyLane-0.33.1-py3-none-any.whl (1.5 MB)\n", 66 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.5/1.5 MB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 67 | "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from pennylane) (1.23.5)\n", 68 | "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from pennylane) (1.11.4)\n", 69 | "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from pennylane) (3.2.1)\n", 70 | "Collecting rustworkx (from pennylane)\n", 71 | " Downloading rustworkx-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)\n", 72 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m13.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 73 | "\u001b[?25hRequirement already satisfied: autograd in /usr/local/lib/python3.10/dist-packages (from pennylane) (1.6.2)\n", 74 | "Requirement already satisfied: toml in /usr/local/lib/python3.10/dist-packages (from pennylane) (0.10.2)\n", 75 | "Requirement already satisfied: appdirs in /usr/local/lib/python3.10/dist-packages (from pennylane) (1.4.4)\n", 76 | "Collecting semantic-version>=2.7 (from pennylane)\n", 77 | " Downloading semantic_version-2.10.0-py2.py3-none-any.whl (15 kB)\n", 78 | "Collecting autoray>=0.6.1 (from pennylane)\n", 79 | " Downloading autoray-0.6.7-py3-none-any.whl (49 kB)\n", 80 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.9/49.9 kB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 81 | "\u001b[?25hRequirement already satisfied: cachetools in /usr/local/lib/python3.10/dist-packages (from pennylane) (5.3.2)\n", 82 | "Collecting pennylane-lightning>=0.33 (from pennylane)\n", 83 | " Downloading PennyLane_Lightning-0.33.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.0 MB)\n", 84 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.0/14.0 MB\u001b[0m \u001b[31m30.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 85 | "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from pennylane) (2.31.0)\n", 86 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from pennylane) (4.5.0)\n", 87 | "Requirement already satisfied: future>=0.15.2 in /usr/local/lib/python3.10/dist-packages (from autograd->pennylane) (0.18.3)\n", 88 | "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->pennylane) (3.3.2)\n", 89 | "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->pennylane) (3.6)\n", 90 | "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->pennylane) (2.0.7)\n", 91 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->pennylane) (2023.11.17)\n", 92 | "Installing collected packages: semantic-version, rustworkx, autoray, pennylane-lightning, pennylane\n", 93 | "Successfully installed autoray-0.6.7 pennylane-0.33.1 pennylane-lightning-0.33.1 rustworkx-0.13.2 semantic-version-2.10.0\n" 94 | ] 95 | } 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "source": [ 101 | "import torch\n", 102 | "import torchvision\n", 103 | "from torchvision import transforms, datasets\n", 104 | "from torchvision.transforms import ToTensor\n", 105 | "import torch.optim as optim\n", 106 | "import torch.nn as nn\n", 107 | "import pennylane as qml" 108 | ], 109 | "metadata": { 110 | "id": "0VpMjoE3v-nE" 111 | }, 112 | "execution_count": 3, 113 | "outputs": [] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "source": [ 118 | "# Data Preparation and Loading\n", 119 | "\n", 120 | "Use the MNIST dataset: https://en.wikipedia.org/wiki/CIFAR-10" 121 | ], 122 | "metadata": { 123 | "id": "FPX9br7jwbmH" 124 | } 125 | }, 126 | { 127 | "cell_type": "code", 128 | "source": [ 129 | "# Download and load MNIST dataset\n", 130 | "\n", 131 | "train_data = datasets.MNIST(\n", 132 | " root = 'data',\n", 133 | " train = True,\n", 134 | " transform = ToTensor(),\n", 135 | " download = True,\n", 136 | ")\n", 137 | "test_data = datasets.MNIST(\n", 138 | " root = 'data',\n", 139 | " train = False,\n", 140 | " transform = ToTensor()\n", 141 | ")" 142 | ], 143 | "metadata": { 144 | "id": "9Tc7SzbpwVgP", 145 | "colab": { 146 | "base_uri": "https://localhost:8080/" 147 | }, 148 | "outputId": "855ddb6a-f9c8-4acb-dc23-0612b538956e" 149 | }, 150 | "execution_count": 4, 151 | "outputs": [ 152 | { 153 | "output_type": "stream", 154 | "name": "stdout", 155 | "text": [ 156 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", 157 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz\n" 158 | ] 159 | }, 160 | { 161 | "output_type": "stream", 162 | "name": "stderr", 163 | "text": [ 164 | "100%|██████████| 9912422/9912422 [00:00<00:00, 109135205.23it/s]\n" 165 | ] 166 | }, 167 | { 168 | "output_type": "stream", 169 | "name": "stdout", 170 | "text": [ 171 | "Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw\n", 172 | "\n", 173 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", 174 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz\n" 175 | ] 176 | }, 177 | { 178 | "output_type": "stream", 179 | "name": "stderr", 180 | "text": [ 181 | "100%|██████████| 28881/28881 [00:00<00:00, 22825644.21it/s]\n" 182 | ] 183 | }, 184 | { 185 | "output_type": "stream", 186 | "name": "stdout", 187 | "text": [ 188 | "Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw\n", 189 | "\n", 190 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", 191 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz\n" 192 | ] 193 | }, 194 | { 195 | "output_type": "stream", 196 | "name": "stderr", 197 | "text": [ 198 | "100%|██████████| 1648877/1648877 [00:00<00:00, 32306380.08it/s]\n" 199 | ] 200 | }, 201 | { 202 | "output_type": "stream", 203 | "name": "stdout", 204 | "text": [ 205 | "Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw\n", 206 | "\n", 207 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", 208 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" 209 | ] 210 | }, 211 | { 212 | "output_type": "stream", 213 | "name": "stderr", 214 | "text": [ 215 | "100%|██████████| 4542/4542 [00:00<00:00, 14848424.60it/s]\n" 216 | ] 217 | }, 218 | { 219 | "output_type": "stream", 220 | "name": "stdout", 221 | "text": [ 222 | "Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw\n", 223 | "\n" 224 | ] 225 | } 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "source": [ 231 | "# Creating the CNN Architecture\n", 232 | "\n", 233 | "The CNN architecture is defined with two convolutional layers, max pooling layers, and fully connected layers. The forward method specifies how data flows through the network. This architecture is suitable for image classification tasks and can be modified to suit your specific project requirements." 234 | ], 235 | "metadata": { 236 | "id": "oRwUYemCyrYU" 237 | } 238 | }, 239 | { 240 | "cell_type": "code", 241 | "source": [ 242 | "import torch.nn as nn\n", 243 | "import torch.nn.functional as F\n", 244 | "\n", 245 | "# Define a simple CNN architecture\n", 246 | "class Net(nn.Module):\n", 247 | " def __init__(self):\n", 248 | " super(Net, self).__init__()\n", 249 | " # Convolutional layer 1 with 1 input channels (for greyscale images), 16 output channels, and 5x5 kernel\n", 250 | " self.conv1 = nn.Conv2d(1, 16, 5, stride=1, padding=2)\n", 251 | " # Batch normalization after convolutional layer 1\n", 252 | " self.bn1 = nn.BatchNorm2d(16)\n", 253 | " # Max pooling layer with a 2x2 window\n", 254 | " self.pool = nn.MaxPool2d(2, 2)\n", 255 | " # Convolutional layer 2 with 16 input channels (from the previous layer), 32 output channels, and 5x5 kernel\n", 256 | " self.conv2 = nn.Conv2d(16, 32, 5, stride=1, padding=2)\n", 257 | " # Batch normalization after convolutional layer 2\n", 258 | " self.bn2 = nn.BatchNorm2d(32)\n", 259 | " # Fully connected layers\n", 260 | " self.fc1 = nn.Linear(32 * 7 * 7, 120)\n", 261 | " self.fc2 = nn.Linear(120, 20)\n", 262 | " self.fc3 = nn.Linear(20, 10)\n", 263 | "\n", 264 | " def forward(self, x):\n", 265 | " # Propagate the input through the CNN layers\n", 266 | " x = self.pool(F.relu(self.bn1(self.conv1(x))))\n", 267 | " x = self.pool(F.relu(self.bn2(self.conv2(x))))\n", 268 | " # Flatten the output from the convolutional layers\n", 269 | " x = x.view(-1, 32 * 7 * 7)\n", 270 | " # Pass the output to the quantum layer\n", 271 | " x = F.relu(self.fc1(x))\n", 272 | " x = F.relu(self.fc2(x))\n", 273 | " x = self.fc3(x)\n", 274 | " return x" 275 | ], 276 | "metadata": { 277 | "id": "uhFmkPj-yvHw" 278 | }, 279 | "execution_count": 5, 280 | "outputs": [] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "source": [ 285 | "# Train the CNN\n", 286 | "\n", 287 | "The CNN model is initialised, the loss function and optimizer are set up, and data loaders for training and validation data are created. The training loop iterates through the dataset for a specified number of epochs, performing forward and backward passes to update the model’s parameters." 288 | ], 289 | "metadata": { 290 | "id": "-7i8G2szzSxz" 291 | } 292 | }, 293 | { 294 | "cell_type": "code", 295 | "source": [ 296 | "dataset = train_data\n", 297 | "\n", 298 | "# Initialize your CNN model\n", 299 | "cnn = Net()\n", 300 | "# Define loss function and optimizer\n", 301 | "criterion = nn.CrossEntropyLoss() # Cross-entropy loss for classification\n", 302 | "optimizer = torch.optim.SGD(cnn.parameters(), lr=0.001, momentum=0.9) # Stochastic Gradient Descent optimizer\n", 303 | "# Split your data into training and validation sets\n", 304 | "train_size = int(0.8 * len(dataset))\n", 305 | "train_set, val_set = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])\n", 306 | "train_loader = torch.utils.data.DataLoader(train_set, batch_size=4, shuffle=True)\n", 307 | "val_loader = torch.utils.data.DataLoader(val_set, batch_size=4, shuffle=False)\n", 308 | "# Training loop\n", 309 | "num_epochs = 10\n", 310 | "for epoch in range(num_epochs):\n", 311 | " running_loss = 0.0\n", 312 | " for i, data in enumerate(train_loader, 0):\n", 313 | " inputs, labels = data\n", 314 | " optimizer.zero_grad() # Zero the parameter gradients to avoid accumulation\n", 315 | " outputs = cnn(inputs) # Forward pass\n", 316 | " loss = criterion(outputs, labels) # Compute the loss\n", 317 | " loss.backward() # Backpropagation\n", 318 | " optimizer.step() # Update the model parameters\n", 319 | "print('Finished Training')" 320 | ], 321 | "metadata": { 322 | "id": "A8a6qGcOzV_1", 323 | "colab": { 324 | "base_uri": "https://localhost:8080/" 325 | }, 326 | "outputId": "f03210dc-a102-4f7a-847b-95f35dad499a" 327 | }, 328 | "execution_count": 6, 329 | "outputs": [ 330 | { 331 | "output_type": "stream", 332 | "name": "stdout", 333 | "text": [ 334 | "Finished Training\n" 335 | ] 336 | } 337 | ] 338 | }, 339 | { 340 | "cell_type": "markdown", 341 | "source": [ 342 | "# Evaluating the Model\n", 343 | "\n", 344 | "Set the model to evaluation mode, use it to make predictions on the validation dataset, and calculate the accuracy of the model." 345 | ], 346 | "metadata": { 347 | "id": "6dAn7ec12RgV" 348 | } 349 | }, 350 | { 351 | "cell_type": "code", 352 | "source": [ 353 | "correct = 0\n", 354 | "total = 0\n", 355 | "# Set the model to evaluation mode\n", 356 | "cnn.eval()\n", 357 | "with torch.no_grad():\n", 358 | " for data in val_loader:\n", 359 | " images, labels = data\n", 360 | " outputs = cnn(images)\n", 361 | " _, predicted = torch.max(outputs.data, 1)\n", 362 | " total += labels.size(0)\n", 363 | " correct += (predicted == labels).sum().item()\n", 364 | "print(f'Accuracy on the validation set: {100 * correct / total:.2f}%')" 365 | ], 366 | "metadata": { 367 | "id": "F6eYQIjr2erR", 368 | "colab": { 369 | "base_uri": "https://localhost:8080/" 370 | }, 371 | "outputId": "a763b223-2eca-4a7c-f241-8ec5f947b04b" 372 | }, 373 | "execution_count": 7, 374 | "outputs": [ 375 | { 376 | "output_type": "stream", 377 | "name": "stdout", 378 | "text": [ 379 | "Accuracy on the validation set: 99.20%\n" 380 | ] 381 | } 382 | ] 383 | } 384 | ] 385 | } -------------------------------------------------------------------------------- /ImgClass_Hybrid.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "gpuType": "T4" 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | }, 16 | "accelerator": "GPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "source": [ 22 | "The code in this notebook is a reimplementation of the classical computing version of the code described in the paper: https://arxiv.org/abs/2304.09224" 23 | ], 24 | "metadata": { 25 | "id": "QZDjxguivu6n" 26 | } 27 | }, 28 | { 29 | "cell_type": "code", 30 | "source": [ 31 | "!pip install torch torchvision\n", 32 | "!pip install pennylane" 33 | ], 34 | "metadata": { 35 | "colab": { 36 | "base_uri": "https://localhost:8080/" 37 | }, 38 | "id": "B-9mL7lXz0mo", 39 | "outputId": "141835d1-fab1-45fa-85df-32f532ece3da" 40 | }, 41 | "execution_count": null, 42 | "outputs": [ 43 | { 44 | "output_type": "stream", 45 | "name": "stdout", 46 | "text": [ 47 | "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.1.0+cu121)\n", 48 | "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.16.0+cu121)\n", 49 | "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.13.1)\n", 50 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.5.0)\n", 51 | "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n", 52 | "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.2.1)\n", 53 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)\n", 54 | "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0)\n", 55 | "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.1.0)\n", 56 | "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.23.5)\n", 57 | "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torchvision) (2.31.0)\n", 58 | "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (9.4.0)\n", 59 | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.3)\n", 60 | "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (3.3.2)\n", 61 | "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (3.6)\n", 62 | "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (2.0.7)\n", 63 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (2023.11.17)\n", 64 | "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", 65 | "Collecting pennylane\n", 66 | " Downloading PennyLane-0.33.1-py3-none-any.whl (1.5 MB)\n", 67 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.5/1.5 MB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 68 | "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from pennylane) (1.23.5)\n", 69 | "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from pennylane) (1.11.4)\n", 70 | "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from pennylane) (3.2.1)\n", 71 | "Collecting rustworkx (from pennylane)\n", 72 | " Downloading rustworkx-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)\n", 73 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m54.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 74 | "\u001b[?25hRequirement already satisfied: autograd in /usr/local/lib/python3.10/dist-packages (from pennylane) (1.6.2)\n", 75 | "Requirement already satisfied: toml in /usr/local/lib/python3.10/dist-packages (from pennylane) (0.10.2)\n", 76 | "Requirement already satisfied: appdirs in /usr/local/lib/python3.10/dist-packages (from pennylane) (1.4.4)\n", 77 | "Collecting semantic-version>=2.7 (from pennylane)\n", 78 | " Downloading semantic_version-2.10.0-py2.py3-none-any.whl (15 kB)\n", 79 | "Collecting autoray>=0.6.1 (from pennylane)\n", 80 | " Downloading autoray-0.6.7-py3-none-any.whl (49 kB)\n", 81 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.9/49.9 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 82 | "\u001b[?25hRequirement already satisfied: cachetools in /usr/local/lib/python3.10/dist-packages (from pennylane) (5.3.2)\n", 83 | "Collecting pennylane-lightning>=0.33 (from pennylane)\n", 84 | " Downloading PennyLane_Lightning-0.33.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.0 MB)\n", 85 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.0/14.0 MB\u001b[0m \u001b[31m23.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 86 | "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from pennylane) (2.31.0)\n", 87 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from pennylane) (4.5.0)\n", 88 | "Requirement already satisfied: future>=0.15.2 in /usr/local/lib/python3.10/dist-packages (from autograd->pennylane) (0.18.3)\n", 89 | "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->pennylane) (3.3.2)\n", 90 | "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->pennylane) (3.6)\n", 91 | "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->pennylane) (2.0.7)\n", 92 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->pennylane) (2023.11.17)\n", 93 | "Installing collected packages: semantic-version, rustworkx, autoray, pennylane-lightning, pennylane\n", 94 | "Successfully installed autoray-0.6.7 pennylane-0.33.1 pennylane-lightning-0.33.1 rustworkx-0.13.2 semantic-version-2.10.0\n" 95 | ] 96 | } 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "source": [ 102 | "import torch\n", 103 | "import torchvision\n", 104 | "from torchvision import transforms, datasets\n", 105 | "from torchvision.transforms import ToTensor\n", 106 | "import torch.optim as optim\n", 107 | "import torch.nn as nn\n", 108 | "import pennylane as qml\n", 109 | "from pennylane import numpy as np\n", 110 | "\n", 111 | "use_cuda = torch.cuda.is_available()\n", 112 | "device = torch.device(\"cuda\" if use_cuda else \"cpu\")" 113 | ], 114 | "metadata": { 115 | "id": "0VpMjoE3v-nE" 116 | }, 117 | "execution_count": null, 118 | "outputs": [] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "source": [ 123 | "# Data Preparation and Loading\n", 124 | "\n", 125 | "Use the MNIST dataset: https://en.wikipedia.org/wiki/CIFAR-10" 126 | ], 127 | "metadata": { 128 | "id": "FPX9br7jwbmH" 129 | } 130 | }, 131 | { 132 | "cell_type": "code", 133 | "source": [ 134 | "# Download and load MNIST dataset\n", 135 | "\n", 136 | "train_data = datasets.MNIST(\n", 137 | " root = 'data',\n", 138 | " train = True,\n", 139 | " transform = ToTensor(),\n", 140 | " download = True,\n", 141 | ")\n", 142 | "test_data = datasets.MNIST(\n", 143 | " root = 'data',\n", 144 | " train = False,\n", 145 | " transform = ToTensor()\n", 146 | ")" 147 | ], 148 | "metadata": { 149 | "id": "9Tc7SzbpwVgP", 150 | "colab": { 151 | "base_uri": "https://localhost:8080/" 152 | }, 153 | "outputId": "3d7ba899-b976-4668-82f7-4990f2cb25df" 154 | }, 155 | "execution_count": null, 156 | "outputs": [ 157 | { 158 | "output_type": "stream", 159 | "name": "stdout", 160 | "text": [ 161 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", 162 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz\n" 163 | ] 164 | }, 165 | { 166 | "output_type": "stream", 167 | "name": "stderr", 168 | "text": [ 169 | "100%|██████████| 9912422/9912422 [00:00<00:00, 374620081.31it/s]" 170 | ] 171 | }, 172 | { 173 | "output_type": "stream", 174 | "name": "stdout", 175 | "text": [ 176 | "Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw\n" 177 | ] 178 | }, 179 | { 180 | "output_type": "stream", 181 | "name": "stderr", 182 | "text": [ 183 | "\n" 184 | ] 185 | }, 186 | { 187 | "output_type": "stream", 188 | "name": "stdout", 189 | "text": [ 190 | "\n", 191 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", 192 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz\n" 193 | ] 194 | }, 195 | { 196 | "output_type": "stream", 197 | "name": "stderr", 198 | "text": [ 199 | "100%|██████████| 28881/28881 [00:00<00:00, 99047991.68it/s]\n" 200 | ] 201 | }, 202 | { 203 | "output_type": "stream", 204 | "name": "stdout", 205 | "text": [ 206 | "Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw\n", 207 | "\n", 208 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", 209 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz\n" 210 | ] 211 | }, 212 | { 213 | "output_type": "stream", 214 | "name": "stderr", 215 | "text": [ 216 | "100%|██████████| 1648877/1648877 [00:00<00:00, 165345145.40it/s]\n" 217 | ] 218 | }, 219 | { 220 | "output_type": "stream", 221 | "name": "stdout", 222 | "text": [ 223 | "Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw\n", 224 | "\n", 225 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", 226 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" 227 | ] 228 | }, 229 | { 230 | "output_type": "stream", 231 | "name": "stderr", 232 | "text": [ 233 | "100%|██████████| 4542/4542 [00:00<00:00, 21897159.50it/s]" 234 | ] 235 | }, 236 | { 237 | "output_type": "stream", 238 | "name": "stdout", 239 | "text": [ 240 | "Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw\n", 241 | "\n" 242 | ] 243 | }, 244 | { 245 | "output_type": "stream", 246 | "name": "stderr", 247 | "text": [ 248 | "\n" 249 | ] 250 | } 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "source": [ 256 | "# Creating the CNN Architecture\n", 257 | "\n", 258 | "The CNN architecture is defined with two convolutional layers, max pooling layers, and fully connected layers. The forward method specifies how data flows through the network. This architecture is suitable for image classification tasks and can be modified to suit your specific project requirements." 259 | ], 260 | "metadata": { 261 | "id": "oRwUYemCyrYU" 262 | } 263 | }, 264 | { 265 | "cell_type": "code", 266 | "source": [ 267 | "import torch.nn as nn\n", 268 | "import torch.nn.functional as F\n", 269 | "import pennylane as qml\n", 270 | "\n", 271 | "# Define the quantum circuit using PennyLane\n", 272 | "n_qubits = 5\n", 273 | "dev = qml.device(\"default.qubit\", wires=n_qubits)\n", 274 | "\n", 275 | "@qml.qnode(dev)\n", 276 | "def qnode(inputs, weights):\n", 277 | " qml.AngleEmbedding(inputs, wires=range(n_qubits))\n", 278 | " qml.BasicEntanglerLayers(weights, wires=range(n_qubits))\n", 279 | " return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]\n", 280 | "\n", 281 | "# Define the QLayer\n", 282 | "n_layers = 3\n", 283 | "weight_shapes = {\"weights\": (n_layers, n_qubits)}\n", 284 | "\n", 285 | "\n", 286 | "# Define a simple CNN architecture\n", 287 | "class Net(nn.Module):\n", 288 | " def __init__(self):\n", 289 | " super(Net, self).__init__()\n", 290 | " # Convolutional layer 1 with 1 input channels (for greyscale images), 16 output channels, and 5x5 kernel\n", 291 | " self.conv1 = nn.Conv2d(1, 16, 5, stride=1, padding=2)\n", 292 | " # Batch normalization after convolutional layer 1\n", 293 | " self.bn1 = nn.BatchNorm2d(16)\n", 294 | " # Max pooling layer with a 2x2 window\n", 295 | " self.pool = nn.MaxPool2d(2, 2)\n", 296 | " # Convolutional layer 2 with 16 input channels (from the previous layer), 32 output channels, and 5x5 kernel\n", 297 | " self.conv2 = nn.Conv2d(16, 32, 5, stride=1, padding=2)\n", 298 | " # Batch normalization after convolutional layer 2\n", 299 | " self.bn2 = nn.BatchNorm2d(32)\n", 300 | " # Quantum layer\n", 301 | " self.qlayer1 = qml.qnn.TorchLayer(qnode, weight_shapes)\n", 302 | " self.qlayer2 = qml.qnn.TorchLayer(qnode, weight_shapes)\n", 303 | " self.qlayer3 = qml.qnn.TorchLayer(qnode, weight_shapes)\n", 304 | " self.qlayer4 = qml.qnn.TorchLayer(qnode, weight_shapes)\n", 305 | " # Fully connected layers\n", 306 | " self.fc1 = nn.Linear(32 * 7 * 7, 120)\n", 307 | " self.fc2 = nn.Linear(120, 20)\n", 308 | " self.fc3 = nn.Linear(20, 10)\n", 309 | "\n", 310 | " def forward(self, x):\n", 311 | " # Propagate the input through the CNN layers\n", 312 | " x = self.pool(F.relu(self.bn1(self.conv1(x))))\n", 313 | " x = self.pool(F.relu(self.bn2(self.conv2(x))))\n", 314 | " # Flatten the output from the convolutional layers\n", 315 | " x = x.view(-1, 32 * 7 * 7)\n", 316 | " # Pass the output to the quantum layer\n", 317 | " x = F.relu(self.fc1(x))\n", 318 | " x = F.relu(self.fc2(x))\n", 319 | " x_1, x_2, x_3, x_4 = torch.split(x, 5, dim=1)\n", 320 | " x_1 = self.qlayer1(x_1)\n", 321 | " x_2 = self.qlayer2(x_2)\n", 322 | " x_3 = self.qlayer3(x_3)\n", 323 | " x_4 = self.qlayer4(x_4)\n", 324 | " x = torch.cat([x_1, x_2, x_3, x_4], axis=1)\n", 325 | " x = self.fc3(x)\n", 326 | " return x" 327 | ], 328 | "metadata": { 329 | "id": "uhFmkPj-yvHw" 330 | }, 331 | "execution_count": null, 332 | "outputs": [] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "source": [ 337 | "# Train the CNN\n", 338 | "\n", 339 | "The CNN model is initialised, the loss function and optimizer are set up, and data loaders for training and validation data are created. The training loop iterates through the dataset for a specified number of epochs, performing forward and backward passes to update the model’s parameters." 340 | ], 341 | "metadata": { 342 | "id": "-7i8G2szzSxz" 343 | } 344 | }, 345 | { 346 | "cell_type": "code", 347 | "source": [ 348 | "import datetime\n", 349 | "\n", 350 | "dataset = train_data\n", 351 | "\n", 352 | "# Initialize your CNN model\n", 353 | "cnn = Net()\n", 354 | "\n", 355 | "# Define loss function and optimizer\n", 356 | "criterion = nn.CrossEntropyLoss() # Cross-entropy loss for classification\n", 357 | "optimizer = torch.optim.SGD(cnn.parameters(), lr=0.001, momentum=0.9) # Stochastic Gradient Descent optimizer\n", 358 | "# Split your data into training and validation sets\n", 359 | "train_size = int(0.8 * len(dataset))\n", 360 | "train_set, val_set = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])\n", 361 | "train_loader = torch.utils.data.DataLoader(train_set, batch_size=4, shuffle=True)\n", 362 | "#val_loader = torch.utils.data.DataLoader(val_set, batch_size=4, shuffle=False)\n", 363 | "# Training loop\n", 364 | "num_epochs = 10\n", 365 | "for epoch in range(num_epochs):\n", 366 | " ct = datetime.datetime.now()\n", 367 | " print(f\"{epoch=}, {ct}\")\n", 368 | " running_loss = 0.0\n", 369 | " for i, data in enumerate(train_loader, 0):\n", 370 | " inputs, labels = data\n", 371 | " optimizer.zero_grad() # Zero the parameter gradients to avoid accumulation\n", 372 | " outputs = cnn(inputs) # Forward pass\n", 373 | " loss = criterion(outputs, labels) # Compute the loss\n", 374 | " loss.backward() # Backpropagation\n", 375 | " optimizer.step() # Update the model parameters\n", 376 | "print('Finished Training')" 377 | ], 378 | "metadata": { 379 | "id": "A8a6qGcOzV_1" 380 | }, 381 | "execution_count": null, 382 | "outputs": [] 383 | }, 384 | { 385 | "cell_type": "markdown", 386 | "source": [ 387 | "# Evaluating the Model\n", 388 | "\n", 389 | "Set the model to evaluation mode, use it to make predictions on the validation dataset, and calculate the accuracy of the model." 390 | ], 391 | "metadata": { 392 | "id": "6dAn7ec12RgV" 393 | } 394 | }, 395 | { 396 | "cell_type": "code", 397 | "source": [ 398 | "correct = 0\n", 399 | "total = 0\n", 400 | "# Set the model to evaluation mode\n", 401 | "cnn.eval()\n", 402 | "with torch.no_grad():\n", 403 | " for data in val_loader:\n", 404 | " images, labels = data\n", 405 | " outputs = cnn(images)\n", 406 | " _, predicted = torch.max(outputs.data, 1)\n", 407 | " total += labels.size(0)\n", 408 | " correct += (predicted == labels).sum().item()\n", 409 | "print(f'Accuracy on the validation set: {100 * correct / total:.2f}%')" 410 | ], 411 | "metadata": { 412 | "id": "F6eYQIjr2erR" 413 | }, 414 | "execution_count": null, 415 | "outputs": [] 416 | } 417 | ] 418 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Hybrid quantum classical models for image classification 2 | 3 | The python code implements the hybrid quantum-classical models from the paper "Quantum machine learning for image classification" by Arsenii Senokosov et al (https://arxiv.org/pdf/2304.09224.pdf) 4 | 5 | The paper proposes two hybrid quantum-classical models for image classification: 6 | 7 | 1. **Hybrid Quantum Neural Network with parallel quantum dense layers, HQNN-Parallel**: 8 | 9 | HQNN-Parallel is a hybrid quantum-classical model that utilizes multiple parallel quantum dense layers for image classification tasks. The classical convolutional block reduces the dimensionality of the input image, and the parallel quantum dense layers extract and process features from the reduced representation. 10 | 11 | The model is evaluated on the MNIST dataset of handwritten digits and compared with the performance of the classical convolutional neural networks (CNNs) with similar architectures. 12 | 13 | The classical convolutional neural network is implemented in `` ImgClass-Classical.ipynb `` and the HQNN-Parallel is implemented in `` ImgClass-Hybrid.ipynb `` 14 | 15 | **Result**: The classical convolutional neural network gives an accurancy of 99.20% and the HQNN-Parallel (with ``n_layers = 1`` instead of ``n_layers = 3`` as used in the paper) gives an accuracy of 99.17% 16 | 17 | 18 | 2. **Hybrid Quantum Neural Network with quanvolutional layer, HQNN-Quanv**: 19 | 20 | HQNN-Quanv is a hybrid quantum-classical model that combines a quanvolutional layer with classical fully connected layers to address the image classification task. The quanvolutional layer utilizes quantum mechanics to extract features from the input image, while the classical fully connected layers process and classify these features. 21 | 22 | The model is evaluated on the MNIST dataset of handwritten digits and compared with the performance of the classical convolutional neural networks with similar architectures. Particularly with, CNN1- Convolutional kernel with 1 input channel, 1 output channel and CNN4- Convolutional kernel with 1 input channel, 4 output channels. 4 x 4 kernels are used. 23 | 24 | The classical convolutional neural network is implemented in `` CNN1.ipynb `` and `` CNN4.ipynb ``, the HQNN-Quanv is implemented in `` HQNN-Quanv.ipynb `` 25 | 26 | **Result**: CNN1 gives an accuracy of 79%, CNN4 gives an accuracy of 67% and the HQNN-Quanv had an accuracy of 70% (see ``dataset_indices_500.pt``). 27 | 28 | Note that HQNN-Quanv model uses Initial_lr = 0.003, and the model's loss function seems to be very sensitive to small changes (+/- 0.0005) to Initial_lr, significantly impacting the accuracy (which gets to 58% for Initial_lr = 0.0035, or 45% for Initial_lr = 0.0025, for example). See ``hqcnn quanv output.txt`` for some outputs of HQNN-Quanv model training run). 29 | -------------------------------------------------------------------------------- /dataset_indices_500.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AishSweety/hybrid-quantum-classical-models-for-image-classification/1c6bd31a02600f1a8a09d888004bd77a21b0ce19/dataset_indices_500.pt -------------------------------------------------------------------------------- /hqcnn quanv output.txt: -------------------------------------------------------------------------------- 1 | n_epochs=40, n_train=500, initial_lr=0.0025, gamma=1.0 2 | momentum=0.9, nesterov=False 3 | 4 | epoch=0, lr=[0.0025], 2024-01-06 21:48:36.488483 5 | Training [2%] Loss: 2.2788 6 | Saving model state to models/ImgClass-Quanvolv.pth 7 | Accuracy on the validation set: 17.00% 8 | epoch=1, lr=[0.0025], 2024-01-06 21:50:37.598273 9 | Training [5%] Loss: 2.2062 10 | Saving model state to models/ImgClass-Quanvolv.pth 11 | Accuracy on the validation set: 23.00% 12 | epoch=2, lr=[0.0025], 2024-01-06 21:52:33.046508 13 | Training [8%] Loss: 2.1370 14 | Saving model state to models/ImgClass-Quanvolv.pth 15 | Accuracy on the validation set: 30.00% 16 | epoch=3, lr=[0.0025], 2024-01-06 21:54:28.273496 17 | Training [10%] Loss: 2.0798 18 | Saving model state to models/ImgClass-Quanvolv.pth 19 | Accuracy on the validation set: 34.00% 20 | epoch=4, lr=[0.0025], 2024-01-06 21:56:23.834537 21 | Training [12%] Loss: 2.0313 22 | Saving model state to models/ImgClass-Quanvolv.pth 23 | Accuracy on the validation set: 39.00% 24 | epoch=5, lr=[0.0025], 2024-01-06 21:58:19.401408 25 | Training [15%] Loss: 1.9895 26 | Saving model state to models/ImgClass-Quanvolv.pth 27 | Accuracy on the validation set: 39.00% 28 | epoch=6, lr=[0.0025], 2024-01-06 22:00:14.385288 29 | Training [18%] Loss: 1.9529 30 | Saving model state to models/ImgClass-Quanvolv.pth 31 | Accuracy on the validation set: 39.00% 32 | epoch=7, lr=[0.0025], 2024-01-06 22:02:10.724110 33 | Training [20%] Loss: 1.9209 34 | Saving model state to models/ImgClass-Quanvolv.pth 35 | Accuracy on the validation set: 41.00% 36 | epoch=8, lr=[0.0025], 2024-01-06 22:04:15.538323 37 | Training [22%] Loss: 1.8924 38 | Saving model state to models/ImgClass-Quanvolv.pth 39 | Accuracy on the validation set: 41.00% 40 | epoch=9, lr=[0.0025], 2024-01-06 22:06:20.006308 41 | Training [25%] Loss: 1.8671 42 | Saving model state to models/ImgClass-Quanvolv.pth 43 | Accuracy on the validation set: 41.00% 44 | epoch=10, lr=[0.0025], 2024-01-06 22:08:24.668136 45 | Training [28%] Loss: 1.8443 46 | Saving model state to models/ImgClass-Quanvolv.pth 47 | Accuracy on the validation set: 42.00% 48 | epoch=11, lr=[0.0025], 2024-01-06 22:10:28.868716 49 | Training [30%] Loss: 1.8237 50 | Saving model state to models/ImgClass-Quanvolv.pth 51 | Accuracy on the validation set: 43.00% 52 | epoch=12, lr=[0.0025], 2024-01-06 22:12:33.583416 53 | Training [32%] Loss: 1.8049 54 | Saving model state to models/ImgClass-Quanvolv.pth 55 | Accuracy on the validation set: 44.00% 56 | epoch=13, lr=[0.0025], 2024-01-06 22:14:36.732528 57 | Training [35%] Loss: 1.7879 58 | Saving model state to models/ImgClass-Quanvolv.pth 59 | Accuracy on the validation set: 43.00% 60 | epoch=14, lr=[0.0025], 2024-01-06 22:16:41.035188 61 | Training [38%] Loss: 1.7722 62 | Saving model state to models/ImgClass-Quanvolv.pth 63 | Accuracy on the validation set: 44.00% 64 | epoch=15, lr=[0.0025], 2024-01-06 22:18:44.407122 65 | Training [40%] Loss: 1.7578 66 | Saving model state to models/ImgClass-Quanvolv.pth 67 | Accuracy on the validation set: 44.00% 68 | epoch=16, lr=[0.0025], 2024-01-06 22:20:47.805007 69 | Training [42%] Loss: 1.7446 70 | Saving model state to models/ImgClass-Quanvolv.pth 71 | Accuracy on the validation set: 44.00% 72 | epoch=17, lr=[0.0025], 2024-01-06 22:22:52.118397 73 | Training [45%] Loss: 1.7322 74 | Saving model state to models/ImgClass-Quanvolv.pth 75 | Accuracy on the validation set: 44.00% 76 | epoch=18, lr=[0.0025], 2024-01-06 22:24:53.781146 77 | Training [48%] Loss: 1.7208 78 | Saving model state to models/ImgClass-Quanvolv.pth 79 | Accuracy on the validation set: 43.00% 80 | epoch=19, lr=[0.0025], 2024-01-06 22:26:53.749506 81 | Training [50%] Loss: 1.7100 82 | Saving model state to models/ImgClass-Quanvolv.pth 83 | Accuracy on the validation set: 45.00% 84 | epoch=20, lr=[0.0025], 2024-01-06 22:28:57.504474 85 | Training [52%] Loss: 1.6998 86 | Saving model state to models/ImgClass-Quanvolv.pth 87 | Accuracy on the validation set: 44.00% 88 | epoch=21, lr=[0.0025], 2024-01-06 22:31:02.007135 89 | Training [55%] Loss: 1.6904 90 | Saving model state to models/ImgClass-Quanvolv.pth 91 | Accuracy on the validation set: 45.00% 92 | epoch=22, lr=[0.0025], 2024-01-06 22:33:06.742635 93 | 94 | ------------------------------------------- 95 | 96 | n_epochs=40, n_train=500, initial_lr=0.0035, gamma=1.0 97 | momentum=0.9, nesterov=False 98 | 99 | epoch=0, lr=[0.004], 2024-01-06 20:25:42.576554 100 | Training [2%] Loss: 2.2541 101 | Saving model state to models/ImgClass-Quanvolv.pth 102 | Accuracy on the validation set: 31.00% 103 | epoch=1, lr=[0.004], 2024-01-06 20:27:38.515313 104 | Training [5%] Loss: 2.1345 105 | Saving model state to models/ImgClass-Quanvolv.pth 106 | Accuracy on the validation set: 31.00% 107 | epoch=2, lr=[0.004], 2024-01-06 20:29:29.783364 108 | Training [8%] Loss: 2.0354 109 | Saving model state to models/ImgClass-Quanvolv.pth 110 | Accuracy on the validation set: 37.00% 111 | epoch=3, lr=[0.004], 2024-01-06 20:31:19.877972 112 | Training [10%] Loss: 1.9549 113 | Saving model state to models/ImgClass-Quanvolv.pth 114 | Accuracy on the validation set: 39.00% 115 | epoch=4, lr=[0.004], 2024-01-06 20:33:10.711311 116 | Training [12%] Loss: 1.8882 117 | Saving model state to models/ImgClass-Quanvolv.pth 118 | Accuracy on the validation set: 40.00% 119 | epoch=5, lr=[0.004], 2024-01-06 20:35:00.742746 120 | Training [15%] Loss: 1.8324 121 | Saving model state to models/ImgClass-Quanvolv.pth 122 | Accuracy on the validation set: 42.00% 123 | epoch=6, lr=[0.004], 2024-01-06 20:36:51.431967 124 | Training [18%] Loss: 1.7844 125 | Saving model state to models/ImgClass-Quanvolv.pth 126 | Accuracy on the validation set: 43.00% 127 | epoch=7, lr=[0.004], 2024-01-06 20:38:42.039556 128 | Training [20%] Loss: 1.7429 129 | Saving model state to models/ImgClass-Quanvolv.pth 130 | Accuracy on the validation set: 46.00% 131 | epoch=8, lr=[0.004], 2024-01-06 20:40:33.206508 132 | Training [22%] Loss: 1.7068 133 | Saving model state to models/ImgClass-Quanvolv.pth 134 | Accuracy on the validation set: 46.00% 135 | epoch=9, lr=[0.004], 2024-01-06 20:42:24.363130 136 | Training [25%] Loss: 1.6749 137 | Saving model state to models/ImgClass-Quanvolv.pth 138 | Accuracy on the validation set: 48.00% 139 | epoch=10, lr=[0.004], 2024-01-06 20:44:15.559977 140 | Training [28%] Loss: 1.6463 141 | Saving model state to models/ImgClass-Quanvolv.pth 142 | Accuracy on the validation set: 49.00% 143 | epoch=11, lr=[0.004], 2024-01-06 20:46:07.154369 144 | Training [30%] Loss: 1.6208 145 | Saving model state to models/ImgClass-Quanvolv.pth 146 | Accuracy on the validation set: 51.00% 147 | epoch=12, lr=[0.004], 2024-01-06 20:47:58.180498 148 | Training [32%] Loss: 1.5975 149 | Saving model state to models/ImgClass-Quanvolv.pth 150 | Accuracy on the validation set: 51.00% 151 | epoch=13, lr=[0.004], 2024-01-06 20:49:49.662972 152 | Training [35%] Loss: 1.5763 153 | Saving model state to models/ImgClass-Quanvolv.pth 154 | Accuracy on the validation set: 51.00% 155 | epoch=14, lr=[0.004], 2024-01-06 20:51:41.539777 156 | Training [38%] Loss: 1.5571 157 | Saving model state to models/ImgClass-Quanvolv.pth 158 | Accuracy on the validation set: 51.00% 159 | epoch=15, lr=[0.004], 2024-01-06 20:53:32.104577 160 | Training [40%] Loss: 1.5392 161 | Saving model state to models/ImgClass-Quanvolv.pth 162 | Accuracy on the validation set: 51.00% 163 | epoch=16, lr=[0.004], 2024-01-06 20:55:24.053864 164 | Training [42%] Loss: 1.5228 165 | Saving model state to models/ImgClass-Quanvolv.pth 166 | Accuracy on the validation set: 51.00% 167 | epoch=17, lr=[0.004], 2024-01-06 20:57:16.142895 168 | Training [45%] Loss: 1.5075 169 | Saving model state to models/ImgClass-Quanvolv.pth 170 | Accuracy on the validation set: 52.00% 171 | epoch=18, lr=[0.004], 2024-01-06 20:59:06.873999 172 | Training [48%] Loss: 1.4933 173 | Saving model state to models/ImgClass-Quanvolv.pth 174 | Accuracy on the validation set: 52.00% 175 | epoch=19, lr=[0.004], 2024-01-06 21:00:58.734462 176 | Training [50%] Loss: 1.4802 177 | Saving model state to models/ImgClass-Quanvolv.pth 178 | Accuracy on the validation set: 52.00% 179 | epoch=20, lr=[0.004], 2024-01-06 21:02:51.020655 180 | Training [52%] Loss: 1.4679 181 | Saving model state to models/ImgClass-Quanvolv.pth 182 | Accuracy on the validation set: 52.00% 183 | epoch=21, lr=[0.004], 2024-01-06 21:04:43.478853 184 | Training [55%] Loss: 1.4562 185 | Saving model state to models/ImgClass-Quanvolv.pth 186 | Accuracy on the validation set: 53.00% 187 | epoch=22, lr=[0.004], 2024-01-06 21:06:35.138327 188 | Training [58%] Loss: 1.4453 189 | Saving model state to models/ImgClass-Quanvolv.pth 190 | Accuracy on the validation set: 54.00% 191 | epoch=23, lr=[0.004], 2024-01-06 21:08:25.864664 192 | Training [60%] Loss: 1.4349 193 | Saving model state to models/ImgClass-Quanvolv.pth 194 | Accuracy on the validation set: 54.00% 195 | epoch=24, lr=[0.004], 2024-01-06 21:10:16.841934 196 | Training [62%] Loss: 1.4254 197 | Saving model state to models/ImgClass-Quanvolv.pth 198 | Accuracy on the validation set: 54.00% 199 | epoch=25, lr=[0.004], 2024-01-06 21:12:09.015262 200 | Training [65%] Loss: 1.4161 201 | Saving model state to models/ImgClass-Quanvolv.pth 202 | Accuracy on the validation set: 54.00% 203 | epoch=26, lr=[0.004], 2024-01-06 21:14:00.441242 204 | Training [68%] Loss: 1.4075 205 | Saving model state to models/ImgClass-Quanvolv.pth 206 | Accuracy on the validation set: 55.00% 207 | epoch=27, lr=[0.004], 2024-01-06 21:15:53.875285 208 | Training [70%] Loss: 1.3992 209 | Saving model state to models/ImgClass-Quanvolv.pth 210 | Accuracy on the validation set: 57.00% 211 | epoch=28, lr=[0.004], 2024-01-06 21:17:51.124039 212 | Training [72%] Loss: 1.3914 213 | Saving model state to models/ImgClass-Quanvolv.pth 214 | Accuracy on the validation set: 57.00% 215 | epoch=29, lr=[0.004], 2024-01-06 21:19:47.453997 216 | Training [75%] Loss: 1.3836 217 | Saving model state to models/ImgClass-Quanvolv.pth 218 | Accuracy on the validation set: 57.00% 219 | epoch=30, lr=[0.004], 2024-01-06 21:21:46.837950 220 | Training [78%] Loss: 1.3763 221 | Saving model state to models/ImgClass-Quanvolv.pth 222 | Accuracy on the validation set: 58.00% 223 | epoch=31, lr=[0.004], 2024-01-06 21:23:54.313153 224 | Training [80%] Loss: 1.3693 225 | Saving model state to models/ImgClass-Quanvolv.pth 226 | Accuracy on the validation set: 58.00% 227 | epoch=32, lr=[0.004], 2024-01-06 21:25:58.073654 228 | Training [82%] Loss: 1.3627 229 | Saving model state to models/ImgClass-Quanvolv.pth 230 | Accuracy on the validation set: 58.00% 231 | epoch=33, lr=[0.004], 2024-01-06 21:27:54.205717 232 | Training [85%] Loss: 1.3562 233 | Saving model state to models/ImgClass-Quanvolv.pth 234 | Accuracy on the validation set: 58.00% 235 | epoch=34, lr=[0.004], 2024-01-06 21:29:49.826544 236 | Training [88%] Loss: 1.3501 237 | Saving model state to models/ImgClass-Quanvolv.pth 238 | Accuracy on the validation set: 58.00% 239 | epoch=35, lr=[0.004], 2024-01-06 21:31:46.567641 240 | Training [90%] Loss: 1.3441 241 | Saving model state to models/ImgClass-Quanvolv.pth 242 | Accuracy on the validation set: 58.00% 243 | epoch=36, lr=[0.004], 2024-01-06 21:33:42.126451 244 | Training [92%] Loss: 1.3384 245 | Saving model state to models/ImgClass-Quanvolv.pth 246 | Accuracy on the validation set: 58.00% 247 | epoch=37, lr=[0.004], 2024-01-06 21:35:37.756882 248 | Training [95%] Loss: 1.3329 249 | Saving model state to models/ImgClass-Quanvolv.pth 250 | Accuracy on the validation set: 58.00% 251 | epoch=38, lr=[0.004], 2024-01-06 21:37:32.694493 252 | Training [98%] Loss: 1.3274 253 | Saving model state to models/ImgClass-Quanvolv.pth 254 | Accuracy on the validation set: 58.00% 255 | epoch=39, lr=[0.004], 2024-01-06 21:39:28.277719 256 | 257 | ------------------------------------------- 258 | 259 | n_epochs=40, n_train=500, initial_lr=0.003, gamma=1.0 260 | momentum=0.9, nesterov=False 261 | 262 | epoch=0, lr=[0.003], 2024-01-06 08:52:24.689584 263 | Training [2%] Loss: 2.2525 264 | Saving model state to models/ImgClass-Quanvolv.pth 265 | Accuracy on the validation set: 29.00% 266 | epoch=1, lr=[0.003], 2024-01-06 08:54:13.441877 267 | Training [5%] Loss: 2.1411 268 | Saving model state to models/ImgClass-Quanvolv.pth 269 | Accuracy on the validation set: 35.00% 270 | epoch=2, lr=[0.003], 2024-01-06 08:56:01.874758 271 | Training [8%] Loss: 2.0432 272 | Saving model state to models/ImgClass-Quanvolv.pth 273 | Accuracy on the validation set: 38.00% 274 | epoch=3, lr=[0.003], 2024-01-06 08:57:51.428558 275 | Training [10%] Loss: 1.9578 276 | Saving model state to models/ImgClass-Quanvolv.pth 277 | Accuracy on the validation set: 44.00% 278 | epoch=4, lr=[0.003], 2024-01-06 08:59:39.916853 279 | Training [12%] Loss: 1.8841 280 | Saving model state to models/ImgClass-Quanvolv.pth 281 | Accuracy on the validation set: 47.00% 282 | epoch=5, lr=[0.003], 2024-01-06 09:01:27.895234 283 | Training [15%] Loss: 1.8204 284 | Saving model state to models/ImgClass-Quanvolv.pth 285 | Accuracy on the validation set: 46.00% 286 | epoch=6, lr=[0.003], 2024-01-06 09:03:16.802128 287 | Training [18%] Loss: 1.7650 288 | Saving model state to models/ImgClass-Quanvolv.pth 289 | Accuracy on the validation set: 48.00% 290 | epoch=7, lr=[0.003], 2024-01-06 09:05:04.810491 291 | Training [20%] Loss: 1.7162 292 | Saving model state to models/ImgClass-Quanvolv.pth 293 | Accuracy on the validation set: 48.00% 294 | epoch=8, lr=[0.003], 2024-01-06 09:06:53.920299 295 | Training [22%] Loss: 1.6731 296 | Saving model state to models/ImgClass-Quanvolv.pth 297 | Accuracy on the validation set: 48.00% 298 | epoch=9, lr=[0.003], 2024-01-06 09:08:45.046283 299 | Training [25%] Loss: 1.6345 300 | Saving model state to models/ImgClass-Quanvolv.pth 301 | Accuracy on the validation set: 48.00% 302 | epoch=10, lr=[0.003], 2024-01-06 09:10:33.695548 303 | Training [28%] Loss: 1.5999 304 | Saving model state to models/ImgClass-Quanvolv.pth 305 | Accuracy on the validation set: 48.00% 306 | epoch=11, lr=[0.003], 2024-01-06 09:12:21.241255 307 | Training [30%] Loss: 1.5687 308 | Saving model state to models/ImgClass-Quanvolv.pth 309 | Accuracy on the validation set: 50.00% 310 | epoch=12, lr=[0.003], 2024-01-06 09:14:08.938715 311 | Training [32%] Loss: 1.5402 312 | Saving model state to models/ImgClass-Quanvolv.pth 313 | Accuracy on the validation set: 50.00% 314 | epoch=13, lr=[0.003], 2024-01-06 09:15:57.113400 315 | Training [35%] Loss: 1.5144 316 | Saving model state to models/ImgClass-Quanvolv.pth 317 | Accuracy on the validation set: 51.00% 318 | epoch=14, lr=[0.003], 2024-01-06 09:17:44.846734 319 | Training [38%] Loss: 1.4905 320 | Saving model state to models/ImgClass-Quanvolv.pth 321 | Accuracy on the validation set: 51.00% 322 | epoch=15, lr=[0.003], 2024-01-06 09:19:33.234402 323 | Training [40%] Loss: 1.4685 324 | Saving model state to models/ImgClass-Quanvolv.pth 325 | Accuracy on the validation set: 51.00% 326 | epoch=16, lr=[0.003], 2024-01-06 09:21:20.926591 327 | Training [42%] Loss: 1.4415 328 | Saving model state to models/ImgClass-Quanvolv.pth 329 | Accuracy on the validation set: 51.00% 330 | epoch=17, lr=[0.003], 2024-01-06 09:23:07.967694 331 | Training [45%] Loss: 1.4048 332 | Saving model state to models/ImgClass-Quanvolv.pth 333 | Accuracy on the validation set: 54.00% 334 | epoch=18, lr=[0.003], 2024-01-06 09:24:55.868774 335 | Training [48%] Loss: 1.3741 336 | Saving model state to models/ImgClass-Quanvolv.pth 337 | Accuracy on the validation set: 56.00% 338 | epoch=19, lr=[0.003], 2024-01-06 09:26:43.173957 339 | Training [50%] Loss: 1.3464 340 | Saving model state to models/ImgClass-Quanvolv.pth 341 | Accuracy on the validation set: 57.00% 342 | epoch=20, lr=[0.003], 2024-01-06 09:28:30.809784 343 | Training [52%] Loss: 1.3208 344 | Saving model state to models/ImgClass-Quanvolv.pth 345 | Accuracy on the validation set: 59.00% 346 | epoch=21, lr=[0.003], 2024-01-06 09:30:19.034058 347 | Training [55%] Loss: 1.2969 348 | Saving model state to models/ImgClass-Quanvolv.pth 349 | Accuracy on the validation set: 60.00% 350 | epoch=22, lr=[0.003], 2024-01-06 09:32:06.517458 351 | Training [58%] Loss: 1.2746 352 | Saving model state to models/ImgClass-Quanvolv.pth 353 | Accuracy on the validation set: 61.00% 354 | epoch=23, lr=[0.003], 2024-01-06 09:33:54.681066 355 | Training [60%] Loss: 1.2533 356 | Saving model state to models/ImgClass-Quanvolv.pth 357 | Accuracy on the validation set: 61.00% 358 | epoch=24, lr=[0.003], 2024-01-06 09:35:41.876052 359 | Training [62%] Loss: 1.2118 360 | Saving model state to models/ImgClass-Quanvolv.pth 361 | Accuracy on the validation set: 63.00% 362 | epoch=25, lr=[0.003], 2024-01-06 09:37:29.613308 363 | Training [65%] Loss: 1.1729 364 | Saving model state to models/ImgClass-Quanvolv.pth 365 | Accuracy on the validation set: 65.00% 366 | epoch=26, lr=[0.003], 2024-01-06 09:39:18.058590 367 | Training [68%] Loss: 1.1459 368 | Saving model state to models/ImgClass-Quanvolv.pth 369 | Accuracy on the validation set: 65.00% 370 | epoch=27, lr=[0.003], 2024-01-06 09:41:05.589280 371 | Training [70%] Loss: 1.1220 372 | Saving model state to models/ImgClass-Quanvolv.pth 373 | Accuracy on the validation set: 68.00% 374 | epoch=28, lr=[0.003], 2024-01-06 09:42:53.002933 375 | Training [72%] Loss: 1.1003 376 | Saving model state to models/ImgClass-Quanvolv.pth 377 | Accuracy on the validation set: 68.00% 378 | epoch=29, lr=[0.003], 2024-01-06 09:44:41.033553 379 | Training [75%] Loss: 1.0803 380 | Saving model state to models/ImgClass-Quanvolv.pth 381 | Accuracy on the validation set: 68.00% 382 | epoch=30, lr=[0.003], 2024-01-06 09:46:28.405035 383 | Training [78%] Loss: 1.0615 384 | Saving model state to models/ImgClass-Quanvolv.pth 385 | Accuracy on the validation set: 68.00% 386 | epoch=31, lr=[0.003], 2024-01-06 09:48:16.492715 387 | Training [80%] Loss: 1.0439 388 | Saving model state to models/ImgClass-Quanvolv.pth 389 | Accuracy on the validation set: 68.00% 390 | epoch=32, lr=[0.003], 2024-01-06 09:50:03.597142 391 | Training [82%] Loss: 1.0275 392 | Saving model state to models/ImgClass-Quanvolv.pth 393 | Accuracy on the validation set: 69.00% 394 | epoch=33, lr=[0.003], 2024-01-06 09:51:50.538511 395 | Training [85%] Loss: 1.0120 396 | Saving model state to models/ImgClass-Quanvolv.pth 397 | Accuracy on the validation set: 69.00% 398 | epoch=34, lr=[0.003], 2024-01-06 09:53:38.600785 399 | Training [88%] Loss: 0.9973 400 | Saving model state to models/ImgClass-Quanvolv.pth 401 | Accuracy on the validation set: 69.00% 402 | epoch=35, lr=[0.003], 2024-01-06 09:55:25.615325 403 | Training [90%] Loss: 0.9835 404 | Saving model state to models/ImgClass-Quanvolv.pth 405 | Accuracy on the validation set: 69.00% 406 | epoch=36, lr=[0.003], 2024-01-06 09:57:12.726703 407 | Training [92%] Loss: 0.9705 408 | Saving model state to models/ImgClass-Quanvolv.pth 409 | Accuracy on the validation set: 70.00% 410 | epoch=37, lr=[0.003], 2024-01-06 09:59:00.840666 411 | Training [95%] Loss: 0.9581 412 | Saving model state to models/ImgClass-Quanvolv.pth 413 | Accuracy on the validation set: 70.00% 414 | epoch=38, lr=[0.003], 2024-01-06 10:00:48.244852 415 | Training [98%] Loss: 0.9462 416 | Saving model state to models/ImgClass-Quanvolv.pth 417 | Accuracy on the validation set: 70.00% 418 | epoch=39, lr=[0.003], 2024-01-06 10:02:36.607641 419 | Training [100%] Loss: 0.9350 420 | Saving model state to models/ImgClass-Quanvolv.pth 421 | Accuracy on the validation set: 70.00% 422 | Finished Training 423 | 424 | 425 | ---------------------------------- 426 | 427 | n_epochs=40, n_train=500, initial_lr=0.002, gamma=1.0 428 | momentum=0.9, nesterov=False 429 | 430 | epoch=0, lr=[0.002], 2024-01-06 07:38:42.481728 431 | Training [2%] Loss: 2.2745 432 | Saving model state to models/ImgClass-Quanvolv.pth 433 | Accuracy on the validation set: 16.00% 434 | epoch=1, lr=[0.002], 2024-01-06 07:40:30.184898 435 | Training [5%] Loss: 2.2121 436 | Saving model state to models/ImgClass-Quanvolv.pth 437 | Accuracy on the validation set: 22.00% 438 | epoch=2, lr=[0.002], 2024-01-06 07:42:17.882459 439 | Training [8%] Loss: 2.1538 440 | Saving model state to models/ImgClass-Quanvolv.pth 441 | Accuracy on the validation set: 26.00% 442 | epoch=3, lr=[0.002], 2024-01-06 07:44:05.318943 443 | Training [10%] Loss: 2.1067 444 | Saving model state to models/ImgClass-Quanvolv.pth 445 | Accuracy on the validation set: 26.00% 446 | epoch=4, lr=[0.002], 2024-01-06 07:45:52.116827 447 | Training [12%] Loss: 2.0666 448 | Saving model state to models/ImgClass-Quanvolv.pth 449 | Accuracy on the validation set: 28.00% 450 | epoch=5, lr=[0.002], 2024-01-06 07:47:40.236447 451 | Training [15%] Loss: 2.0317 452 | Saving model state to models/ImgClass-Quanvolv.pth 453 | Accuracy on the validation set: 29.00% 454 | epoch=6, lr=[0.002], 2024-01-06 07:49:27.469576 455 | Training [18%] Loss: 2.0008 456 | Saving model state to models/ImgClass-Quanvolv.pth 457 | Accuracy on the validation set: 33.00% 458 | epoch=7, lr=[0.002], 2024-01-06 07:51:14.456527 459 | Training [20%] Loss: 1.9610 460 | Saving model state to models/ImgClass-Quanvolv.pth 461 | Accuracy on the validation set: 35.00% 462 | epoch=8, lr=[0.002], 2024-01-06 07:53:02.538428 463 | Training [22%] Loss: 1.9259 464 | Saving model state to models/ImgClass-Quanvolv.pth 465 | Accuracy on the validation set: 35.00% 466 | epoch=9, lr=[0.002], 2024-01-06 07:54:50.853916 467 | Training [25%] Loss: 1.8955 468 | Saving model state to models/ImgClass-Quanvolv.pth 469 | Accuracy on the validation set: 39.00% 470 | epoch=10, lr=[0.002], 2024-01-06 07:56:39.229828 471 | Training [28%] Loss: 1.8684 472 | Saving model state to models/ImgClass-Quanvolv.pth 473 | Accuracy on the validation set: 39.00% 474 | epoch=11, lr=[0.002], 2024-01-06 07:58:27.233021 475 | Training [30%] Loss: 1.8444 476 | Saving model state to models/ImgClass-Quanvolv.pth 477 | Accuracy on the validation set: 41.00% 478 | epoch=12, lr=[0.002], 2024-01-06 08:00:16.880991 479 | Training [32%] Loss: 1.8226 480 | Saving model state to models/ImgClass-Quanvolv.pth 481 | Accuracy on the validation set: 41.00% 482 | epoch=13, lr=[0.002], 2024-01-06 08:02:05.691452 483 | Training [35%] Loss: 1.8029 484 | Saving model state to models/ImgClass-Quanvolv.pth 485 | Accuracy on the validation set: 41.00% 486 | epoch=14, lr=[0.002], 2024-01-06 08:03:53.700569 487 | Training [38%] Loss: 1.7848 488 | Saving model state to models/ImgClass-Quanvolv.pth 489 | Accuracy on the validation set: 41.00% 490 | epoch=15, lr=[0.002], 2024-01-06 08:05:42.433512 491 | Training [40%] Loss: 1.7683 492 | Saving model state to models/ImgClass-Quanvolv.pth 493 | Accuracy on the validation set: 42.00% 494 | epoch=16, lr=[0.002], 2024-01-06 08:07:30.234302 495 | Training [42%] Loss: 1.7531 496 | Saving model state to models/ImgClass-Quanvolv.pth 497 | Accuracy on the validation set: 43.00% 498 | epoch=17, lr=[0.002], 2024-01-06 08:09:20.167318 499 | Training [45%] Loss: 1.7392 500 | Saving model state to models/ImgClass-Quanvolv.pth 501 | Accuracy on the validation set: 43.00% 502 | epoch=18, lr=[0.002], 2024-01-06 08:11:08.830681 503 | Training [48%] Loss: 1.7265 504 | Saving model state to models/ImgClass-Quanvolv.pth 505 | Accuracy on the validation set: 43.00% 506 | epoch=19, lr=[0.002], 2024-01-06 08:12:57.270538 507 | Training [50%] Loss: 1.7144 508 | Saving model state to models/ImgClass-Quanvolv.pth 509 | Accuracy on the validation set: 43.00% 510 | epoch=20, lr=[0.002], 2024-01-06 08:14:45.972366 511 | Training [52%] Loss: 1.7025 512 | Saving model state to models/ImgClass-Quanvolv.pth 513 | Accuracy on the validation set: 43.00% 514 | epoch=21, lr=[0.002], 2024-01-06 08:16:33.348491 515 | Training [55%] Loss: 1.6920 516 | Saving model state to models/ImgClass-Quanvolv.pth 517 | Accuracy on the validation set: 43.00% 518 | epoch=22, lr=[0.002], 2024-01-06 08:18:21.871190 519 | Training [58%] Loss: 1.6820 520 | Saving model state to models/ImgClass-Quanvolv.pth 521 | Accuracy on the validation set: 44.00% 522 | epoch=23, lr=[0.002], 2024-01-06 08:20:09.283989 523 | Training [60%] Loss: 1.6728 524 | Saving model state to models/ImgClass-Quanvolv.pth 525 | Accuracy on the validation set: 44.00% 526 | epoch=24, lr=[0.002], 2024-01-06 08:21:57.801130 527 | Training [62%] Loss: 1.6639 528 | Saving model state to models/ImgClass-Quanvolv.pth 529 | Accuracy on the validation set: 45.00% 530 | epoch=25, lr=[0.002], 2024-01-06 08:23:46.275783 531 | Training [65%] Loss: 1.6554 532 | Saving model state to models/ImgClass-Quanvolv.pth 533 | Accuracy on the validation set: 45.00% 534 | epoch=26, lr=[0.002], 2024-01-06 08:25:34.317704 535 | Training [68%] Loss: 1.6475 536 | Saving model state to models/ImgClass-Quanvolv.pth 537 | Accuracy on the validation set: 45.00% 538 | epoch=27, lr=[0.002], 2024-01-06 08:27:22.783228 539 | Training [70%] Loss: 1.6399 540 | Saving model state to models/ImgClass-Quanvolv.pth 541 | Accuracy on the validation set: 45.00% 542 | epoch=28, lr=[0.002], 2024-01-06 08:29:10.325441 543 | Training [72%] Loss: 1.6327 544 | Saving model state to models/ImgClass-Quanvolv.pth 545 | Accuracy on the validation set: 45.00% 546 | epoch=29, lr=[0.002], 2024-01-06 08:30:58.814884 547 | Training [75%] Loss: 1.6258 548 | Saving model state to models/ImgClass-Quanvolv.pth 549 | Accuracy on the validation set: 45.00% 550 | epoch=30, lr=[0.002], 2024-01-06 08:32:46.844126 551 | Training [78%] Loss: 1.6192 552 | Saving model state to models/ImgClass-Quanvolv.pth 553 | Accuracy on the validation set: 45.00% 554 | epoch=31, lr=[0.002], 2024-01-06 08:34:34.817675 555 | Training [80%] Loss: 1.6128 556 | Saving model state to models/ImgClass-Quanvolv.pth 557 | Accuracy on the validation set: 46.00% 558 | epoch=32, lr=[0.002], 2024-01-06 08:36:23.678349 559 | Training [82%] Loss: 1.6068 560 | Saving model state to models/ImgClass-Quanvolv.pth 561 | Accuracy on the validation set: 46.00% 562 | epoch=33, lr=[0.002], 2024-01-06 08:38:11.610040 563 | Training [85%] Loss: 1.6009 564 | Saving model state to models/ImgClass-Quanvolv.pth 565 | Accuracy on the validation set: 46.00% 566 | epoch=34, lr=[0.002], 2024-01-06 08:40:00.377915 567 | Training [88%] Loss: 1.5954 568 | Saving model state to models/ImgClass-Quanvolv.pth 569 | Accuracy on the validation set: 46.00% 570 | epoch=35, lr=[0.002], 2024-01-06 08:41:47.938136 571 | Training [90%] Loss: 1.5899 572 | Saving model state to models/ImgClass-Quanvolv.pth 573 | Accuracy on the validation set: 46.00% 574 | epoch=36, lr=[0.002], 2024-01-06 08:43:35.222422 575 | Training [92%] Loss: 1.5849 576 | Saving model state to models/ImgClass-Quanvolv.pth 577 | Accuracy on the validation set: 46.00% 578 | epoch=37, lr=[0.002], 2024-01-06 08:45:23.749900 579 | Training [95%] Loss: 1.5801 580 | Saving model state to models/ImgClass-Quanvolv.pth 581 | Accuracy on the validation set: 46.00% 582 | epoch=38, lr=[0.002], 2024-01-06 08:47:11.963641 583 | Training [98%] Loss: 1.5753 584 | Saving model state to models/ImgClass-Quanvolv.pth 585 | Accuracy on the validation set: 46.00% 586 | epoch=39, lr=[0.002], 2024-01-06 08:49:00.281950 587 | Training [100%] Loss: 1.5706 588 | Saving model state to models/ImgClass-Quanvolv.pth 589 | Accuracy on the validation set: 46.00% 590 | Finished Training 591 | 592 | ---------------------------------------- 593 | 594 | n_epochs=40, n_train=500, initial_lr=0.001, gamma=1.0 595 | momentum=0.9, nesterov=False 596 | 597 | epoch=0, lr=[0.001], 2024-01-06 06:25:19.389842 598 | Training [2%] Loss: 2.2858 599 | Saving model state to models/ImgClass-Quanvolv.pth 600 | Accuracy on the validation set: 29.00% 601 | epoch=1, lr=[0.001], 2024-01-06 06:27:08.286766 602 | Training [5%] Loss: 2.2646 603 | Saving model state to models/ImgClass-Quanvolv.pth 604 | Accuracy on the validation set: 29.00% 605 | epoch=2, lr=[0.001], 2024-01-06 06:28:55.091027 606 | Training [8%] Loss: 2.2469 607 | Saving model state to models/ImgClass-Quanvolv.pth 608 | Accuracy on the validation set: 29.00% 609 | epoch=3, lr=[0.001], 2024-01-06 06:30:42.401500 610 | Training [10%] Loss: 2.2314 611 | Saving model state to models/ImgClass-Quanvolv.pth 612 | Accuracy on the validation set: 29.00% 613 | epoch=4, lr=[0.001], 2024-01-06 06:32:30.089318 614 | Training [12%] Loss: 2.2167 615 | Saving model state to models/ImgClass-Quanvolv.pth 616 | Accuracy on the validation set: 29.00% 617 | epoch=5, lr=[0.001], 2024-01-06 06:34:17.282891 618 | Training [15%] Loss: 2.2027 619 | Saving model state to models/ImgClass-Quanvolv.pth 620 | Accuracy on the validation set: 29.00% 621 | epoch=6, lr=[0.001], 2024-01-06 06:36:04.796057 622 | Training [18%] Loss: 2.1891 623 | Saving model state to models/ImgClass-Quanvolv.pth 624 | Accuracy on the validation set: 29.00% 625 | epoch=7, lr=[0.001], 2024-01-06 06:37:52.939476 626 | Training [20%] Loss: 2.1759 627 | Saving model state to models/ImgClass-Quanvolv.pth 628 | Accuracy on the validation set: 29.00% 629 | epoch=8, lr=[0.001], 2024-01-06 06:39:39.829150 630 | Training [22%] Loss: 2.1629 631 | Saving model state to models/ImgClass-Quanvolv.pth 632 | Accuracy on the validation set: 29.00% 633 | epoch=9, lr=[0.001], 2024-01-06 06:41:29.778578 634 | Training [25%] Loss: 2.1502 635 | Saving model state to models/ImgClass-Quanvolv.pth 636 | Accuracy on the validation set: 30.00% 637 | epoch=10, lr=[0.001], 2024-01-06 06:43:16.498598 638 | Training [28%] Loss: 2.1378 639 | Saving model state to models/ImgClass-Quanvolv.pth 640 | Accuracy on the validation set: 32.00% 641 | epoch=11, lr=[0.001], 2024-01-06 06:45:03.482567 642 | Training [30%] Loss: 2.1255 643 | Saving model state to models/ImgClass-Quanvolv.pth 644 | Accuracy on the validation set: 33.00% 645 | epoch=12, lr=[0.001], 2024-01-06 06:46:51.556702 646 | Training [32%] Loss: 2.1136 647 | Saving model state to models/ImgClass-Quanvolv.pth 648 | Accuracy on the validation set: 33.00% 649 | epoch=0, lr=[0.001], 2024-01-06 06:25:19.389842 650 | Training [2%] Loss: 2.2858 651 | Saving model state to models/ImgClass-Quanvolv.pth 652 | Accuracy on the validation set: 29.00% 653 | epoch=1, lr=[0.001], 2024-01-06 06:27:08.286766 654 | Training [5%] Loss: 2.2646 655 | Saving model state to models/ImgClass-Quanvolv.pth 656 | Accuracy on the validation set: 29.00% 657 | epoch=2, lr=[0.001], 2024-01-06 06:28:55.091027 658 | Training [8%] Loss: 2.2469 659 | Saving model state to models/ImgClass-Quanvolv.pth 660 | Accuracy on the validation set: 29.00% 661 | epoch=3, lr=[0.001], 2024-01-06 06:30:42.401500 662 | Training [10%] Loss: 2.2314 663 | Saving model state to models/ImgClass-Quanvolv.pth 664 | Accuracy on the validation set: 29.00% 665 | epoch=4, lr=[0.001], 2024-01-06 06:32:30.089318 666 | Training [12%] Loss: 2.2167 667 | Saving model state to models/ImgClass-Quanvolv.pth 668 | Accuracy on the validation set: 29.00% 669 | epoch=5, lr=[0.001], 2024-01-06 06:34:17.282891 670 | Training [15%] Loss: 2.2027 671 | Saving model state to models/ImgClass-Quanvolv.pth 672 | Accuracy on the validation set: 29.00% 673 | epoch=6, lr=[0.001], 2024-01-06 06:36:04.796057 674 | Training [18%] Loss: 2.1891 675 | Saving model state to models/ImgClass-Quanvolv.pth 676 | Accuracy on the validation set: 29.00% 677 | epoch=7, lr=[0.001], 2024-01-06 06:37:52.939476 678 | Training [20%] Loss: 2.1759 679 | Saving model state to models/ImgClass-Quanvolv.pth 680 | Accuracy on the validation set: 29.00% 681 | epoch=8, lr=[0.001], 2024-01-06 06:39:39.829150 682 | Training [22%] Loss: 2.1629 683 | Saving model state to models/ImgClass-Quanvolv.pth 684 | Accuracy on the validation set: 29.00% 685 | epoch=9, lr=[0.001], 2024-01-06 06:41:29.778578 686 | Training [25%] Loss: 2.1502 687 | Saving model state to models/ImgClass-Quanvolv.pth 688 | Accuracy on the validation set: 30.00% 689 | epoch=10, lr=[0.001], 2024-01-06 06:43:16.498598 690 | Training [28%] Loss: 2.1378 691 | Saving model state to models/ImgClass-Quanvolv.pth 692 | Accuracy on the validation set: 32.00% 693 | epoch=11, lr=[0.001], 2024-01-06 06:45:03.482567 694 | Training [30%] Loss: 2.1255 695 | Saving model state to models/ImgClass-Quanvolv.pth 696 | Accuracy on the validation set: 33.00% 697 | epoch=12, lr=[0.001], 2024-01-06 06:46:51.556702 698 | Training [32%] Loss: 2.1136 699 | Saving model state to models/ImgClass-Quanvolv.pth 700 | Accuracy on the validation set: 33.00% 701 | epoch=13, lr=[0.001], 2024-01-06 06:48:39.262166 702 | Training [35%] Loss: 2.1018 703 | Saving model state to models/ImgClass-Quanvolv.pth 704 | Accuracy on the validation set: 33.00% 705 | epoch=14, lr=[0.001], 2024-01-06 06:50:27.007334 706 | Training [38%] Loss: 2.0902 707 | Saving model state to models/ImgClass-Quanvolv.pth 708 | Accuracy on the validation set: 34.00% 709 | epoch=15, lr=[0.001], 2024-01-06 06:52:15.210870 710 | Training [40%] Loss: 2.0789 711 | Saving model state to models/ImgClass-Quanvolv.pth 712 | Accuracy on the validation set: 36.00% 713 | epoch=16, lr=[0.001], 2024-01-06 06:54:02.832253 714 | Training [42%] Loss: 2.0678 715 | Saving model state to models/ImgClass-Quanvolv.pth 716 | Accuracy on the validation set: 37.00% 717 | epoch=17, lr=[0.001], 2024-01-06 06:55:51.082397 718 | Training [45%] Loss: 2.0569 719 | Saving model state to models/ImgClass-Quanvolv.pth 720 | Accuracy on the validation set: 38.00% 721 | epoch=18, lr=[0.001], 2024-01-06 06:57:38.469320 722 | Training [48%] Loss: 2.0462 723 | Saving model state to models/ImgClass-Quanvolv.pth 724 | Accuracy on the validation set: 39.00% 725 | epoch=19, lr=[0.001], 2024-01-06 06:59:25.938056 726 | Training [50%] Loss: 2.0357 727 | Saving model state to models/ImgClass-Quanvolv.pth 728 | Accuracy on the validation set: 39.00% 729 | epoch=20, lr=[0.001], 2024-01-06 07:01:14.082041 730 | Training [52%] Loss: 2.0255 731 | Saving model state to models/ImgClass-Quanvolv.pth 732 | Accuracy on the validation set: 39.00% 733 | epoch=21, lr=[0.001], 2024-01-06 07:03:01.319095 734 | Training [55%] Loss: 2.0154 735 | Saving model state to models/ImgClass-Quanvolv.pth 736 | Accuracy on the validation set: 39.00% 737 | epoch=22, lr=[0.001], 2024-01-06 07:04:49.850050 738 | Training [58%] Loss: 2.0056 739 | Saving model state to models/ImgClass-Quanvolv.pth 740 | Accuracy on the validation set: 39.00% 741 | epoch=23, lr=[0.001], 2024-01-06 07:06:37.045023 742 | Training [60%] Loss: 1.9960 743 | Saving model state to models/ImgClass-Quanvolv.pth 744 | Accuracy on the validation set: 39.00% 745 | epoch=24, lr=[0.001], 2024-01-06 07:08:28.426736 746 | Training [62%] Loss: 1.9865 747 | Saving model state to models/ImgClass-Quanvolv.pth 748 | Accuracy on the validation set: 40.00% 749 | epoch=25, lr=[0.001], 2024-01-06 07:10:19.406478 750 | Training [65%] Loss: 1.9773 751 | Saving model state to models/ImgClass-Quanvolv.pth 752 | Accuracy on the validation set: 41.00% 753 | epoch=26, lr=[0.001], 2024-01-06 07:12:11.712794 754 | Training [68%] Loss: 1.9682 755 | Saving model state to models/ImgClass-Quanvolv.pth 756 | Accuracy on the validation set: 43.00% 757 | epoch=27, lr=[0.001], 2024-01-06 07:14:02.508663 758 | Training [70%] Loss: 1.9593 759 | Saving model state to models/ImgClass-Quanvolv.pth 760 | Accuracy on the validation set: 43.00% 761 | epoch=28, lr=[0.001], 2024-01-06 07:15:55.097974 762 | Training [72%] Loss: 1.9505 763 | Saving model state to models/ImgClass-Quanvolv.pth 764 | Accuracy on the validation set: 43.00% 765 | epoch=29, lr=[0.001], 2024-01-06 07:17:46.743385 766 | Training [75%] Loss: 1.9420 767 | Saving model state to models/ImgClass-Quanvolv.pth 768 | Accuracy on the validation set: 43.00% 769 | epoch=30, lr=[0.001], 2024-01-06 07:19:34.836893 770 | Training [78%] Loss: 1.9335 771 | Saving model state to models/ImgClass-Quanvolv.pth 772 | Accuracy on the validation set: 43.00% 773 | epoch=31, lr=[0.001], 2024-01-06 07:21:22.332036 774 | Training [80%] Loss: 1.9253 775 | Saving model state to models/ImgClass-Quanvolv.pth 776 | Accuracy on the validation set: 44.00% 777 | epoch=32, lr=[0.001], 2024-01-06 07:23:11.239341 778 | Training [82%] Loss: 1.9172 779 | Saving model state to models/ImgClass-Quanvolv.pth 780 | Accuracy on the validation set: 45.00% 781 | epoch=33, lr=[0.001], 2024-01-06 07:24:58.803599 782 | Training [85%] Loss: 1.9092 783 | Saving model state to models/ImgClass-Quanvolv.pth 784 | Accuracy on the validation set: 46.00% 785 | epoch=34, lr=[0.001], 2024-01-06 07:26:47.120606 786 | Training [88%] Loss: 1.9014 787 | Saving model state to models/ImgClass-Quanvolv.pth 788 | Accuracy on the validation set: 46.00% 789 | epoch=35, lr=[0.001], 2024-01-06 07:28:35.871062 790 | Training [90%] Loss: 1.8938 791 | Saving model state to models/ImgClass-Quanvolv.pth 792 | Accuracy on the validation set: 48.00% 793 | epoch=36, lr=[0.001], 2024-01-06 07:30:23.380835 794 | Training [92%] Loss: 1.8862 795 | Saving model state to models/ImgClass-Quanvolv.pth 796 | Accuracy on the validation set: 48.00% 797 | epoch=37, lr=[0.001], 2024-01-06 07:32:11.834365 798 | Training [95%] Loss: 1.8788 799 | Saving model state to models/ImgClass-Quanvolv.pth 800 | Accuracy on the validation set: 48.00% 801 | epoch=38, lr=[0.001], 2024-01-06 07:34:00.175946 802 | Training [98%] Loss: 1.8715 803 | Saving model state to models/ImgClass-Quanvolv.pth 804 | Accuracy on the validation set: 48.00% 805 | epoch=39, lr=[0.001], 2024-01-06 07:35:49.059563 806 | Training [100%] Loss: 1.8644 807 | Saving model state to models/ImgClass-Quanvolv.pth 808 | Accuracy on the validation set: 48.00% 809 | Finished Training 810 | 811 | ---------------------------------------- 812 | 813 | n_epochs=40, n_train=500, initial_lr=0.0005, gamma=1.0 814 | momentum=0.9, nesterov=False 815 | 816 | epoch=0, lr=[0.0005], 2024-01-06 05:18:14.538093 817 | Training [2%] Loss: 2.2971 818 | Saving model state to models/ImgClass-Quanvolv.pth 819 | Accuracy on the validation set: 10.00% 820 | epoch=1, lr=[0.0005], 2024-01-06 05:20:04.611698 821 | Training [5%] Loss: 2.2782 822 | Saving model state to models/ImgClass-Quanvolv.pth 823 | Accuracy on the validation set: 11.00% 824 | epoch=2, lr=[0.0005], 2024-01-06 05:21:53.095138 825 | Training [8%] Loss: 2.2613 826 | Saving model state to models/ImgClass-Quanvolv.pth 827 | Accuracy on the validation set: 18.00% 828 | epoch=3, lr=[0.0005], 2024-01-06 05:23:42.334454 829 | Training [10%] Loss: 2.2448 830 | Saving model state to models/ImgClass-Quanvolv.pth 831 | Accuracy on the validation set: 19.00% 832 | epoch=4, lr=[0.0005], 2024-01-06 05:25:30.654572 833 | Training [12%] Loss: 2.2274 834 | Saving model state to models/ImgClass-Quanvolv.pth 835 | Accuracy on the validation set: 22.00% 836 | epoch=5, lr=[0.0005], 2024-01-06 05:27:19.715867 837 | Training [15%] Loss: 2.2083 838 | Saving model state to models/ImgClass-Quanvolv.pth 839 | Accuracy on the validation set: 26.00% 840 | epoch=6, lr=[0.0005], 2024-01-06 05:29:07.987849 841 | Training [18%] Loss: 2.1903 842 | Saving model state to models/ImgClass-Quanvolv.pth 843 | Accuracy on the validation set: 26.00% 844 | epoch=7, lr=[0.0005], 2024-01-06 05:30:57.118689 845 | Training [20%] Loss: 2.1730 846 | Saving model state to models/ImgClass-Quanvolv.pth 847 | Accuracy on the validation set: 26.00% 848 | epoch=8, lr=[0.0005], 2024-01-06 05:32:45.244608 849 | Training [22%] Loss: 2.1561 850 | Saving model state to models/ImgClass-Quanvolv.pth 851 | Accuracy on the validation set: 27.00% 852 | epoch=9, lr=[0.0005], 2024-01-06 05:34:34.061418 853 | Training [25%] Loss: 2.1394 854 | Saving model state to models/ImgClass-Quanvolv.pth 855 | Accuracy on the validation set: 28.00% 856 | epoch=10, lr=[0.0005], 2024-01-06 05:36:22.648266 857 | Training [28%] Loss: 2.1227 858 | Saving model state to models/ImgClass-Quanvolv.pth 859 | Accuracy on the validation set: 30.00% 860 | epoch=11, lr=[0.0005], 2024-01-06 05:38:11.148976 861 | Training [30%] Loss: 2.1065 862 | Saving model state to models/ImgClass-Quanvolv.pth 863 | Accuracy on the validation set: 34.00% 864 | epoch=12, lr=[0.0005], 2024-01-06 05:40:00.119992 865 | Training [32%] Loss: 2.0912 866 | Saving model state to models/ImgClass-Quanvolv.pth 867 | Accuracy on the validation set: 36.00% 868 | epoch=13, lr=[0.0005], 2024-01-06 05:41:47.945410 869 | Training [35%] Loss: 2.0764 870 | Saving model state to models/ImgClass-Quanvolv.pth 871 | Accuracy on the validation set: 36.00% 872 | epoch=14, lr=[0.0005], 2024-01-06 05:43:37.269339 873 | Training [38%] Loss: 2.0621 874 | Saving model state to models/ImgClass-Quanvolv.pth 875 | Accuracy on the validation set: 36.00% 876 | epoch=15, lr=[0.0005], 2024-01-06 05:45:25.494261 877 | Training [40%] Loss: 2.0485 878 | Saving model state to models/ImgClass-Quanvolv.pth 879 | Accuracy on the validation set: 37.00% 880 | epoch=16, lr=[0.0005], 2024-01-06 05:47:14.412310 881 | Training [42%] Loss: 2.0352 882 | Saving model state to models/ImgClass-Quanvolv.pth 883 | Accuracy on the validation set: 37.00% 884 | epoch=17, lr=[0.0005], 2024-01-06 05:49:02.573017 885 | Training [45%] Loss: 2.0224 886 | Saving model state to models/ImgClass-Quanvolv.pth 887 | Accuracy on the validation set: 37.00% 888 | epoch=18, lr=[0.0005], 2024-01-06 05:50:51.704100 889 | Training [48%] Loss: 2.0100 890 | Saving model state to models/ImgClass-Quanvolv.pth 891 | Accuracy on the validation set: 37.00% 892 | epoch=19, lr=[0.0005], 2024-01-06 05:52:39.909892 893 | Training [50%] Loss: 1.9981 894 | Saving model state to models/ImgClass-Quanvolv.pth 895 | Accuracy on the validation set: 38.00% 896 | epoch=20, lr=[0.0005], 2024-01-06 05:54:29.450723 897 | Training [52%] Loss: 1.9865 898 | Saving model state to models/ImgClass-Quanvolv.pth 899 | Accuracy on the validation set: 38.00% 900 | epoch=21, lr=[0.0005], 2024-01-06 05:56:18.036922 901 | Training [55%] Loss: 1.9753 902 | Saving model state to models/ImgClass-Quanvolv.pth 903 | Accuracy on the validation set: 39.00% 904 | epoch=22, lr=[0.0005], 2024-01-06 05:58:07.553020 905 | Training [58%] Loss: 1.9644 906 | Saving model state to models/ImgClass-Quanvolv.pth 907 | Accuracy on the validation set: 39.00% 908 | epoch=23, lr=[0.0005], 2024-01-06 05:59:57.134333 909 | Training [60%] Loss: 1.9539 910 | Saving model state to models/ImgClass-Quanvolv.pth 911 | Accuracy on the validation set: 40.00% 912 | epoch=24, lr=[0.0005], 2024-01-06 06:01:46.932339 913 | Training [62%] Loss: 1.9437 914 | Saving model state to models/ImgClass-Quanvolv.pth 915 | Accuracy on the validation set: 40.00% 916 | epoch=25, lr=[0.0005], 2024-01-06 06:03:35.457495 917 | Training [65%] Loss: 1.9337 918 | Saving model state to models/ImgClass-Quanvolv.pth 919 | Accuracy on the validation set: 40.00% 920 | epoch=26, lr=[0.0005], 2024-01-06 06:05:25.499702 921 | Training [68%] Loss: 1.9241 922 | Saving model state to models/ImgClass-Quanvolv.pth 923 | Accuracy on the validation set: 40.00% 924 | epoch=27, lr=[0.0005], 2024-01-06 06:07:14.204328 925 | Training [70%] Loss: 1.9147 926 | Saving model state to models/ImgClass-Quanvolv.pth 927 | Accuracy on the validation set: 40.00% 928 | epoch=28, lr=[0.0005], 2024-01-06 06:09:03.901584 929 | Training [72%] Loss: 1.9055 930 | Saving model state to models/ImgClass-Quanvolv.pth 931 | Accuracy on the validation set: 40.00% 932 | epoch=29, lr=[0.0005], 2024-01-06 06:10:53.050053 933 | Training [75%] Loss: 1.8967 934 | Saving model state to models/ImgClass-Quanvolv.pth 935 | Accuracy on the validation set: 41.00% 936 | epoch=30, lr=[0.0005], 2024-01-06 06:12:42.819788 937 | Training [78%] Loss: 1.8880 938 | Saving model state to models/ImgClass-Quanvolv.pth 939 | Accuracy on the validation set: 41.00% 940 | epoch=31, lr=[0.0005], 2024-01-06 06:14:31.633133 941 | Training [80%] Loss: 1.8797 942 | Saving model state to models/ImgClass-Quanvolv.pth 943 | Accuracy on the validation set: 42.00% 944 | epoch=32, lr=[0.0005], 2024-01-06 06:16:21.019484 945 | Training [82%] Loss: 1.8714 946 | Saving model state to models/ImgClass-Quanvolv.pth 947 | Accuracy on the validation set: 42.00% 948 | epoch=33, lr=[0.0005], 2024-01-06 06:18:09.289937 949 | Training [85%] Loss: 1.8635 950 | Saving model state to models/ImgClass-Quanvolv.pth 951 | Accuracy on the validation set: 42.00% 952 | epoch=34, lr=[0.0005], 2024-01-06 06:19:58.028244 953 | Training [88%] Loss: 1.8556 954 | Saving model state to models/ImgClass-Quanvolv.pth 955 | Accuracy on the validation set: 42.00% 956 | epoch=35, lr=[0.0005], 2024-01-06 06:21:47.050639 957 | Training [90%] Loss: 1.8479 958 | Saving model state to models/ImgClass-Quanvolv.pth 959 | Accuracy on the validation set: 42.00% 960 | 961 | --------------------------------------------------------------------------------