├── DeepConnection.ipynb └── README.md /DeepConnection.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"DeepConnection.ipynb","version":"0.3.2","provenance":[{"file_id":"1Xf5jsiDRNdf_MqKMEfIWab7kA9aliztL","timestamp":1559385201954},{"file_id":"1CzdTemebKULND1prnWj2CLkpIRFwcMJP","timestamp":1557910039467}],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"code","metadata":{"id":"kpGtTWjhVcdo","colab_type":"code","outputId":"c4351d14-6c24-4ac2-ffcb-cc27c37ae526","executionInfo":{"status":"ok","timestamp":1559115624099,"user_tz":-120,"elapsed":1073,"user":{"displayName":"Daniel Bojar","photoUrl":"","userId":"10339697633531698497"}},"colab":{"base_uri":"https://localhost:8080/","height":36}},"source":["import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","from torch.optim import lr_scheduler\n","import numpy as np\n","import torchvision\n","from torchvision import models, transforms, datasets\n","import matplotlib.pyplot as plt\n","import time\n","import os\n","import copy\n","plt.ion()\n","from google.colab import drive\n","drive.mount('/content/drive')\n","from PIL import Image"],"execution_count":0,"outputs":[{"output_type":"stream","text":["Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"BwPUa6zspUIo","colab_type":"code","colab":{}},"source":["!pip install git+https://github.com/aleju/imgaug\n","from imgaug import augmenters as iaa\n","import imgaug as ia"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"lSQshPi089fH","colab_type":"text"},"source":["# Data Augmentation\n","- apply transforms to flip, crop, blur, scale, rotate, shear images and change hue / saturation"]},{"cell_type":"code","metadata":{"id":"_eIaJ657pUlC","colab_type":"code","colab":{}},"source":["class ImgAugTransform:\n"," def __init__(self):\n"," self.aug = iaa.Sequential([\n"," iaa.Fliplr(0.5),\n"," iaa.Crop(percent=(0, 0.1)),\n"," iaa.Sometimes(0.5, iaa.GaussianBlur(sigma = (0, 0.5))),\n"," iaa.AdditiveGaussianNoise(loc = 0, scale = (0.0, 0.05*255), per_channel = 0.5),\n"," iaa.Multiply((0.8, 1.2), per_channel = 0.2),\n"," iaa.Affine(\n"," scale={\"x\": (0.8, 1.2), \"y\": (0.8, 1.2)},\n"," translate_percent={\"x\": (-0.2, 0.2), \"y\": (-0.2, 0.2)},\n"," rotate=(-25, 25),\n"," shear=(-8, 8)\n"," ),\n"," iaa.AddToHueAndSaturation(value = (-10, 10), per_channel = True),\n"," iaa.ContrastNormalization((0.75, 1.5))\n"," ], random_order = True)\n"," \n"," def __call__(self, img):\n"," img = np.array(img)\n"," img = self.aug.augment_image(img)\n"," return Image.fromarray(img)"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"kBH6Zepf9Rkp","colab_type":"text"},"source":["# Building Dataloaders\n","- transforms + resizing to 512 (no conceptual limit; can be increased with more memory)\n","- center-cropping didn't improve results but this may depend on the images\n","- folder structure should be train / val and then folders with the class names"]},{"cell_type":"code","metadata":{"id":"4VnNC-OpWh69","colab_type":"code","colab":{}},"source":["data_transforms = {\n"," 'train':transforms.Compose([\n"," transforms.Resize((512, 512)),\n"," ImgAugTransform(),\n"," transforms.ToTensor(),\n"," transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n"," ]),\n"," 'val':transforms.Compose([\n"," transforms.Resize((512, 512)),\n"," transforms.ToTensor(),\n"," transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n"," ]),\n","}\n","\n","## you may need to change this to your data directory\n","data_dir = 'drive/My Drive/couples'\n","\n","image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),\n"," data_transforms[x]) for x in ['train', 'val']}\n","dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],\n"," shuffle = True, batch_size = 48,\n"," num_workers = 4) for x in ['train', 'val']}\n","\n","dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}\n","class_names = image_datasets['train'].classes\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"BCEz5WB_-fFs","colab_type":"text"},"source":["# Visualizing Batch\n","- takes one batch (with data augmentation) from the dataloader and visualizes it"]},{"cell_type":"code","metadata":{"id":"A1_Gzaz3ZPjQ","colab_type":"code","colab":{}},"source":["def imshow(inp, title = None):\n"," inp = inp.numpy().transpose((1, 2, 0))\n"," mean = np.array([0.485, 0.456, 0.406])\n"," std = np.array([0.229, 0.224, 0.225])\n"," inp = std*inp + mean\n"," inp = np.clip(inp, 0, 1)\n"," plt.imshow(inp)\n"," if title is not None:\n"," plt.title(title)\n"," plt.pause(0.001)\n"," \n","inputs,classes = next(iter(dataloaders['train']))\n","out = torchvision.utils.make_grid(inputs)\n","imshow(out, title = [class_names[x] for x in classes])"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"RlpW8QjZ-1KF","colab_type":"text"},"source":["# Mixed Precision Training\n","- uses less memory, is faster and leads to slightly better generalization"]},{"cell_type":"code","metadata":{"id":"C2bCJKEYs1gC","colab_type":"code","colab":{}},"source":["!pip install git+https://github.com/NVIDIA/apex\n","from apex import amp"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"raX2nXgO--fW","colab_type":"text"},"source":["# Early Stopping\n","- stops training if validation loss did not improve for the last three epochs\n","- regularization technique"]},{"cell_type":"code","metadata":{"id":"wqI4DGcoIr8l","colab_type":"code","colab":{}},"source":["## https://github.com/Bjarten/early-stopping-pytorch\n","\n","class EarlyStopping:\n"," \"\"\"Early stops the training if validation loss doesn't improve after a given patience.\"\"\"\n"," def __init__(self, patience = 7, verbose = False):\n"," \"\"\"\n"," Args:\n"," patience (int): How long to wait after last time validation loss improved.\n"," Default: 7\n"," verbose (bool): If True, prints a message for each validation loss improvement. \n"," Default: False\n"," \"\"\"\n"," self.patience = patience\n"," self.verbose = verbose\n"," self.counter = 0\n"," self.best_score = None\n"," self.early_stop = False\n"," self.val_loss_min = 0\n","\n"," def __call__(self, val_loss, model):\n","\n"," score = -val_loss\n","\n"," if self.best_score is None:\n"," self.best_score = score\n"," self.save_checkpoint(val_loss, model)\n"," elif score < self.best_score:\n"," self.counter += 1\n"," print(f'EarlyStopping counter: {self.counter} out of {self.patience}')\n"," if self.counter >= self.patience:\n"," self.early_stop = True\n"," else:\n"," self.best_score = score\n"," self.save_checkpoint(val_loss, model)\n"," self.counter = 0\n","\n"," def save_checkpoint(self, val_loss, model):\n"," '''Saves model when validation loss decrease.'''\n"," if self.verbose:\n"," print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')\n"," #torch.save(model.state_dict(), 'drive/My Drive/checkpoint.pt')\n"," self.val_loss_min = val_loss\n"," \n","early_stopping = EarlyStopping(patience=3, verbose=True)"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xaaTzDnG_VgV","colab_type":"text"},"source":["# Training Loop\n","- keeps track of loss and accuracy\n","- saves the model with the highest validation accuracy"]},{"cell_type":"code","metadata":{"id":"Q0ITzH-VaLbQ","colab_type":"code","colab":{}},"source":["def train_model(model, criterion, optimizer, scheduler, num_epochs=25):\n"," since = time.time()\n"," best_model_wts = copy.deepcopy(model.state_dict())\n"," best_acc = 0.0\n"," \n"," for epoch in range(num_epochs):\n"," print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n"," print('-'*10)\n"," \n"," for phase in ['train', 'val']:\n"," if phase == 'train':\n"," scheduler.step()\n"," model.train()\n"," else:\n"," model.eval()\n"," \n"," running_loss = 0.0\n"," running_corrects = 0\n"," for inputs, labels in dataloaders[phase]:\n"," inputs=inputs.to(device)\n"," labels=labels.to(device)\n"," optimizer.zero_grad()\n"," \n"," with torch.set_grad_enabled(phase == 'train'):\n"," outputs = model(inputs)\n"," _, preds = torch.max(outputs, 1)\n"," loss = criterion(outputs,labels)\n"," \n"," if phase == 'train':\n"," with amp.scale_loss(loss, optimizer) as scaled_loss:\n"," scaled_loss.backward()\n"," optimizer.step()\n"," \n"," running_loss += loss.item()*inputs.size(0)\n"," running_corrects += torch.sum(preds == labels.data)\n"," \n"," epoch_loss = running_loss / dataset_sizes[phase]\n"," epoch_acc = running_corrects.double() / dataset_sizes[phase]\n"," print('{} Loss: {:.4f} Acc: {:.4f}'.format(\n"," phase, epoch_loss, epoch_acc))\n"," \n"," if phase == 'val' and epoch_acc >= best_acc:\n"," best_acc = epoch_acc\n"," best_model_wts = copy.deepcopy(model.state_dict())\n"," if phase == 'val':\n"," early_stopping(epoch_loss, model)\n"," \n"," if early_stopping.early_stop:\n"," print(\"Early stopping\")\n"," break\n"," print()\n"," \n"," time_elapsed = time.time() - since\n"," print('Training complete in {:.0f}m {:.0f}s'.format(\n"," time_elapsed // 60, time_elapsed % 60))\n"," print('Best val Acc: {:4f}'.format(best_acc))\n"," model.load_state_dict(best_model_wts)\n"," return model"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"43JoGXgVA1SJ","colab_type":"text"},"source":["# Visualize Model Predictions\n","- grabs a few validation images and displays them together with their predicted class"]},{"cell_type":"code","metadata":{"id":"gJJcEF3VddwU","colab_type":"code","colab":{}},"source":["def visualize_model(model, num_images = 6):\n"," was_training = model.training\n"," model.eval()\n"," images_so_far = 0\n"," fig = plt.figure()\n"," \n"," with torch.no_grad():\n"," for i, (inputs, labels) in enumerate(dataloaders['val']):\n"," inputs = inputs.to(device)\n"," labels = labels.to(device)\n"," outputs = model(inputs)\n"," _, preds = torch.max(outputs, 1)\n"," \n"," for j in range(inputs.size()[0]):\n"," images_so_far += 1\n"," ax = plt.subplot(num_images // 2, 2, images_so_far)\n"," ax.axis('off')\n"," ax.set_title('predicted: {}'.format(class_names[preds[j]]))\n"," imshow(inputs.cpu().data[j])\n"," if images_so_far == num_images:\n"," model.train(mode = was_training)\n"," return model.train(mode = was_training)"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"A2vlD5-7BXol","colab_type":"text"},"source":["# Spatial Pyramid Pooling Layer & Power Mean Transformation\n","- creates fixed sized representation from convolutional outputs of variable sized images\n","- uses different filters to represent bins with different features"]},{"cell_type":"code","metadata":{"id":"sWJvOR5REMex","colab_type":"code","colab":{}},"source":["## SPP from https://github.com/yueruchen/sppnet-pytorch/blob/master/spp_layer.py\n","\n","import math\n","\n","def spatial_pyramid_pool(previous_conv, num_sample, previous_conv_size, out_pool_size):\n"," '''\n"," previous_conv: a tensor vector of previous convolution layer\n"," num_sample: an int number of image in the batch\n"," previous_conv_size: an int vector [height, width] of the matrix features size of previous convolution layer\n"," out_pool_size: a int vector of expected output size of max pooling layer\n"," \n"," returns: a tensor vector with shape [1 x n] is the concentration of multi-level pooling\n"," ''' \n"," \n"," for i in range(len(out_pool_size)):\n"," h_wid = int(math.ceil(previous_conv_size[0] / out_pool_size[i]))\n"," w_wid = int(math.ceil(previous_conv_size[1] / out_pool_size[i]))\n"," h_pad = (h_wid*out_pool_size[i] - previous_conv_size[0] + 1)//2\n"," w_pad = (w_wid*out_pool_size[i] - previous_conv_size[1] + 1)//2\n"," \n"," maxpool = nn.MaxPool2d((h_wid, w_wid), stride = (h_wid, w_wid), padding = (h_pad, w_pad))\n"," #avgpool = nn.AvgPool2d((h_wid, w_wid), stride = (h_wid, w_wid), padding = (h_pad, w_pad))\n"," x = maxpool(previous_conv)\n"," ## tried to also concat average pooling but did not increase model performance, maybe dependent on application\n"," #y = avgpool(previous_conv)\n"," #z = torch.cat((x, y), dim = -1)\n"," \n"," if(i == 0):\n"," spp = x.view(num_sample, -1)\n"," else:\n"," spp = torch.cat((spp, x.view(num_sample, -1)), 1)\n"," \n"," return spp\n","\n","class PMT(nn.Module):\n"," def __init__(self):\n"," super(PMT, self).__init__()\n"," \n"," def forward(self, x):\n"," ## tried to apply PMT prior to first convolution, did not increase model performance\n"," #x_1 = torch.sign(x)*torch.log(1 + abs(x))\n"," #x_2 = torch.sign(x)*(torch.log(1 + abs(x)))**2\n"," #x = torch.cat((x, x_1, x_2), dim = 3)\n"," return x\n"," \n","class SPP(nn.Module):\n"," def __init__(self):\n"," super(SPP, self).__init__()\n"," \n"," ## features incoming from ResNet-34 (after SPP/PMT)\n"," self.lin1 = nn.Linear(2*43520, 100)\n"," \n"," self.relu = nn.ReLU()\n"," self.bn1 = nn.BatchNorm1d(100)\n"," self.dp1 = nn.Dropout(0.5)\n"," self.lin2 = nn.Linear(100, 2)\n"," \n"," def forward(self, x):\n"," # SPP\n"," x = spatial_pyramid_pool(x, x.shape[0], [x.shape[2], x.shape[3]], [8, 4, 2, 1])\n"," \n"," # PMT\n"," x_1 = torch.sign(x)*torch.log(1 + abs(x))\n"," x_2 = torch.sign(x)*(torch.log(1 + abs(x)))**2\n"," x = torch.cat((x_1, x_2), dim = 1)\n"," \n"," # fully connected classification part\n"," x = self.lin1(x)\n"," x = self.bn1(self.relu(x))\n"," \n"," #1\n"," x1 = self.lin2(self.dp1(x))\n"," #2\n"," x2 = self.lin2(self.dp1(x))\n"," #3\n"," x3 = self.lin2(self.dp1(x))\n"," #4\n"," x4 = self.lin2(self.dp1(x))\n"," #5\n"," x5 = self.lin2(self.dp1(x))\n"," #6\n"," x6 = self.lin2(self.dp1(x))\n"," #7\n"," x7 = self.lin2(self.dp1(x))\n"," #8\n"," x8 = self.lin2(self.dp1(x))\n"," \n"," x = torch.mean(torch.stack([x1, x2, x3, x4, x5, x6, x7, x8]), dim = 0)\n"," \n"," return x"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"4iDurpv2CAJT","colab_type":"text"},"source":["# Training the DeepConnection Model\n","- best results with a pretrained ResNet-34 model\n","- exchange the AdaptiveAvgPool layer with a spatial pyramid pooling layer\n","- add power mean transformation after SPP\n","- add some linear layers (initialized with Xavier initialization) for classification"]},{"cell_type":"code","metadata":{"id":"bKTpuVFsemuM","colab_type":"code","colab":{}},"source":["# initialization for linear layers\n","def init_weights(m):\n"," if type(m) == nn.Linear:\n"," torch.nn.init.xavier_uniform_(m.weight)\n"," m.bias.data.fill_(0.)\n","\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","model = models.resnet34(pretrained = True)\n","\n","# freeze ResNet-34 layers\n","for param in model.parameters():\n"," param.require_grad = False \n"," \n","\n","# DeepConnection\n","class PMT_SPP(nn.Module):\n"," def __init__(self):\n"," super(PMT_SPP, self).__init__()\n"," self.fc = PMT()\n"," self.convs = nn.Sequential(*list(model.children())[:-2])\n"," self.fc2 = SPP()\n"," \n"," def forward(self, x):\n"," x = self.fc(x)\n"," x = self.convs(x)\n"," x = self.fc2(x)\n"," \n"," return x\n"," \n"," \n","# instantiate model, put on GPU and initialize linear layers \n","deep_connection = PMT_SPP().cuda()\n","deep_connection.fc2.apply(init_weights)\n"," \n","criterion = nn.CrossEntropyLoss()\n","optimizer_ft = optim.Adam(deep_connection.parameters(),lr = 0.0001, weight_decay = 1e-1)\n","scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer_ft, 15)\n","model_ft, optimizer_ft = amp.initialize(deep_connection, optimizer_ft, opt_level = \"O1\")\n","\n","model_ft = train_model(model_ft, criterion, optimizer_ft, scheduler,\n"," num_epochs = 16)"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0o0I-jx7oV3x","colab_type":"text"},"source":["# Hook Function to get Inputs / Outputs of Intermediate Layers\n","- exchange 'input' to 'output.detach()' to save the output of intermediate layers"]},{"cell_type":"code","metadata":{"id":"KN5CxXsQT3to","colab_type":"code","colab":{}},"source":["activation = {}\n","\n","def get_activation(name):\n"," def hook(model, input, output):\n"," activation[name] = input\n"," return hook\n","\n","## save all inputs to layers subsequent to convolutional layers\n","for name, layer in model_ft.fc2.named_modules():\n"," layer.register_forward_hook(get_activation(name))"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"koKlsRTrC8q1","colab_type":"text"},"source":["# Test Model on a Single New Image\n","- hooks will store all designated inputs / output once you run this"]},{"cell_type":"code","metadata":{"id":"GNPT7vzRwEJs","colab_type":"code","colab":{}},"source":["# need to change that to your directory\n","example = 'drive/My Drive/couples/couple-2436263_960_720.jpg'\n","\n","imsize = (512,512)\n","loader = transforms.Compose([transforms.Scale(imsize), \n"," transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])\n","\n","def image_loader(image_name):\n"," \"\"\"load image, returns cuda tensor\"\"\"\n"," image = Image.open(image_name)\n"," image = loader(image).float()\n"," image = image.unsqueeze(0)\n"," return image.cuda()\n"," \n","example_img = image_loader(example)\n","out = model_ft(example_img)\n","_, preds = torch.max(out, 1)\n","\n","print('predicted: {}'.format(class_names[preds]))\n","imshow(example_img.cpu().data[0])"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"9waeB9pxpRI6","colab_type":"text"},"source":["# Create State of Activation Plots\n","- plots class-specific weights of final linear layer (100, 2) as scatter plot, which are constant for a trained model\n","- colors points according to the input value of a given image stored in the hooks\n","- multiplication of weights and input values (with subsequent summation) yields class probabilities"]},{"cell_type":"code","metadata":{"id":"Zz8kzjhWUN45","colab_type":"code","colab":{}},"source":["ws2 = model_ft.fc2.lin2.weight.data.detach().cpu().numpy()\n","\n","plt.style.use('ggplot')\n","\n","cm = plt.cm.get_cmap('YlOrRd')\n","\n","sc = plt.scatter(ws2.T[:, 0], ws2.T[:, 1], c = activation['lin2'][0].cpu().detach().numpy()[0], vmax = 1.5, vmin = -1,\n"," cmap = cm, alpha = 0.8, s = 30, edgecolor = 'gray')\n","plt.xlabel('Happy Weight')\n","plt.ylabel('Unhappy Weight')\n","plt.title('Activation of Neurons in Last Layer')\n","cb = plt.colorbar(sc)\n","cb.set_label('Incoming Activation Values')\n","\n","plt.show()"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_dxSeDlDqF97","colab_type":"text"},"source":["# Visualizes Exemplary Model Predictions and Model Architecture"]},{"cell_type":"code","metadata":{"id":"qx-0WBMMjiE7","colab_type":"code","colab":{}},"source":["visualize_model(model_ft)"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"R0UV5VBVDlG0","colab_type":"text"},"source":["# Prepare Model for GradCAM\n","- need to separate convolutional part from rest (because we want the gradients at the end of the convolutional part)"]},{"cell_type":"code","metadata":{"id":"oa_CLTQ0v2P7","colab_type":"code","colab":{}},"source":["modulelist_conv = nn.Sequential(*list(model_ft.children())[:-1]).cuda().half()\n","modulelist_fc = nn.Sequential(*list(model_ft.children())[-1:]).cuda().half()\n","modulelist_fc"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"GkT24R3WDr_m","colab_type":"text"},"source":["# GradCAM\n","- function to create heatmap based on convolutional gradients, indicating saliency"]},{"cell_type":"code","metadata":{"id":"rbTzHHEnw4kS","colab_type":"code","colab":{}},"source":["## from https://github.com/eclique/pytorch-gradcam/blob/master/gradcam.ipynb\n","\n","def GradCAM(img, c, features_fn, classifier_fn):\n"," feats = modulelist_conv(img.cuda().half())\n"," feats = feats.cuda()\n"," _, N, H, W = feats.size()\n"," \n"," out = modulelist_fc(feats)\n"," c_score = out[0, c]\n"," grads = torch.autograd.grad(c_score, feats)\n"," w = grads[0][0].mean(-1).mean(-1)\n"," \n"," sal = torch.matmul(w, feats.view(N, H*W))\n"," sal = sal.view(H, W).cpu().detach().numpy()\n"," sal = np.maximum(sal, 0)\n"," \n"," return sal"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n0bGgVmXEAVZ","colab_type":"text"},"source":["# Create GradCAM Heatmaps\n","- grab an image and plot it plus overlaying GradCAM heatmaps for both classes"]},{"cell_type":"code","metadata":{"id":"8VwXf8ssRFLL","colab_type":"code","colab":{}},"source":["model_ft.eval()\n","\n","read_tensor = transforms.Compose([\n"," lambda x: Image.open(x),\n"," transforms.Resize((512, 512)),\n"," transforms.ToTensor(),\n"," transforms.Normalize(mean=[0.485, 0.456, 0.406],\n"," std=[0.229, 0.224, 0.225]),\n"," lambda x: torch.unsqueeze(x, 0)\n","])\n","\n","def get_class_name(c):\n"," if c == 0:\n"," return 'Happy Couple'\n"," if c == 1:\n"," return 'Unhappy Couple'\n","\n","## change that to your working directory\n","direc = 'drive/My Drive/couples/couple-2436263_960_720.jpg'\n","\n","dil = read_tensor(direc)\n","\n","# get prediction probabilities and corresponding classes\n","pp, cc = torch.topk(nn.Softmax(dim = 1)(model_ft(dil.cuda())), 2)\n","\n","plt.figure(figsize = (15, 5))\n","\n","for i, (p, c) in enumerate(zip(pp[0], cc[0])):\n"," plt.subplot(1, 2, i+1)\n"," \n"," sal = GradCAM(dil.cuda(), int(c), modulelist_conv, modulelist_fc)\n"," img = Image.open(direc)\n"," sal = Image.fromarray(sal)\n"," sal = sal.resize(img.size, resample = Image.LANCZOS)\n","\n"," plt.title('{}: {:.1f}%'.format(get_class_name(c), 100*float(p)))\n"," plt.axis('off')\n"," plt.imshow(img)\n"," plt.imshow(np.array(sal), alpha = 0.5, cmap = 'jet')\n"," \n","plt.show()"],"execution_count":0,"outputs":[]}]} -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepConnection 2 | Deep learning model to classify relationship state in romantic couples from images and video. 3 | 4 | This repository contains the code for training and using the deep learning binary classification model DeepConnection (described in 'DeepConnection: Classifying Relationship State from Images of Romantic Couples' by Maximiliane Uhlich and Daniel Bojar, https://psyarxiv.com/df25j/). Capitalizing on image data of romantic couples, DeepConnection uses facial and bodily expressions to predict the relationship state of the couple. This leads to a classification accuracy of nearly 97%. 5 | 6 | DeepConnection consists of a pretrained ResNet-34 base model, with a subsequent spatial pyramid pooling layer (https://arxiv.org/abs/1406.4729, with spatial bins of [8x8], [4x4], [2x2], and [1x1]) and a power mean transformation (http://lamda.nju.edu.cn/zhangcl/papers/pm2018_pr.pdf). Subsequent to this, a fully connected part with dropout and ReLU activation layers leads to the binary prediction. The motivation for DeepConnection stems from an improved classification accuracy, a substantial increase in classification speed in comparison to manual coding schemes, and potentially new insights into relevant factors for relationship state in romantic couples. 7 | 8 | The repository contains a Jupyter notebook with the relevant code for training & using the models. Additionally, representative trained models from the three stages discussed in the paper can be downloaded here (https://ln.sync.com/dl/04bec39c0/n8bei4rk-7kqhfkdz-u8vkgfkn-4kghsrpx) as pickle files. 9 | --------------------------------------------------------------------------------