├── requirements.txt ├── README.md ├── src ├── img_utils.py ├── datasets.py └── model.py ├── .gitignore └── train.ipynb /requirements.txt: -------------------------------------------------------------------------------- 1 | -f https://download.pytorch.org/whl/torch_stable.html 2 | torch==1.3.1 3 | torchvision==0.4.2 4 | Cython==0.29.14 5 | tqdm==4.41.1 6 | ipywidgets==7.5.1 7 | notebook==6.4.1 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Siamese Neural Network for Object Co-Segmentation 2 | 3 | This project aims to generate predictive background/foreground masks from image pairs with similar objects. [CoSegNet's architecture](https://www.ijcai.org/proceedings/2019/0095.pdf) was used as the basis for this project. 4 | 5 | ## Getting Started 6 | 7 | First, create the "icoseg_data" directory at this project's root and then download/extract the [iCoseg dataset](http://chenlab.ece.cornell.edu/projects/touch-coseg/) (what I used to train, feel free to use another alternative dataset) into the folder. 8 | 9 | Next, install all the python requirements: 10 | 11 | ```bash 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | To enable the tqdm progress bar in the Jupyter Notebook, please run the following command after dependencies are installed: 16 | 17 | ```bash 18 | jupyter nbextension enable --py widgetsnbextension --sys-prefix 19 | ``` 20 | 21 | If you wish to train the model and evaluate it, run: 22 | 23 | ```bash 24 | jupyter notebook 25 | ``` 26 | and open up the "train.ipynb" file. 27 | 28 | The model, datasets, and other utilities can be found in the src folder. -------------------------------------------------------------------------------- /src/img_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torchvision.transforms.functional as F 3 | import numpy as np 4 | 5 | 6 | class ResizePad(object): 7 | def __init__(self, size): 8 | self.size = (size, size) 9 | 10 | def __call__(self, image, target): 11 | # Padding/resizing image 12 | resize_image = square_pad(image).resize(self.size, Image.BILINEAR) 13 | resize_image = F.to_tensor(resize_image) 14 | 15 | # Padding/resizing masks 16 | masks = target["masks"] 17 | new_masks = [] 18 | for i in range(len(masks)): 19 | mask = target["masks"][i] 20 | pil_mask = Image.fromarray(mask.numpy(), "L") 21 | resize_mask = square_pad(pil_mask, image_type="L").resize(self.size, Image.NEAREST) 22 | resize_mask = np.array(resize_mask, dtype=np.uint8) 23 | resize_mask = np.where(resize_mask > 0, 1, 0) 24 | new_masks.append(F.to_tensor(resize_mask).squeeze()) 25 | 26 | target["masks"] = new_masks 27 | 28 | return resize_image, target 29 | 30 | 31 | def square_pad(img, image_type="RGB"): 32 | width, height = img.size 33 | size = max(width, height) 34 | square_img = Image.new(image_type, (size, size)) 35 | 36 | width_center = int((size - width) / 2) 37 | height_center = int((size - height) / 2) 38 | square_img.paste(img, (width_center, height_center)) 39 | 40 | return square_img 41 | -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torchvision.datasets import ImageFolder, DatasetFolder 5 | from torchvision.transforms import ToTensor, Normalize, Compose 6 | from torch.utils.data import Dataset 7 | 8 | from PIL import Image 9 | import numpy as np 10 | 11 | 12 | class iCoSegDataset(Dataset): 13 | def __init__(self, images_path, masks_path, image_size=224): 14 | self.images_path = images_path 15 | self.masks_path = masks_path 16 | self.image_size = (image_size, image_size) 17 | 18 | preprocess = Compose([ 19 | ToTensor() 20 | ]) 21 | 22 | self.images = DatasetFolder( 23 | root=images_path, 24 | loader=self.iCoSegImageLoader, 25 | extensions=("jpg"), 26 | transform=preprocess 27 | ) 28 | self.masks = DatasetFolder( 29 | root=masks_path, 30 | loader=self.iCoSegMaskLoader, 31 | extensions=("png"), 32 | ) 33 | 34 | self.length = len(self.images) 35 | 36 | def iCoSegImageLoader(self, path): 37 | image = Image.open(path) 38 | image = image.resize(self.image_size) 39 | #image = np.array(image, dtype=np.float32) 40 | 41 | return image 42 | 43 | def iCoSegMaskLoader(self, path): 44 | mask = Image.open(path) 45 | mask = mask.resize(self.image_size) 46 | mask = np.array(mask, dtype=np.uint8) 47 | 48 | return mask 49 | 50 | def __getitem__(self, index): 51 | 52 | if type(index) == torch.Tensor: 53 | index = index.item() 54 | 55 | image, image_label = self.images[index] 56 | mask, mask_label = self.masks[index] 57 | 58 | sample = { 59 | "image": image, 60 | "mask": mask, 61 | "label": image_label 62 | } 63 | 64 | return sample 65 | 66 | def __len__(self): 67 | return self.length 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Project specific 132 | .vscode/ 133 | *_data/ 134 | pycocotools 135 | weights/ -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | import torch.nn.functional as F 5 | 6 | 7 | class CoSegNet(nn.Module): 8 | def __init__(self, input_channels=3, output_channels=2): 9 | super(CoSegNet, self).__init__() 10 | 11 | self.input_channels = input_channels # 3 = RGB 12 | self.output_channels = output_channels # 2 = Foreground + Background 13 | 14 | # Encoder 15 | # Using pretrained VGG16 as the backbone 16 | self.encoder = models.vgg16_bn(pretrained=True).features 17 | 18 | # Decoder 19 | self.decoder3 = nn.Sequential( 20 | nn.Upsample(scale_factor=2, mode="nearest"), 21 | nn.ConvTranspose2d(512, 512, kernel_size=3, padding=1), 22 | nn.ConvTranspose2d(512, 512, kernel_size=3, padding=1), 23 | nn.ConvTranspose2d(512, 512, kernel_size=3, padding=1), 24 | nn.BatchNorm2d(512), 25 | nn.ReLU() 26 | ) 27 | self.decoder6 = nn.Sequential( 28 | nn.Upsample(scale_factor=2, mode="nearest"), 29 | nn.ConvTranspose2d(512, 512, kernel_size=3, padding=1), 30 | nn.ConvTranspose2d(512, 512, kernel_size=3, padding=1), 31 | nn.ConvTranspose2d(512, 512, kernel_size=3, padding=1), 32 | nn.BatchNorm2d(512), 33 | nn.ReLU() 34 | ) 35 | self.decoder9 = nn.Sequential( 36 | nn.Upsample(scale_factor=2, mode="nearest"), 37 | nn.ConvTranspose2d(512, 256, kernel_size=3, padding=1), 38 | nn.ConvTranspose2d(256, 256, kernel_size=3, padding=1), 39 | nn.ConvTranspose2d(256, 256, kernel_size=3, padding=1), 40 | nn.BatchNorm2d(256), 41 | nn.ReLU() 42 | ) 43 | self.decoder11 = nn.Sequential( 44 | nn.Upsample(scale_factor=2, mode="nearest"), 45 | nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1), 46 | nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1), 47 | nn.BatchNorm2d(128), 48 | nn.ReLU() 49 | ) 50 | self.decoder13 = nn.Sequential( 51 | nn.Upsample(scale_factor=2, mode="nearest"), 52 | nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1), 53 | nn.ConvTranspose2d(64, 1, kernel_size=3, padding=1), 54 | nn.BatchNorm2d(1), 55 | nn.Sigmoid() 56 | ) 57 | 58 | # Metric Net 59 | self.metricnet = SiameseMetricNet() 60 | 61 | # Decision Net 62 | self.decisionnet = SiameseDecisionNet() 63 | 64 | 65 | def forward(self, imageA, imageB): 66 | # Getting VGG features for the first stage of the encoding 67 | featuresA = self.encoder(imageA) 68 | featuresB = self.encoder(imageB) 69 | 70 | # Decoding 71 | # 3 layers 72 | featuresA = self.decoder3(featuresA) 73 | featuresB = self.decoder3(featuresB) 74 | # 6 layers 75 | featuresA = self.decoder6(featuresA) 76 | featuresB = self.decoder6(featuresB) 77 | # 9 layers 78 | featuresA = self.decoder9(featuresA) 79 | featuresB = self.decoder9(featuresB) 80 | 81 | # Siamese Metric Net 82 | metric_featureA = self.metricnet(featuresA) 83 | metric_featureB = self.metricnet(featuresB) 84 | 85 | # Siamese Decision Net 86 | # Concatenating the two vectors to make a single prediction vector 87 | decision_vector = torch.cat((metric_featureA, metric_featureB), dim=0) 88 | decision = self.decisionnet(decision_vector) 89 | 90 | # 11 layers 91 | featuresA = self.decoder11(featuresA) 92 | featuresB = self.decoder11(featuresB) 93 | # 13 layers 94 | featuresA = self.decoder13(featuresA) 95 | featuresB = self.decoder13(featuresB) 96 | 97 | return featuresA, featuresB, metric_featureA, metric_featureB, decision 98 | 99 | 100 | class SiameseMetricNet(nn.Module): 101 | def __init__(self): 102 | super(SiameseMetricNet, self).__init__() 103 | self.metricnet = nn.Sequential( 104 | nn.Linear(256, 128), 105 | nn.ReLU(), 106 | nn.Linear(128, 64) 107 | ) 108 | 109 | 110 | def forward(self, feature): 111 | # Global Average Pooling (GAP) 112 | gap_feature_vector = F.avg_pool2d(feature, kernel_size=feature.size()[2:]) 113 | gap_feature_vector = gap_feature_vector.squeeze() 114 | 115 | return self.metricnet(gap_feature_vector) 116 | 117 | 118 | class SiameseDecisionNet(nn.Module): 119 | def __init__(self): 120 | super(SiameseDecisionNet, self).__init__() 121 | self.decisionnet = nn.Sequential( 122 | nn.Linear(128, 32), 123 | nn.ReLU(), 124 | nn.Linear(32, 1), 125 | nn.Sigmoid() 126 | ) 127 | 128 | def forward(self, vector): 129 | return self.decisionnet(vector) 130 | -------------------------------------------------------------------------------- /train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Siamese Segmentation Training" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Imports" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# Jupyter Notebook utils\n", 24 | "%load_ext autoreload\n", 25 | "%matplotlib inline" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 19, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import time\n", 35 | "import os\n", 36 | "\n", 37 | "from tqdm.notebook import tqdm\n", 38 | "import numpy as np\n", 39 | "\n", 40 | "import torch\n", 41 | "import torch.nn as nn\n", 42 | "from torch.utils.data import DataLoader\n", 43 | "import torchvision.transforms.functional as F\n", 44 | "\n", 45 | "import matplotlib.pyplot as plt\n", 46 | "\n", 47 | "import src.img_utils as utils\n", 48 | "import src.model as siam_models\n", 49 | "from src.datasets import iCoSegDataset\n", 50 | "%autoreload 2" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "## Dataset Loading" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "### Constants" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 51, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "BATCH_SIZE = 2 # TODO: Allow changing\n", 74 | "NUM_WORKERS = 1\n", 75 | "\n", 76 | "IMAGES_PATH = \"./icoseg_data/images_subset\"\n", 77 | "MASKS_PATH = \"./icoseg_data/ground_truth_subset\"\n", 78 | "\n", 79 | "VALIDATION_SPLIT = 0.2 # What % the validation set should be\n", 80 | "SHUFFLE = True" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "### Instantiation" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 55, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "dataset = iCoSegDataset(IMAGES_PATH, MASKS_PATH)\n", 97 | "\n", 98 | "# Test/train dataset split\n", 99 | "# Default is 80% training, 20% testing\n", 100 | "dataset_length = len(dataset)\n", 101 | "train_size = int((1.0 - VALIDATION_SPLIT) * dataset_length)\n", 102 | "test_size = dataset_length - train_size\n", 103 | "train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])\n", 104 | "\n", 105 | "train_loader = DataLoader(\n", 106 | " train_dataset,\n", 107 | " batch_size=BATCH_SIZE,\n", 108 | " num_workers=NUM_WORKERS,\n", 109 | " shuffle=SHUFFLE\n", 110 | ")\n", 111 | "test_loader = DataLoader(\n", 112 | " test_dataset,\n", 113 | " batch_size=BATCH_SIZE,\n", 114 | " num_workers=NUM_WORKERS,\n", 115 | " shuffle=SHUFFLE,\n", 116 | " drop_last=True\n", 117 | ")" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "## Training" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "### Constants" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 95, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "LEARNING_RATE = 0.00001\n", 141 | "MOMENTUM = 0.9\n", 142 | "WEIGHT_DECAY = 0.0004\n", 143 | "\n", 144 | "SAVE_WEIGHTS = True\n", 145 | "WEIGHTS_PATH = \"./weights\"\n", 146 | "\n", 147 | "MARGIN_ALPHA = 0.25 # Parameter used in loss calculation\n", 148 | "\n", 149 | "EPOCHS = 100" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | "### Preparation" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 102, 162 | "metadata": {}, 163 | "outputs": [ 164 | { 165 | "name": "stdout", 166 | "output_type": "stream", 167 | "text": [ 168 | "Using CUDA 0 as device\n" 169 | ] 170 | } 171 | ], 172 | "source": [ 173 | "# Instantiation\n", 174 | "device = None\n", 175 | "if torch.cuda.is_available():\n", 176 | " print(\"Using CUDA 0 as device\")\n", 177 | " device = torch.device(\"cuda:0\")\n", 178 | "else:\n", 179 | " print(\"Using CPU as device\")\n", 180 | " device = torch.device(\"cpu\")\n", 181 | "\n", 182 | "prev_loss = float(\"inf\")\n", 183 | "model = siam_models.CoSegNet();\n", 184 | "\n", 185 | "optimizer = torch.optim.SGD(\n", 186 | " model.parameters(),\n", 187 | " lr=LEARNING_RATE,\n", 188 | " momentum=MOMENTUM,\n", 189 | " weight_decay=WEIGHT_DECAY\n", 190 | ")\n", 191 | "model.to(device);\n", 192 | "\n", 193 | "# Citerions\n", 194 | "criterion_bce = nn.BCELoss()\n", 195 | "# criterion_triplet = nn.TripletMarginLoss(margin=MARGIN_ALPHA)\n", 196 | "# Using Cosine Embedding Loss as a similar loss to triplet\n", 197 | "# due to prescence of pairs\n", 198 | "criterion_cel = nn.CosineEmbeddingLoss(margin=MARGIN_ALPHA)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "### Training Loop" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 105, 211 | "metadata": {}, 212 | "outputs": [ 213 | { 214 | "data": { 215 | "application/vnd.jupyter.widget-view+json": { 216 | "model_id": "01f29908e30a465985ee9756b26c0852", 217 | "version_major": 2, 218 | "version_minor": 0 219 | }, 220 | "text/plain": [ 221 | "HBox(children=(FloatProgress(value=0.0, description='Epoch Progress: ', max=50.0, style=ProgressStyle(descript…" 222 | ] 223 | }, 224 | "metadata": {}, 225 | "output_type": "display_data" 226 | }, 227 | { 228 | "data": { 229 | "application/vnd.jupyter.widget-view+json": { 230 | "model_id": "860e261978874329946c379e79dfeb87", 231 | "version_major": 2, 232 | "version_minor": 0 233 | }, 234 | "text/plain": [ 235 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 236 | ] 237 | }, 238 | "metadata": {}, 239 | "output_type": "display_data" 240 | }, 241 | { 242 | "name": "stdout", 243 | "output_type": "stream", 244 | "text": [ 245 | "\n", 246 | "Total Loss: 1.6927924305200577\n", 247 | "Saving Model\n" 248 | ] 249 | }, 250 | { 251 | "data": { 252 | "application/vnd.jupyter.widget-view+json": { 253 | "model_id": "4da0402a2dd043adad9c0e9104d5ac1a", 254 | "version_major": 2, 255 | "version_minor": 0 256 | }, 257 | "text/plain": [ 258 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 259 | ] 260 | }, 261 | "metadata": {}, 262 | "output_type": "display_data" 263 | }, 264 | { 265 | "name": "stdout", 266 | "output_type": "stream", 267 | "text": [ 268 | "\n", 269 | "Total Loss: 1.6602893471717834\n", 270 | "Saving Model\n" 271 | ] 272 | }, 273 | { 274 | "data": { 275 | "application/vnd.jupyter.widget-view+json": { 276 | "model_id": "78ebe88ebeff4f83804dd2e6168c525b", 277 | "version_major": 2, 278 | "version_minor": 0 279 | }, 280 | "text/plain": [ 281 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 282 | ] 283 | }, 284 | "metadata": {}, 285 | "output_type": "display_data" 286 | }, 287 | { 288 | "name": "stdout", 289 | "output_type": "stream", 290 | "text": [ 291 | "\n", 292 | "Total Loss: 1.6928275525569916\n" 293 | ] 294 | }, 295 | { 296 | "data": { 297 | "application/vnd.jupyter.widget-view+json": { 298 | "model_id": "01cce4422d334d9483c0cb30a9687be1", 299 | "version_major": 2, 300 | "version_minor": 0 301 | }, 302 | "text/plain": [ 303 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 304 | ] 305 | }, 306 | "metadata": {}, 307 | "output_type": "display_data" 308 | }, 309 | { 310 | "name": "stdout", 311 | "output_type": "stream", 312 | "text": [ 313 | "\n", 314 | "Total Loss: 1.7734732776880264\n" 315 | ] 316 | }, 317 | { 318 | "data": { 319 | "application/vnd.jupyter.widget-view+json": { 320 | "model_id": "190de91f86434c4fa6b3bcab60c76aec", 321 | "version_major": 2, 322 | "version_minor": 0 323 | }, 324 | "text/plain": [ 325 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 326 | ] 327 | }, 328 | "metadata": {}, 329 | "output_type": "display_data" 330 | }, 331 | { 332 | "name": "stdout", 333 | "output_type": "stream", 334 | "text": [ 335 | "\n", 336 | "Total Loss: 1.870010793209076\n" 337 | ] 338 | }, 339 | { 340 | "data": { 341 | "application/vnd.jupyter.widget-view+json": { 342 | "model_id": "097adfc34a5246a49f59365d4ff5a8b6", 343 | "version_major": 2, 344 | "version_minor": 0 345 | }, 346 | "text/plain": [ 347 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 348 | ] 349 | }, 350 | "metadata": {}, 351 | "output_type": "display_data" 352 | }, 353 | { 354 | "name": "stdout", 355 | "output_type": "stream", 356 | "text": [ 357 | "\n", 358 | "Total Loss: 1.945670872926712\n" 359 | ] 360 | }, 361 | { 362 | "data": { 363 | "application/vnd.jupyter.widget-view+json": { 364 | "model_id": "036de6c71d8840af89e0eaa5b3390dfd", 365 | "version_major": 2, 366 | "version_minor": 0 367 | }, 368 | "text/plain": [ 369 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 370 | ] 371 | }, 372 | "metadata": {}, 373 | "output_type": "display_data" 374 | }, 375 | { 376 | "name": "stdout", 377 | "output_type": "stream", 378 | "text": [ 379 | "\n", 380 | "Total Loss: 1.951790526509285\n" 381 | ] 382 | }, 383 | { 384 | "data": { 385 | "application/vnd.jupyter.widget-view+json": { 386 | "model_id": "2cc60928dd554736976eeea3f6a23d1d", 387 | "version_major": 2, 388 | "version_minor": 0 389 | }, 390 | "text/plain": [ 391 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 392 | ] 393 | }, 394 | "metadata": {}, 395 | "output_type": "display_data" 396 | }, 397 | { 398 | "name": "stdout", 399 | "output_type": "stream", 400 | "text": [ 401 | "\n", 402 | "Total Loss: 1.8777571022510529\n" 403 | ] 404 | }, 405 | { 406 | "data": { 407 | "application/vnd.jupyter.widget-view+json": { 408 | "model_id": "fa34d29435dd4cdaa9cc0496bdc445e8", 409 | "version_major": 2, 410 | "version_minor": 0 411 | }, 412 | "text/plain": [ 413 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 414 | ] 415 | }, 416 | "metadata": {}, 417 | "output_type": "display_data" 418 | }, 419 | { 420 | "name": "stdout", 421 | "output_type": "stream", 422 | "text": [ 423 | "\n", 424 | "Total Loss: 1.7375290542840958\n" 425 | ] 426 | }, 427 | { 428 | "data": { 429 | "application/vnd.jupyter.widget-view+json": { 430 | "model_id": "110ed119eb7e493c9766c0fa4adc1bcc", 431 | "version_major": 2, 432 | "version_minor": 0 433 | }, 434 | "text/plain": [ 435 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 436 | ] 437 | }, 438 | "metadata": {}, 439 | "output_type": "display_data" 440 | }, 441 | { 442 | "name": "stdout", 443 | "output_type": "stream", 444 | "text": [ 445 | "\n", 446 | "Total Loss: 1.566610112786293\n", 447 | "Saving Model\n" 448 | ] 449 | }, 450 | { 451 | "data": { 452 | "application/vnd.jupyter.widget-view+json": { 453 | "model_id": "6306f0446eb944149c52fbc3fb92b70a", 454 | "version_major": 2, 455 | "version_minor": 0 456 | }, 457 | "text/plain": [ 458 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 459 | ] 460 | }, 461 | "metadata": {}, 462 | "output_type": "display_data" 463 | }, 464 | { 465 | "name": "stdout", 466 | "output_type": "stream", 467 | "text": [ 468 | "\n", 469 | "Total Loss: 1.4196221083402634\n", 470 | "Saving Model\n" 471 | ] 472 | }, 473 | { 474 | "data": { 475 | "application/vnd.jupyter.widget-view+json": { 476 | "model_id": "f99c51121baa4d568909bf8824806826", 477 | "version_major": 2, 478 | "version_minor": 0 479 | }, 480 | "text/plain": [ 481 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 482 | ] 483 | }, 484 | "metadata": {}, 485 | "output_type": "display_data" 486 | }, 487 | { 488 | "name": "stdout", 489 | "output_type": "stream", 490 | "text": [ 491 | "\n", 492 | "Total Loss: 1.330006442964077\n", 493 | "Saving Model\n" 494 | ] 495 | }, 496 | { 497 | "data": { 498 | "application/vnd.jupyter.widget-view+json": { 499 | "model_id": "bb85e909174c4bba925a0ece2e905291", 500 | "version_major": 2, 501 | "version_minor": 0 502 | }, 503 | "text/plain": [ 504 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 505 | ] 506 | }, 507 | "metadata": {}, 508 | "output_type": "display_data" 509 | }, 510 | { 511 | "name": "stdout", 512 | "output_type": "stream", 513 | "text": [ 514 | "\n", 515 | "Total Loss: 1.3305239230394363\n" 516 | ] 517 | }, 518 | { 519 | "data": { 520 | "application/vnd.jupyter.widget-view+json": { 521 | "model_id": "ad569c2c7afe478e9a3a0c8855064303", 522 | "version_major": 2, 523 | "version_minor": 0 524 | }, 525 | "text/plain": [ 526 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 527 | ] 528 | }, 529 | "metadata": {}, 530 | "output_type": "display_data" 531 | }, 532 | { 533 | "name": "stdout", 534 | "output_type": "stream", 535 | "text": [ 536 | "\n", 537 | "Total Loss: 1.408533088862896\n" 538 | ] 539 | }, 540 | { 541 | "data": { 542 | "application/vnd.jupyter.widget-view+json": { 543 | "model_id": "ee184925fa234590938232e565aa5c3d", 544 | "version_major": 2, 545 | "version_minor": 0 546 | }, 547 | "text/plain": [ 548 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 549 | ] 550 | }, 551 | "metadata": {}, 552 | "output_type": "display_data" 553 | }, 554 | { 555 | "name": "stdout", 556 | "output_type": "stream", 557 | "text": [ 558 | "\n", 559 | "Total Loss: 1.5383499562740326\n" 560 | ] 561 | }, 562 | { 563 | "data": { 564 | "application/vnd.jupyter.widget-view+json": { 565 | "model_id": "474f9578f45b4526b32801d9231b900d", 566 | "version_major": 2, 567 | "version_minor": 0 568 | }, 569 | "text/plain": [ 570 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 571 | ] 572 | }, 573 | "metadata": {}, 574 | "output_type": "display_data" 575 | }, 576 | { 577 | "name": "stdout", 578 | "output_type": "stream", 579 | "text": [ 580 | "\n", 581 | "Total Loss: 1.6761837750673294\n" 582 | ] 583 | }, 584 | { 585 | "data": { 586 | "application/vnd.jupyter.widget-view+json": { 587 | "model_id": "a1840089c5754e1cb6fd7de6ca82bf68", 588 | "version_major": 2, 589 | "version_minor": 0 590 | }, 591 | "text/plain": [ 592 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 593 | ] 594 | }, 595 | "metadata": {}, 596 | "output_type": "display_data" 597 | }, 598 | { 599 | "name": "stdout", 600 | "output_type": "stream", 601 | "text": [ 602 | "\n", 603 | "Total Loss: 1.7656275629997253\n" 604 | ] 605 | }, 606 | { 607 | "data": { 608 | "application/vnd.jupyter.widget-view+json": { 609 | "model_id": "0df2fe5e4bb7491b9810ea11793d2e3c", 610 | "version_major": 2, 611 | "version_minor": 0 612 | }, 613 | "text/plain": [ 614 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 615 | ] 616 | }, 617 | "metadata": {}, 618 | "output_type": "display_data" 619 | }, 620 | { 621 | "name": "stdout", 622 | "output_type": "stream", 623 | "text": [ 624 | "\n", 625 | "Total Loss: 1.8159170299768448\n" 626 | ] 627 | }, 628 | { 629 | "data": { 630 | "application/vnd.jupyter.widget-view+json": { 631 | "model_id": "9aaf667b0cac4541814445775b0ce861", 632 | "version_major": 2, 633 | "version_minor": 0 634 | }, 635 | "text/plain": [ 636 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 637 | ] 638 | }, 639 | "metadata": {}, 640 | "output_type": "display_data" 641 | }, 642 | { 643 | "name": "stdout", 644 | "output_type": "stream", 645 | "text": [ 646 | "\n", 647 | "Total Loss: 1.796944484114647\n" 648 | ] 649 | }, 650 | { 651 | "data": { 652 | "application/vnd.jupyter.widget-view+json": { 653 | "model_id": "8f05fc514b9a4b5d8f158def0093393d", 654 | "version_major": 2, 655 | "version_minor": 0 656 | }, 657 | "text/plain": [ 658 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 659 | ] 660 | }, 661 | "metadata": {}, 662 | "output_type": "display_data" 663 | }, 664 | { 665 | "name": "stdout", 666 | "output_type": "stream", 667 | "text": [ 668 | "\n", 669 | "Total Loss: 1.7244201302528381\n" 670 | ] 671 | }, 672 | { 673 | "data": { 674 | "application/vnd.jupyter.widget-view+json": { 675 | "model_id": "5a684449f8d6453eb3510be151c64c1f", 676 | "version_major": 2, 677 | "version_minor": 0 678 | }, 679 | "text/plain": [ 680 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 681 | ] 682 | }, 683 | "metadata": {}, 684 | "output_type": "display_data" 685 | }, 686 | { 687 | "name": "stdout", 688 | "output_type": "stream", 689 | "text": [ 690 | "\n", 691 | "Total Loss: 1.6041111201047897\n" 692 | ] 693 | }, 694 | { 695 | "data": { 696 | "application/vnd.jupyter.widget-view+json": { 697 | "model_id": "85e4a6ed4837476c87978c44e00f8b8e", 698 | "version_major": 2, 699 | "version_minor": 0 700 | }, 701 | "text/plain": [ 702 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 703 | ] 704 | }, 705 | "metadata": {}, 706 | "output_type": "display_data" 707 | }, 708 | { 709 | "name": "stdout", 710 | "output_type": "stream", 711 | "text": [ 712 | "\n", 713 | "Total Loss: 1.4716791659593582\n" 714 | ] 715 | }, 716 | { 717 | "data": { 718 | "application/vnd.jupyter.widget-view+json": { 719 | "model_id": "11b02936116342318584e0e80f0dc5e4", 720 | "version_major": 2, 721 | "version_minor": 0 722 | }, 723 | "text/plain": [ 724 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 725 | ] 726 | }, 727 | "metadata": {}, 728 | "output_type": "display_data" 729 | }, 730 | { 731 | "name": "stdout", 732 | "output_type": "stream", 733 | "text": [ 734 | "\n", 735 | "Total Loss: 1.345517821609974\n" 736 | ] 737 | }, 738 | { 739 | "data": { 740 | "application/vnd.jupyter.widget-view+json": { 741 | "model_id": "86deb926ec4e49fb92efcf9ef9bc4d82", 742 | "version_major": 2, 743 | "version_minor": 0 744 | }, 745 | "text/plain": [ 746 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 747 | ] 748 | }, 749 | "metadata": {}, 750 | "output_type": "display_data" 751 | }, 752 | { 753 | "name": "stdout", 754 | "output_type": "stream", 755 | "text": [ 756 | "\n", 757 | "Total Loss: 1.256837397813797\n", 758 | "Saving Model\n" 759 | ] 760 | }, 761 | { 762 | "data": { 763 | "application/vnd.jupyter.widget-view+json": { 764 | "model_id": "239cecf957184cd8a1c35a1432c18e7a", 765 | "version_major": 2, 766 | "version_minor": 0 767 | }, 768 | "text/plain": [ 769 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 770 | ] 771 | }, 772 | "metadata": {}, 773 | "output_type": "display_data" 774 | }, 775 | { 776 | "name": "stdout", 777 | "output_type": "stream", 778 | "text": [ 779 | "\n", 780 | "Total Loss: 1.221520982682705\n", 781 | "Saving Model\n" 782 | ] 783 | }, 784 | { 785 | "data": { 786 | "application/vnd.jupyter.widget-view+json": { 787 | "model_id": "ef2e00b7def7428a92fd5564311b109d", 788 | "version_major": 2, 789 | "version_minor": 0 790 | }, 791 | "text/plain": [ 792 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 793 | ] 794 | }, 795 | "metadata": {}, 796 | "output_type": "display_data" 797 | }, 798 | { 799 | "name": "stdout", 800 | "output_type": "stream", 801 | "text": [ 802 | "\n", 803 | "Total Loss: 1.251066729426384\n" 804 | ] 805 | }, 806 | { 807 | "data": { 808 | "application/vnd.jupyter.widget-view+json": { 809 | "model_id": "4c2d40d4ede14ce5950a0b9c3bc5bb9f", 810 | "version_major": 2, 811 | "version_minor": 0 812 | }, 813 | "text/plain": [ 814 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 815 | ] 816 | }, 817 | "metadata": {}, 818 | "output_type": "display_data" 819 | }, 820 | { 821 | "name": "stdout", 822 | "output_type": "stream", 823 | "text": [ 824 | "\n", 825 | "Total Loss: 1.306217521429062\n" 826 | ] 827 | }, 828 | { 829 | "data": { 830 | "application/vnd.jupyter.widget-view+json": { 831 | "model_id": "3eaaabb1b0ae484fb19f96026437ef5c", 832 | "version_major": 2, 833 | "version_minor": 0 834 | }, 835 | "text/plain": [ 836 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 837 | ] 838 | }, 839 | "metadata": {}, 840 | "output_type": "display_data" 841 | }, 842 | { 843 | "name": "stdout", 844 | "output_type": "stream", 845 | "text": [ 846 | "\n", 847 | "Total Loss: 1.393511027097702\n" 848 | ] 849 | }, 850 | { 851 | "data": { 852 | "application/vnd.jupyter.widget-view+json": { 853 | "model_id": "9e831d1950124267bde89892dbfb1e38", 854 | "version_major": 2, 855 | "version_minor": 0 856 | }, 857 | "text/plain": [ 858 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 859 | ] 860 | }, 861 | "metadata": {}, 862 | "output_type": "display_data" 863 | }, 864 | { 865 | "name": "stdout", 866 | "output_type": "stream", 867 | "text": [ 868 | "\n", 869 | "Total Loss: 1.4632043689489365\n" 870 | ] 871 | }, 872 | { 873 | "data": { 874 | "application/vnd.jupyter.widget-view+json": { 875 | "model_id": "e384add40a6441f8801e30d43251e87c", 876 | "version_major": 2, 877 | "version_minor": 0 878 | }, 879 | "text/plain": [ 880 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 881 | ] 882 | }, 883 | "metadata": {}, 884 | "output_type": "display_data" 885 | }, 886 | { 887 | "name": "stdout", 888 | "output_type": "stream", 889 | "text": [ 890 | "\n", 891 | "Total Loss: 1.4949854537844658\n" 892 | ] 893 | }, 894 | { 895 | "data": { 896 | "application/vnd.jupyter.widget-view+json": { 897 | "model_id": "f95237f3b804431891893182386df041", 898 | "version_major": 2, 899 | "version_minor": 0 900 | }, 901 | "text/plain": [ 902 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 903 | ] 904 | }, 905 | "metadata": {}, 906 | "output_type": "display_data" 907 | }, 908 | { 909 | "name": "stdout", 910 | "output_type": "stream", 911 | "text": [ 912 | "\n", 913 | "Total Loss: 1.4919632151722908\n" 914 | ] 915 | }, 916 | { 917 | "data": { 918 | "application/vnd.jupyter.widget-view+json": { 919 | "model_id": "0e7e02c30bf140669743ee03a2d34e5d", 920 | "version_major": 2, 921 | "version_minor": 0 922 | }, 923 | "text/plain": [ 924 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 925 | ] 926 | }, 927 | "metadata": {}, 928 | "output_type": "display_data" 929 | }, 930 | { 931 | "name": "stdout", 932 | "output_type": "stream", 933 | "text": [ 934 | "\n", 935 | "Total Loss: 1.4521089643239975\n" 936 | ] 937 | }, 938 | { 939 | "data": { 940 | "application/vnd.jupyter.widget-view+json": { 941 | "model_id": "5e8777a8bf08467da4d42f917a9a7c9d", 942 | "version_major": 2, 943 | "version_minor": 0 944 | }, 945 | "text/plain": [ 946 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 947 | ] 948 | }, 949 | "metadata": {}, 950 | "output_type": "display_data" 951 | }, 952 | { 953 | "name": "stdout", 954 | "output_type": "stream", 955 | "text": [ 956 | "\n", 957 | "Total Loss: 1.3882318809628487\n" 958 | ] 959 | }, 960 | { 961 | "data": { 962 | "application/vnd.jupyter.widget-view+json": { 963 | "model_id": "055e5b81922746b392ce0f3c2d8ed41e", 964 | "version_major": 2, 965 | "version_minor": 0 966 | }, 967 | "text/plain": [ 968 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 969 | ] 970 | }, 971 | "metadata": {}, 972 | "output_type": "display_data" 973 | }, 974 | { 975 | "name": "stdout", 976 | "output_type": "stream", 977 | "text": [ 978 | "\n", 979 | "Total Loss: 1.321122132241726\n" 980 | ] 981 | }, 982 | { 983 | "data": { 984 | "application/vnd.jupyter.widget-view+json": { 985 | "model_id": "a317e09822ca45e5a00714a7640a4765", 986 | "version_major": 2, 987 | "version_minor": 0 988 | }, 989 | "text/plain": [ 990 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 991 | ] 992 | }, 993 | "metadata": {}, 994 | "output_type": "display_data" 995 | }, 996 | { 997 | "name": "stdout", 998 | "output_type": "stream", 999 | "text": [ 1000 | "\n", 1001 | "Total Loss: 1.2985627725720406\n" 1002 | ] 1003 | }, 1004 | { 1005 | "data": { 1006 | "application/vnd.jupyter.widget-view+json": { 1007 | "model_id": "2d6178b4faf0480685be79ccf1eeb72f", 1008 | "version_major": 2, 1009 | "version_minor": 0 1010 | }, 1011 | "text/plain": [ 1012 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 1013 | ] 1014 | }, 1015 | "metadata": {}, 1016 | "output_type": "display_data" 1017 | }, 1018 | { 1019 | "name": "stdout", 1020 | "output_type": "stream", 1021 | "text": [ 1022 | "\n", 1023 | "Total Loss: 1.2938548177480698\n" 1024 | ] 1025 | }, 1026 | { 1027 | "data": { 1028 | "application/vnd.jupyter.widget-view+json": { 1029 | "model_id": "d5445ee8fb2c4b06bbeffb926f39c5d4", 1030 | "version_major": 2, 1031 | "version_minor": 0 1032 | }, 1033 | "text/plain": [ 1034 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 1035 | ] 1036 | }, 1037 | "metadata": {}, 1038 | "output_type": "display_data" 1039 | }, 1040 | { 1041 | "name": "stdout", 1042 | "output_type": "stream", 1043 | "text": [ 1044 | "\n", 1045 | "Total Loss: 1.3346440717577934\n" 1046 | ] 1047 | }, 1048 | { 1049 | "data": { 1050 | "application/vnd.jupyter.widget-view+json": { 1051 | "model_id": "970b7fd76a4c442ab17cf3814bdc41c1", 1052 | "version_major": 2, 1053 | "version_minor": 0 1054 | }, 1055 | "text/plain": [ 1056 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 1057 | ] 1058 | }, 1059 | "metadata": {}, 1060 | "output_type": "display_data" 1061 | }, 1062 | { 1063 | "name": "stdout", 1064 | "output_type": "stream", 1065 | "text": [ 1066 | "\n", 1067 | "Total Loss: 1.4024389162659645\n" 1068 | ] 1069 | }, 1070 | { 1071 | "data": { 1072 | "application/vnd.jupyter.widget-view+json": { 1073 | "model_id": "a38096f085994872ad7f81e83fcbfe94", 1074 | "version_major": 2, 1075 | "version_minor": 0 1076 | }, 1077 | "text/plain": [ 1078 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 1079 | ] 1080 | }, 1081 | "metadata": {}, 1082 | "output_type": "display_data" 1083 | }, 1084 | { 1085 | "name": "stdout", 1086 | "output_type": "stream", 1087 | "text": [ 1088 | "\n", 1089 | "Total Loss: 1.4722695127129555\n" 1090 | ] 1091 | }, 1092 | { 1093 | "data": { 1094 | "application/vnd.jupyter.widget-view+json": { 1095 | "model_id": "0eb2b03a707e4824a797bdc5efb793a7", 1096 | "version_major": 2, 1097 | "version_minor": 0 1098 | }, 1099 | "text/plain": [ 1100 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 1101 | ] 1102 | }, 1103 | "metadata": {}, 1104 | "output_type": "display_data" 1105 | }, 1106 | { 1107 | "name": "stdout", 1108 | "output_type": "stream", 1109 | "text": [ 1110 | "\n", 1111 | "Total Loss: 1.5289373993873596\n" 1112 | ] 1113 | }, 1114 | { 1115 | "data": { 1116 | "application/vnd.jupyter.widget-view+json": { 1117 | "model_id": "6b538dc687994c11b2f4ebaa498326a4", 1118 | "version_major": 2, 1119 | "version_minor": 0 1120 | }, 1121 | "text/plain": [ 1122 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 1123 | ] 1124 | }, 1125 | "metadata": {}, 1126 | "output_type": "display_data" 1127 | }, 1128 | { 1129 | "name": "stdout", 1130 | "output_type": "stream", 1131 | "text": [ 1132 | "\n", 1133 | "Total Loss: 1.5576649606227875\n" 1134 | ] 1135 | }, 1136 | { 1137 | "data": { 1138 | "application/vnd.jupyter.widget-view+json": { 1139 | "model_id": "2427645116d741c0915657f91fdb3ec5", 1140 | "version_major": 2, 1141 | "version_minor": 0 1142 | }, 1143 | "text/plain": [ 1144 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 1145 | ] 1146 | }, 1147 | "metadata": {}, 1148 | "output_type": "display_data" 1149 | }, 1150 | { 1151 | "name": "stdout", 1152 | "output_type": "stream", 1153 | "text": [ 1154 | "\n", 1155 | "Total Loss: 1.5467615649104118\n" 1156 | ] 1157 | }, 1158 | { 1159 | "data": { 1160 | "application/vnd.jupyter.widget-view+json": { 1161 | "model_id": "e5f3a75fc8ec4f2797bb0a8d8e132dda", 1162 | "version_major": 2, 1163 | "version_minor": 0 1164 | }, 1165 | "text/plain": [ 1166 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 1167 | ] 1168 | }, 1169 | "metadata": {}, 1170 | "output_type": "display_data" 1171 | }, 1172 | { 1173 | "name": "stdout", 1174 | "output_type": "stream", 1175 | "text": [ 1176 | "\n", 1177 | "Total Loss: 1.495627447962761\n" 1178 | ] 1179 | }, 1180 | { 1181 | "data": { 1182 | "application/vnd.jupyter.widget-view+json": { 1183 | "model_id": "a177cfcbf9ae497fbd8206e0a6c86c9c", 1184 | "version_major": 2, 1185 | "version_minor": 0 1186 | }, 1187 | "text/plain": [ 1188 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 1189 | ] 1190 | }, 1191 | "metadata": {}, 1192 | "output_type": "display_data" 1193 | }, 1194 | { 1195 | "name": "stdout", 1196 | "output_type": "stream", 1197 | "text": [ 1198 | "\n", 1199 | "Total Loss: 1.4099234193563461\n" 1200 | ] 1201 | }, 1202 | { 1203 | "data": { 1204 | "application/vnd.jupyter.widget-view+json": { 1205 | "model_id": "c3d6dd4755ac4bfb9dc2e5b27665d834", 1206 | "version_major": 2, 1207 | "version_minor": 0 1208 | }, 1209 | "text/plain": [ 1210 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 1211 | ] 1212 | }, 1213 | "metadata": {}, 1214 | "output_type": "display_data" 1215 | }, 1216 | { 1217 | "name": "stdout", 1218 | "output_type": "stream", 1219 | "text": [ 1220 | "\n", 1221 | "Total Loss: 1.2934547439217567\n" 1222 | ] 1223 | }, 1224 | { 1225 | "data": { 1226 | "application/vnd.jupyter.widget-view+json": { 1227 | "model_id": "f38a092f8721408fa6ff68f7b93d1a53", 1228 | "version_major": 2, 1229 | "version_minor": 0 1230 | }, 1231 | "text/plain": [ 1232 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 1233 | ] 1234 | }, 1235 | "metadata": {}, 1236 | "output_type": "display_data" 1237 | }, 1238 | { 1239 | "name": "stdout", 1240 | "output_type": "stream", 1241 | "text": [ 1242 | "\n", 1243 | "Total Loss: 1.1789302676916122\n", 1244 | "Saving Model\n" 1245 | ] 1246 | }, 1247 | { 1248 | "data": { 1249 | "application/vnd.jupyter.widget-view+json": { 1250 | "model_id": "bae58ab36ebb41ad95f416bc614f9e34", 1251 | "version_major": 2, 1252 | "version_minor": 0 1253 | }, 1254 | "text/plain": [ 1255 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 1256 | ] 1257 | }, 1258 | "metadata": {}, 1259 | "output_type": "display_data" 1260 | }, 1261 | { 1262 | "name": "stdout", 1263 | "output_type": "stream", 1264 | "text": [ 1265 | "\n", 1266 | "Total Loss: 1.0810576602816582\n", 1267 | "Saving Model\n" 1268 | ] 1269 | }, 1270 | { 1271 | "data": { 1272 | "application/vnd.jupyter.widget-view+json": { 1273 | "model_id": "5d31379849a342579b345e37c880344c", 1274 | "version_major": 2, 1275 | "version_minor": 0 1276 | }, 1277 | "text/plain": [ 1278 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 1279 | ] 1280 | }, 1281 | "metadata": {}, 1282 | "output_type": "display_data" 1283 | }, 1284 | { 1285 | "name": "stdout", 1286 | "output_type": "stream", 1287 | "text": [ 1288 | "\n", 1289 | "Total Loss: 1.0233119428157806\n", 1290 | "Saving Model\n" 1291 | ] 1292 | }, 1293 | { 1294 | "data": { 1295 | "application/vnd.jupyter.widget-view+json": { 1296 | "model_id": "a49518b0289245f6b7e26079a50bbe70", 1297 | "version_major": 2, 1298 | "version_minor": 0 1299 | }, 1300 | "text/plain": [ 1301 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 1302 | ] 1303 | }, 1304 | "metadata": {}, 1305 | "output_type": "display_data" 1306 | }, 1307 | { 1308 | "name": "stdout", 1309 | "output_type": "stream", 1310 | "text": [ 1311 | "\n", 1312 | "Total Loss: 1.0005386620759964\n", 1313 | "Saving Model\n" 1314 | ] 1315 | }, 1316 | { 1317 | "data": { 1318 | "application/vnd.jupyter.widget-view+json": { 1319 | "model_id": "f2ce1ca6019845aaae5d46b3ee4455ba", 1320 | "version_major": 2, 1321 | "version_minor": 0 1322 | }, 1323 | "text/plain": [ 1324 | "HBox(children=(FloatProgress(value=0.0, description='Batch Progress: ', max=10.0, style=ProgressStyle(descript…" 1325 | ] 1326 | }, 1327 | "metadata": {}, 1328 | "output_type": "display_data" 1329 | }, 1330 | { 1331 | "name": "stdout", 1332 | "output_type": "stream", 1333 | "text": [ 1334 | "\n", 1335 | "Total Loss: 1.0272977575659752\n", 1336 | "\n" 1337 | ] 1338 | } 1339 | ], 1340 | "source": [ 1341 | "for epoch in tqdm(range(EPOCHS), desc=\"Epoch Progress: \"):\n", 1342 | " # Losses and Loss Weights\n", 1343 | " # Loss_final = W1*L1 + W2*L2 + W3*L3\n", 1344 | " total_loss = 0\n", 1345 | " loss_final = 0\n", 1346 | " loss1A = 0 # Pixel-wise binary cross entropy\n", 1347 | " loss1B = 0\n", 1348 | " weight1 = 0\n", 1349 | " loss2 = 0 # Triplet loss\n", 1350 | " weight2 = 0\n", 1351 | " loss3 = 0 # Cross Entropy\n", 1352 | " weight3 = 0\n", 1353 | " \n", 1354 | " #Statistics\n", 1355 | " predictions_correct = 0\n", 1356 | " predictions_total = 0\n", 1357 | " background_percent_correct = 0\n", 1358 | " \n", 1359 | " loader = tqdm(enumerate(train_loader), desc=\"Batch Progress: \", total=(train_size/BATCH_SIZE))\n", 1360 | " \n", 1361 | " model.train()\n", 1362 | " \n", 1363 | " time_start = time.time()\n", 1364 | " \n", 1365 | " for i, sample in loader:\n", 1366 | " # Image Tensors\n", 1367 | " imageA = sample[\"image\"][0]\n", 1368 | " imageB = sample[\"image\"][1]\n", 1369 | " # Mask Tensors\n", 1370 | " maskA = sample[\"mask\"][0].float().to(device)\n", 1371 | " maskB = sample[\"mask\"][1].float().to(device)\n", 1372 | " # Labels\n", 1373 | " labelA = sample[\"label\"][0]\n", 1374 | " labelB = sample[\"label\"][1]\n", 1375 | " \n", 1376 | " # Normalizing images before sending to the model\n", 1377 | " # Forming into batch-shape for processing\n", 1378 | " norm_imageA = F.normalize(imageA,\n", 1379 | " mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).unsqueeze(0)\n", 1380 | " norm_imageB = F.normalize(imageB, \n", 1381 | " mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).unsqueeze(0)\n", 1382 | " \n", 1383 | " norm_imageA = norm_imageA.to(device)\n", 1384 | " norm_imageB = norm_imageB.to(device)\n", 1385 | " \n", 1386 | " # Network Run\n", 1387 | " pmapA, pmapB, vectorA, vectorB, decision = model(norm_imageA, norm_imageB)\n", 1388 | " \n", 1389 | " # Obtaining co-segmentation masks\n", 1390 | " # based off of the decision by the decision net\n", 1391 | " # During testing, the decision net's value\n", 1392 | " # would be thresheld before multiplying\n", 1393 | " pred_maskA = pmapA*decision\n", 1394 | " pred_maskB = pmapB*decision\n", 1395 | " \n", 1396 | " # Loss Calculations/Evaluations\n", 1397 | " # Configuring loss weights depending on sample\n", 1398 | " # Also deciding if we produce a groundtruth mask\n", 1399 | " truth = None\n", 1400 | " pairwise = None\n", 1401 | " if labelA == labelB: # Positive Sample found\n", 1402 | " # Weighting loss evenly\n", 1403 | " w1=w2=w3 = 0.33\n", 1404 | " \n", 1405 | " truth = 1\n", 1406 | " pairwise = 1\n", 1407 | " else: # Negative sample\n", 1408 | " # Prevent Loss1 from backpropogating\n", 1409 | " w1 = 0\n", 1410 | " w2=w3 = 0.5\n", 1411 | " \n", 1412 | " # Create a null mask from the groundtruths\n", 1413 | " maskA = maskA * 0\n", 1414 | " maskB = maskB * 0\n", 1415 | " \n", 1416 | " truth = 0\n", 1417 | " pairwise = -1\n", 1418 | " \n", 1419 | " # Loss 1\n", 1420 | " # Pixel-wise Binary Cross Entropy Loss\n", 1421 | " loss1A = criterion_bce(pred_maskA, maskA.unsqueeze(0).unsqueeze(0))\n", 1422 | " loss1B = criterion_bce(pred_maskB, maskB.unsqueeze(0).unsqueeze(0))\n", 1423 | " \n", 1424 | " # Loss 2\n", 1425 | " # Standard Triplet Loss with Margin\n", 1426 | " pairwise = torch.tensor(pairwise).to(device)\n", 1427 | " loss2 = criterion_cel(vectorA.unsqueeze(0), vectorB.unsqueeze(0), pairwise)\n", 1428 | " \n", 1429 | " # Loss 3\n", 1430 | " # Binary Cross Entropy Loss\n", 1431 | " truth = torch.tensor(truth).float().unsqueeze(0).to(device)\n", 1432 | " loss3 = criterion_bce(decision, truth)\n", 1433 | " \n", 1434 | " loss_final = w1*(loss1A + loss1B) + w2*loss2 + w3*loss3\n", 1435 | " \n", 1436 | " loss_final.backward()\n", 1437 | " optimizer.step()\n", 1438 | " \n", 1439 | " total_loss = total_loss + loss_final.item()\n", 1440 | " \n", 1441 | " print(\"Total Loss: \" + str(total_loss))\n", 1442 | " \n", 1443 | " if total_loss < prev_loss and SAVE_WEIGHTS:\n", 1444 | " # Check for dir, create if it doesn't exist\n", 1445 | " if not os.path.exists(WEIGHTS_PATH):\n", 1446 | " os.makedirs(WEIGHTS_PATH)\n", 1447 | " \n", 1448 | " prev_loss = total_loss\n", 1449 | " print(\"Saving Model\")\n", 1450 | " torch.save(model.state_dict(), os.path.join(WEIGHTS_PATH, \"CoSegNet_VGG16.path\"))\n", 1451 | " \n", 1452 | " \n", 1453 | " time_total = round(time.time() - time_start, 2)" 1454 | ] 1455 | }, 1456 | { 1457 | "cell_type": "markdown", 1458 | "metadata": {}, 1459 | "source": [ 1460 | "## Testing" 1461 | ] 1462 | }, 1463 | { 1464 | "cell_type": "code", 1465 | "execution_count": 112, 1466 | "metadata": {}, 1467 | "outputs": [ 1468 | { 1469 | "name": "stdout", 1470 | "output_type": "stream", 1471 | "text": [ 1472 | "tensor([1.], device='cuda:0', grad_fn=)\n" 1473 | ] 1474 | }, 1475 | { 1476 | "data": { 1477 | "image/png": "\n", 1478 | "text/plain": [ 1479 | "
" 1480 | ] 1481 | }, 1482 | "metadata": { 1483 | "needs_background": "light" 1484 | }, 1485 | "output_type": "display_data" 1486 | } 1487 | ], 1488 | "source": [ 1489 | "sample = next(iter(test_loader))\n", 1490 | "\n", 1491 | "imageA = sample[\"image\"][0].unsqueeze(0).to(device)\n", 1492 | "imageB = sample[\"image\"][1].unsqueeze(0).to(device)\n", 1493 | "\n", 1494 | "model.eval()\n", 1495 | "pmapA, pmapB, vectorA, vectorB, decision = model(imageA, imageB)\n", 1496 | "\n", 1497 | "mapA = F.to_pil_image(pmapA.detach().cpu().squeeze())\n", 1498 | "\n", 1499 | "imageA_cpu = F.to_pil_image(imageA.cpu().squeeze())\n", 1500 | "plt.imshow(imageA_cpu)\n", 1501 | "plt.imshow(mapA, cmap=\"jet\", alpha=0.4)\n", 1502 | "plt.plot()\n", 1503 | "\n", 1504 | "print(decision)\n" 1505 | ] 1506 | } 1507 | ], 1508 | "metadata": { 1509 | "kernelspec": { 1510 | "display_name": "Python 3", 1511 | "language": "python", 1512 | "name": "python3" 1513 | }, 1514 | "language_info": { 1515 | "codemirror_mode": { 1516 | "name": "ipython", 1517 | "version": 3 1518 | }, 1519 | "file_extension": ".py", 1520 | "mimetype": "text/x-python", 1521 | "name": "python", 1522 | "nbconvert_exporter": "python", 1523 | "pygments_lexer": "ipython3", 1524 | "version": "3.7.4" 1525 | } 1526 | }, 1527 | "nbformat": 4, 1528 | "nbformat_minor": 2 1529 | } 1530 | --------------------------------------------------------------------------------