├── .DS_Store ├── .ipynb_checkpoints └── demo-checkpoint.ipynb ├── LICENSE ├── README.md ├── __pycache__ └── unet_tile_se_norm.cpython-36.pyc ├── dataset.png ├── demo.ipynb ├── example.png ├── pipeline.png ├── requirements.txt └── unet_tile_se_norm.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SSHarmonization/0bb4ce83883ec9fb8690b3c8dbefb69018a1f6db/.DS_Store -------------------------------------------------------------------------------- /.ipynb_checkpoints/demo-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from __future__ import division\n", 10 | "%matplotlib inline\n", 11 | "import os\n", 12 | "import torch\n", 13 | "from torch import nn\n", 14 | "import torchvision.transforms.functional as tf\n", 15 | "from matplotlib.pyplot import imshow\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "from PIL import Image\n", 18 | "import numpy as np\n", 19 | "import pdb\n", 20 | "import cv2\n", 21 | "import random\n", 22 | "import glob\n", 23 | "import imageio" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "\n", 33 | "from unet_tile_se_norm import UNetTileSENorm\n", 34 | "from PIL import ImageFile\n", 35 | "ImageFile.LOAD_TRUNCATED_IMAGES = True\n", 36 | "# unet_tile_se_inter\n", 37 | "# from utils import np2tensor\n", 38 | "\n", 39 | "class OPT():\n", 40 | " pass\n", 41 | "\n", 42 | "def np2tensor(numpy_array):\n", 43 | " tensor = torch.from_numpy(np.transpose(numpy_array.copy(), (2, 0, 1))).float()/255.*2. - 1\n", 44 | " return tensor.unsqueeze(0)\n", 45 | "\n", 46 | "def tensor2im(input_image, imtype=np.uint8):\n", 47 | " \"\"\"\"Converts a Tensor array into a numpy image array.\n", 48 | "\n", 49 | " Parameters:\n", 50 | " input_image (tensor) -- the input image tensor array\n", 51 | " imtype (type) -- the desired type of the converted numpy array\n", 52 | " \"\"\"\n", 53 | " if not isinstance(input_image, np.ndarray):\n", 54 | " if isinstance(input_image, torch.Tensor): # get the data from a variable\n", 55 | " image_tensor = input_image.data\n", 56 | " else:\n", 57 | " return input_image\n", 58 | " image_numpy = image_tensor.float().numpy() # convert it into a numpy array\n", 59 | " if image_numpy.shape[0] == 1: # grayscale to RGB\n", 60 | " image_numpy = np.tile(image_numpy, (3, 1, 1))\n", 61 | " image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling\n", 62 | " image_numpy = np.maximum(image_numpy, 0)\n", 63 | " image_numpy = np.minimum(image_numpy, 255)\n", 64 | " else: # if it is a numpy array, do nothing\n", 65 | " image_numpy = input_image\n", 66 | " return image_numpy.astype(imtype)\n", 67 | "\n", 68 | "import torchvision.transforms as transforms\n", 69 | "\n", 70 | "def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):\n", 71 | " transform_list = []\n", 72 | " if grayscale:\n", 73 | " transform_list.append(transforms.Grayscale(1))\n", 74 | " if 'resize' in opt.preprocess:\n", 75 | " osize = [opt.load_size, opt.load_size]\n", 76 | " transform_list.append(transforms.Resize(osize, method))\n", 77 | " elif 'scale_width' in opt.preprocess:\n", 78 | " transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))\n", 79 | "\n", 80 | " if 'crop' in opt.preprocess:\n", 81 | " if params is None:\n", 82 | " transform_list.append(transforms.RandomCrop(opt.crop_size))\n", 83 | " else:\n", 84 | " transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))\n", 85 | "\n", 86 | " if opt.preprocess == 'none':\n", 87 | " pass\n", 88 | "# transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))\n", 89 | "\n", 90 | " if not opt.no_flip:\n", 91 | " if params is None:\n", 92 | " transform_list.append(transforms.RandomHorizontalFlip())\n", 93 | " elif params['flip']:\n", 94 | " transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))\n", 95 | "\n", 96 | " if convert:\n", 97 | " transform_list += [transforms.ToTensor()]\n", 98 | " if grayscale:\n", 99 | " transform_list += [transforms.Normalize((0.5,), (0.5,))]\n", 100 | " else:\n", 101 | " transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n", 102 | " return transforms.Compose(transform_list)\n", 103 | "\n", 104 | "\n" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 4, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "===============================================================================\n", 117 | "* *\n", 118 | "* Interpreter : *\n", 119 | "* python : 3.6.10 |Anaconda, Inc.| (default, May 7 2020, 23:06:31) *\n", 120 | "* [GCC 4.2.1 Compatible Clang 4.0.1 *\n", 121 | "* (tags/RELEASE_401/final)] *\n", 122 | "* *\n", 123 | "* colour-science.org : *\n", 124 | "* colour : 0.3.15 *\n", 125 | "* colour-checker-detection : 0.1.1 *\n", 126 | "* *\n", 127 | "* Runtime : *\n", 128 | "* imageio : 2.9.0 *\n", 129 | "* matplotlib : 3.3.0 *\n", 130 | "* networkx : 2.5 *\n", 131 | "* numpy : 1.19.5 *\n", 132 | "* scipy : 1.5.1 *\n", 133 | "* six : 1.15.0 *\n", 134 | "* opencv : 4.4.0 *\n", 135 | "* *\n", 136 | "===============================================================================\n", 137 | "216\n", 138 | "0\n", 139 | "pexels-johannes-plenio-1123445\n" 140 | ] 141 | }, 142 | { 143 | "name": "stderr", 144 | "output_type": "stream", 145 | "text": [ 146 | "/Users/yifanjiang/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/functional.py:3509: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.\n", 147 | " warnings.warn(\"nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.\")\n", 148 | "/Users/yifanjiang/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/functional.py:3635: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.\n", 149 | " \"See the documentation of nn.Upsample for details.\".format(mode)\n", 150 | "/Users/yifanjiang/anaconda3/envs/torch/lib/python3.6/site-packages/colour/utilities/verbose.py:235: ColourUsageWarning: \"OpenImageIO\" related API features are not available, switching to \"Imageio\"!\n", 151 | " warn(*args, **kwargs)\n" 152 | ] 153 | }, 154 | { 155 | "ename": "KeyboardInterrupt", 156 | "evalue": "", 157 | "output_type": "error", 158 | "traceback": [ 159 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 160 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 161 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 193\u001b[0m output = colour.cctf_encoding(\n\u001b[1;32m 194\u001b[0m colour.colour_correction(\n\u001b[0;32m--> 195\u001b[0;31m colour.cctf_decoding(colour.io.read_image(fore_shift_path)), color_0, color_1, terms = 17))\n\u001b[0m\u001b[1;32m 196\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 162 | "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/colour/characterisation/correction.py\u001b[0m in \u001b[0;36mcolour_correction\u001b[0;34m(RGB, M_T, M_R, method, **kwargs)\u001b[0m\n\u001b[1;32m 915\u001b[0m \u001b[0mfunction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCOLOUR_CORRECTION_METHODS\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mmethod\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 916\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 917\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mRGB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mM_T\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mM_R\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfilter_kwargs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 163 | "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/colour/characterisation/correction.py\u001b[0m in \u001b[0;36mcolour_correction_Cheung2004\u001b[0;34m(RGB, M_T, M_R, terms)\u001b[0m\n\u001b[1;32m 681\u001b[0m \u001b[0mRGB\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mRGB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 682\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 683\u001b[0;31m \u001b[0mRGB_e\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0maugmented_matrix_Cheung2004\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mRGB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mterms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 684\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 685\u001b[0m \u001b[0mCCM\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcolour_correction_matrix_Cheung2004\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mM_T\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mM_R\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mterms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 164 | "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/colour/characterisation/correction.py\u001b[0m in \u001b[0;36maugmented_matrix_Cheung2004\u001b[0;34m(RGB, terms)\u001b[0m\n\u001b[1;32m 165\u001b[0m return tstack([\n\u001b[1;32m 166\u001b[0m \u001b[0mR\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mR\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mG\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mR\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mR\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mB\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mR\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mG\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mB\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 167\u001b[0;31m \u001b[0mR\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mG\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mB\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mR\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mR\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mB\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mones\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 168\u001b[0m ])\n\u001b[1;32m 169\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mterms\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m19\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 165 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "import cv2\n", 171 | "import glob\n", 172 | "import matplotlib.pyplot as plt\n", 173 | "import numpy as np\n", 174 | "import os\n", 175 | "from collections import OrderedDict\n", 176 | "import pdb\n", 177 | "import colour\n", 178 | "from colour.plotting import *\n", 179 | "import imageio\n", 180 | "from colour_checker_detection import (\n", 181 | " EXAMPLES_RESOURCES_DIRECTORY,\n", 182 | " colour_checkers_coordinates_segmentation,\n", 183 | " detect_colour_checkers_segmentation)\n", 184 | "\n", 185 | "from colour_checker_detection.detection.segmentation import (\n", 186 | " adjust_image)\n", 187 | "\n", 188 | "colour.utilities.describe_environment();\n", 189 | "colour_style();\n", 190 | "\n", 191 | "opt = OPT()\n", 192 | "opt.norm = \"IN\"\n", 193 | "opt.preprocess = \"none\"\n", 194 | "opt.no_flip = True\n", 195 | "transform = get_transform(opt)\n", 196 | "\n", 197 | "\n", 198 | "fore_list = glob.glob('./portrait_image/*_cloud_input_WB.jpg') + glob.glob('./portrait2/*_cloud_input_WB.jpg')\n", 199 | "back_list = glob.glob('./street_image/*.jpg') + glob.glob('./background_image/*.jpg')+ glob.glob('./background2/*.jpg')\n", 200 | "# back_list = glob.glob('./portrait2/*_cloud_input_WB.jpg') + glob.glob('./portrait_image/*_cloud_input_WB.jpg') + glob.glob('./street_image/*_WB.jpg') + glob.glob('./background_image/*_WB.jpg')\n", 201 | "\n", 202 | "def fitting(fore_name, fore_path, dir_path, example_name, mask_0, back_image_0):\n", 203 | " EXAMPLES_RESOURCES_DIRECTORY = dir_path\n", 204 | " COLOUR_CHECKER_IMAGE_PATHS = glob.glob(\n", 205 | " os.path.join(EXAMPLES_RESOURCES_DIRECTORY, '*.jpg'))\n", 206 | " COLOUR_CHECKER_IMAGES = [\n", 207 | " colour.cctf_decoding(colour.io.read_image(path))\n", 208 | " for path in COLOUR_CHECKER_IMAGE_PATHS\n", 209 | " ]\n", 210 | "\n", 211 | " for iteration in range(1):\n", 212 | "\n", 213 | " color_0 = cv2.resize(COLOUR_CHECKER_IMAGES[1], (64,64)).reshape((-1,3))\n", 214 | " color_1 = cv2.resize(COLOUR_CHECKER_IMAGES[0], (64,64)).reshape((-1,3))\n", 215 | " output = colour.cctf_encoding(\n", 216 | " colour.colour_correction(\n", 217 | " colour.cctf_decoding(colour.io.read_image(fore_path)), color_0, color_1, terms = 17))\n", 218 | " output = np.clip(output, 0, 1)\n", 219 | "\n", 220 | " comp = output*(mask_0) + back_image_0 * (1-mask_0)\n", 221 | "\n", 222 | " imageio.imwrite('./dovenet_compare_psnr_lr/'+ fore_name+example_name,input_image*255)\n", 223 | " \n", 224 | "def process(inp_img, ref_img, inp_show, model_1, dir_path):\n", 225 | " if not os.path.isdir(dir_path):\n", 226 | " os.makedirs(dir_path)\n", 227 | " SS_scratch_output, _, _ = model_1(inp_img, ref_img)\n", 228 | " SS_scratch_show = tensor2im(SS_scratch_output[0])\n", 229 | " SS_scratch_show = SS_scratch_show.astype(float)/255\n", 230 | " Image.fromarray(np.uint8(inp_show*255)).save(\"./\" + dir_path + \"/input.jpg\")\n", 231 | " Image.fromarray(np.uint8(SS_scratch_show*255)).save(\"./\" + dir_path + \"/result.jpg\")\n", 232 | "\n", 233 | "image_small_list = []\n", 234 | "for root, dirs, files in os.walk(\"../RealHM/vendor_testing_1/\"):\n", 235 | " for file in files:\n", 236 | " if \"small\" in file:\n", 237 | " image_small_list.append(os.path.join(root, file))\n", 238 | "for root, dirs, files in os.walk(\"../RealHM/vendor_testing_2/\"):\n", 239 | " for file in files:\n", 240 | " if \"small\" in file:\n", 241 | " image_small_list.append(os.path.join(root, file))\n", 242 | " \n", 243 | "for root, dirs, files in os.walk(\"../RealHM/vendor_testing_3/\"):\n", 244 | " for file in files:\n", 245 | " if \"small\" in file:\n", 246 | " image_small_list.append(os.path.join(root, file))\n", 247 | " \n", 248 | "print(len(image_small_list))\n", 249 | "for index, file in enumerate(image_small_list):\n", 250 | " if 1 == 1:\n", 251 | " print(index)\n", 252 | " fore_path = file\n", 253 | " fore_shift_path = file.replace(\"_small.jpg\", \".jpg\")\n", 254 | " mask_path = file.replace(\"_small.jpg\", \"_mask.jpg\")\n", 255 | " bg_path = file.replace(\"_small.jpg\", \"_fore.jpg\")\n", 256 | " gt_path = file.replace(\"_small.jpg\", \"_gt.jpg\")\n", 257 | " fore_name = bg_path.replace(\"_fore.jpg\", \"\").replace(\"../vendor_testing_2/\", \"\").replace(\"../vendor_testing_1/\", \"\").replace(\"../vendor_testing_3/\", \"\")\n", 258 | " fore_name = fore_name.split(\"/\")[-1]\n", 259 | "# if os.path.isfile('./output/'+ fore_name+'_ssh.jpg'):\n", 260 | "# continue\n", 261 | " inp_img = Image.open(fore_path).convert('RGB')\n", 262 | " inp_shift_img = Image.open(fore_shift_path).convert('RGB')\n", 263 | " ref_img = Image.open(bg_path).convert('RGB')\n", 264 | " mask_img = Image.open(mask_path).convert('RGB')\n", 265 | " gt_img = Image.open(gt_path).convert('RGB')\n", 266 | "\n", 267 | " inp_img = tf.resize(inp_img, [256, 256])\n", 268 | " inp_shift_img = tf.resize(inp_shift_img, [256, 256])\n", 269 | " ref_img = tf.resize(ref_img, [256, 256])\n", 270 | " mask_img = tf.resize(mask_img, [256, 256])\n", 271 | " gt_img = tf.resize(gt_img, [256, 256])\n", 272 | " \n", 273 | " inp_img = np.array(inp_img)\n", 274 | " inp_shift_img = np.array(inp_shift_img)\n", 275 | " ref_img = np.array(ref_img)\n", 276 | " mask_img = np.array(mask_img)\n", 277 | " gt_img = np.array(gt_img)\n", 278 | "\n", 279 | " inp_img = np2tensor(inp_img)\n", 280 | " inp_shift_img = np2tensor(inp_shift_img)\n", 281 | " ref_img = np2tensor(ref_img)\n", 282 | " mask_img = np2tensor(mask_img)\n", 283 | "\n", 284 | "# ratio = 0.95\n", 285 | "# output, _, style = model_inter(inp_img, ref_img, inp_img, ratio)\n", 286 | "\n", 287 | " inp_show = tensor2im(inp_img[0])\n", 288 | " inp_shift_show = tensor2im(inp_shift_img[0])\n", 289 | " ref_show = tensor2im(ref_img[0])\n", 290 | "# oup_show = tensor2im(output[0])\n", 291 | "\n", 292 | " # import pdb\n", 293 | " mask_img1 = tensor2im(mask_img.squeeze())\n", 294 | " mask_img1 = mask_img1.astype(float)/255\n", 295 | " # pdb.set_trace()\n", 296 | "\n", 297 | "# oup_show = oup_show.astype(float)/255\n", 298 | " ref_show = ref_show.astype(float)/255\n", 299 | " inp_show = inp_show.astype(float)/255\n", 300 | " inp_shift_show = inp_shift_show.astype(float)/255\n", 301 | "\n", 302 | "# comp_our = oup_show*(mask_img1)+ref_show*(1-mask_img1)\n", 303 | " comp_input = inp_shift_show*(mask_img1)+ref_show*(1-mask_img1)\n", 304 | "\n", 305 | " mask_img1 = mask_img1[:,:,0]\n", 306 | " comp = Image.fromarray(np.uint8(comp_input*255))\n", 307 | " mask = Image.fromarray(np.uint8(mask_img1*255))\n", 308 | " real = Image.fromarray(np.uint8(comp_input*255))\n", 309 | "\n", 310 | " # apply the same transform to composite and real images\n", 311 | " comp = transform(comp)\n", 312 | " mask = tf.to_tensor(mask)\n", 313 | " real = transform(real)\n", 314 | " # concate the composite and mask as the input of generator\n", 315 | " inputs=torch.cat([comp,mask],0)\n", 316 | "\n", 317 | "# dove_output = DoveNet_model(inputs.unsqueeze(0))\n", 318 | " mask = mask.unsqueeze(0)\n", 319 | " input_show = tensor2im(comp)\n", 320 | "# output_show = tensor2im(dove_output[0])\n", 321 | "\n", 322 | " mask_img1_3 = np.zeros_like(comp_input)\n", 323 | " mask_img1_3[:,:,0] = mask_img1\n", 324 | " mask_img1_3[:,:,1] = mask_img1\n", 325 | " mask_img1_3[:,:,2] = mask_img1\n", 326 | "\n", 327 | "\n", 328 | "\n", 329 | "\n", 330 | "\n", 331 | "\n", 332 | "\n", 333 | " input_image = np.array(Image.open(fore_shift_path).convert('RGB')).astype(float)/255\n", 334 | " back_image = np.array(Image.open(bg_path).convert('RGB')).astype(float)/255\n", 335 | " mask = np.array(Image.open(mask_path).convert('RGB')).astype(float)\n", 336 | "\n", 337 | " mask3 = np.zeros_like(input_image)\n", 338 | " mask3[:,:,0] = mask[:,:,0]/255\n", 339 | " mask3[:,:,1] = mask[:,:,0]/255\n", 340 | " mask3[:,:,2] = mask[:,:,0]/255 \n", 341 | " print(fore_name) \n", 342 | " \n", 343 | " checkpoint = torch.load(\"./checkpoint\", map_location='cpu')\n", 344 | " opt.norm = \"BN\"\n", 345 | " opt.style_norm = \"BN\"\n", 346 | " model_1 = UNetTileSENorm(opt)\n", 347 | " model_1.load_state_dict(checkpoint['model_state_dict'])\n", 348 | " model_1.eval()\n", 349 | " process(inp_img, ref_img, inp_show, model_1, \"NEW_SS_transform_2\")\n", 350 | "# fitting(fore_name, fore_path, 'NEW_SS_transform_1', '_64_128_cycle0.1_newcube.jpg', mask3, back_image)\n", 351 | " EXAMPLES_RESOURCES_DIRECTORY = 'NEW_SS_transform_2'\n", 352 | " COLOUR_CHECKER_IMAGE_PATHS = glob.glob(os.path.join(EXAMPLES_RESOURCES_DIRECTORY, '*.jpg'))\n", 353 | " COLOUR_CHECKER_IMAGES = [\n", 354 | " colour.cctf_decoding(colour.io.read_image(path))\n", 355 | " for path in COLOUR_CHECKER_IMAGE_PATHS\n", 356 | " ]\n", 357 | "\n", 358 | " for iteration in range(1):\n", 359 | "\n", 360 | " color_0 = cv2.resize(COLOUR_CHECKER_IMAGES[1], (64,64)).reshape((-1,3))\n", 361 | " color_1 = cv2.resize(COLOUR_CHECKER_IMAGES[0], (64,64)).reshape((-1,3))\n", 362 | " output = colour.cctf_encoding(\n", 363 | " colour.colour_correction(\n", 364 | " colour.cctf_decoding(colour.io.read_image(fore_shift_path)), color_0, color_1, terms = 17))\n", 365 | " output = np.clip(output, 0, 1)\n", 366 | "\n", 367 | " comp = output[:, :, :3]*(mask3) + back_image * (1-mask3)\n", 368 | " comp = output[:, :, :3]*(mask3) + back_image * (1-mask3)\n", 369 | " comp = comp*255\n", 370 | " comp = Image.fromarray(np.uint8(comp))\n", 371 | " comp.save('./output/'+ fore_name+'_ssh.jpg')\n", 372 | " " 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": null, 378 | "metadata": {}, 379 | "outputs": [], 380 | "source": [ 381 | "import os\n", 382 | "import cv2" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": null, 388 | "metadata": {}, 389 | "outputs": [], 390 | "source": [ 391 | "from skimage.measure import compare_mse as mse\n", 392 | "from skimage.measure import compare_psnr as psnr\n", 393 | "from skimage.measure import compare_ssim as ssim\n", 394 | "\n", 395 | "total_psnr = 0\n", 396 | "total_ssim = 0\n", 397 | "total_mse = 0\n", 398 | "total_lpips = 0\n", 399 | "index_1 = 0\n", 400 | "for files, dirs, root in os.walk(\"./output_1\"):\n", 401 | " for file in root:\n", 402 | " if \".jpg\" in file:\n", 403 | "# print(file)\n", 404 | " ssh = cv2.imread(\"./output_1/\"+file)\n", 405 | " for name in image_small_list:\n", 406 | " if file.replace(\"_ssh.jpg\", \"\") in name:\n", 407 | "# print(name)\n", 408 | " index_1 += 1\n", 409 | " gt = cv2.imread(name.replace(\"_small_fore\", \"_small\").replace(\"_small\", \"_gt\"))\n", 410 | " gt = cv2.resize(gt, (256,256), interpolation=cv2.INTER_AREA)\n", 411 | " ssh = cv2.resize(ssh, (256,256), interpolation=cv2.INTER_AREA)\n", 412 | " psnr_1 = psnr(gt, ssh, data_range=ssh.max() - ssh.min())\n", 413 | " \n", 414 | " total_psnr += psnr_1\n", 415 | "# print(len(image_small_list))\n", 416 | "print(index_1)\n", 417 | "print(total_psnr/index_1)\n", 418 | " \n", 419 | "# img = cv2.imread(f\"../camera_ready/vendor_testing_3/{file}\")\n", 420 | "# cv2.imwrite(f\"../camera_ready/vendor_testing_3/{file}\", img)\n", 421 | "# replace = file.replace(\".png\", \".jpg\")\n", 422 | "# os.system(f\"mv ../camera_ready/vendor_testing_3/{file} ../camera_ready/vendor_testing_3/{replace}\")\n", 423 | "# print(f\"mv ../camera_ready/vendor_testing_3/{file} ../camera_ready/vendor_testing_3/{replace}\")" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": null, 429 | "metadata": {}, 430 | "outputs": [], 431 | "source": [ 432 | "index_1\n", 433 | "total_psnr/200" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": null, 439 | "metadata": {}, 440 | "outputs": [], 441 | "source": [] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": null, 446 | "metadata": {}, 447 | "outputs": [], 448 | "source": [] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": null, 453 | "metadata": {}, 454 | "outputs": [], 455 | "source": [] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": null, 460 | "metadata": {}, 461 | "outputs": [], 462 | "source": [] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": null, 467 | "metadata": {}, 468 | "outputs": [], 469 | "source": [] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": null, 474 | "metadata": {}, 475 | "outputs": [], 476 | "source": [] 477 | } 478 | ], 479 | "metadata": { 480 | "kernelspec": { 481 | "display_name": "Python 3", 482 | "language": "python", 483 | "name": "python3" 484 | }, 485 | "language_info": { 486 | "codemirror_mode": { 487 | "name": "ipython", 488 | "version": 3 489 | }, 490 | "file_extension": ".py", 491 | "mimetype": "text/x-python", 492 | "name": "python", 493 | "nbconvert_exporter": "python", 494 | "pygments_lexer": "ipython3", 495 | "version": "3.6.10" 496 | } 497 | }, 498 | "nbformat": 4, 499 | "nbformat_minor": 2 500 | } 501 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | --------------------------- LICENSE FOR SSH -------------------------------- 2 | BSD License 3 | 4 | For SSH software 5 | Copyright (c) 2021, Yifan Jiang and Adobe Inc. 6 | All rights reserved. 7 | 8 | Redistribution and use in source and binary forms, with or without 9 | modification, are permitted provided that the following conditions are met: 10 | 11 | * Redistributions of source code must retain the above copyright notice, this 12 | list of conditions and the following disclaimer. 13 | 14 | * Redistributions in binary form must reproduce the above copyright notice, 15 | this list of conditions and the following disclaimer in the documentation 16 | and/or other materials provided with the distribution. 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SSH: A Self-Supervised Framework for Image Harmonization (ICCV 2021) 2 | We provide the inference code and collected Dataset for [Self-supervised Image Harmonization](https://arxiv.org/abs/2108.06805) 3 | 4 | ## Guideline 5 | Download the [RealHM dataset](https://drive.google.com/file/d/1jBx-DBtRX8GaqMvMv-CZutK4jn9tz-fT/view?usp=sharing) and put it on `../RealH`'. 6 | Ddownload the [pretrained weight](https://drive.google.com/file/d/19OtBUedEM3QnsUEn0ECxFCmlU0mzDRcT/view?usp=share_link) and put it on `./`. 7 | Create directory `output`. 8 | Then Run the `demo.ipynb`. 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | ## updates 13 | add environment file (Feb. 2022) 14 | Bugs are fixed in the newest version (Jan. 2022) 15 | 16 | ## Representative Examples 17 | ![Visual_Examples](./example.png) 18 | ## Main Pipeline 19 | ![Pipeline](./pipeline.png) 20 | 21 | ## Dataset Pipeline 22 |
23 | Editor 24 |
25 | 26 | ## Citation 27 | ``` 28 | @inproceedings{jiang2021ssh, 29 | title={Ssh: A self-supervised framework for image harmonization}, 30 | author={Jiang, Yifan and Zhang, He and Zhang, Jianming and Wang, Yilin and Lin, Zhe and Sunkavalli, Kalyan and Chen, Simon and Amirghodsi, Sohrab and Kong, Sarah and Wang, Zhangyang}, 31 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 32 | pages={4832--4841}, 33 | year={2021} 34 | } 35 | ``` 36 | -------------------------------------------------------------------------------- /__pycache__/unet_tile_se_norm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SSHarmonization/0bb4ce83883ec9fb8690b3c8dbefb69018a1f6db/__pycache__/unet_tile_se_norm.cpython-36.pyc -------------------------------------------------------------------------------- /dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SSHarmonization/0bb4ce83883ec9fb8690b3c8dbefb69018a1f6db/dataset.png -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from __future__ import division\n", 10 | "%matplotlib inline\n", 11 | "import os\n", 12 | "import torch\n", 13 | "from torch import nn\n", 14 | "import torchvision.transforms.functional as tf\n", 15 | "from matplotlib.pyplot import imshow\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "from PIL import Image\n", 18 | "import numpy as np\n", 19 | "import pdb\n", 20 | "import cv2\n", 21 | "import random\n", 22 | "import glob\n", 23 | "import imageio" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "\n", 33 | "from unet_tile_se_norm import UNetTileSENorm\n", 34 | "from PIL import ImageFile\n", 35 | "ImageFile.LOAD_TRUNCATED_IMAGES = True\n", 36 | "# unet_tile_se_inter\n", 37 | "# from utils import np2tensor\n", 38 | "\n", 39 | "class OPT():\n", 40 | " pass\n", 41 | "\n", 42 | "def np2tensor(numpy_array):\n", 43 | " tensor = torch.from_numpy(np.transpose(numpy_array.copy(), (2, 0, 1))).float()/255.*2. - 1\n", 44 | " return tensor.unsqueeze(0)\n", 45 | "\n", 46 | "def tensor2im(input_image, imtype=np.uint8):\n", 47 | " \"\"\"\"Converts a Tensor array into a numpy image array.\n", 48 | "\n", 49 | " Parameters:\n", 50 | " input_image (tensor) -- the input image tensor array\n", 51 | " imtype (type) -- the desired type of the converted numpy array\n", 52 | " \"\"\"\n", 53 | " if not isinstance(input_image, np.ndarray):\n", 54 | " if isinstance(input_image, torch.Tensor): # get the data from a variable\n", 55 | " image_tensor = input_image.data\n", 56 | " else:\n", 57 | " return input_image\n", 58 | " image_numpy = image_tensor.float().numpy() # convert it into a numpy array\n", 59 | " if image_numpy.shape[0] == 1: # grayscale to RGB\n", 60 | " image_numpy = np.tile(image_numpy, (3, 1, 1))\n", 61 | " image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling\n", 62 | " image_numpy = np.maximum(image_numpy, 0)\n", 63 | " image_numpy = np.minimum(image_numpy, 255)\n", 64 | " else: # if it is a numpy array, do nothing\n", 65 | " image_numpy = input_image\n", 66 | " return image_numpy.astype(imtype)\n", 67 | "\n", 68 | "import torchvision.transforms as transforms\n", 69 | "\n", 70 | "def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):\n", 71 | " transform_list = []\n", 72 | " if grayscale:\n", 73 | " transform_list.append(transforms.Grayscale(1))\n", 74 | " if 'resize' in opt.preprocess:\n", 75 | " osize = [opt.load_size, opt.load_size]\n", 76 | " transform_list.append(transforms.Resize(osize, method))\n", 77 | " elif 'scale_width' in opt.preprocess:\n", 78 | " transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))\n", 79 | "\n", 80 | " if 'crop' in opt.preprocess:\n", 81 | " if params is None:\n", 82 | " transform_list.append(transforms.RandomCrop(opt.crop_size))\n", 83 | " else:\n", 84 | " transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))\n", 85 | "\n", 86 | " if opt.preprocess == 'none':\n", 87 | " pass\n", 88 | "# transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))\n", 89 | "\n", 90 | " if not opt.no_flip:\n", 91 | " if params is None:\n", 92 | " transform_list.append(transforms.RandomHorizontalFlip())\n", 93 | " elif params['flip']:\n", 94 | " transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))\n", 95 | "\n", 96 | " if convert:\n", 97 | " transform_list += [transforms.ToTensor()]\n", 98 | " if grayscale:\n", 99 | " transform_list += [transforms.Normalize((0.5,), (0.5,))]\n", 100 | " else:\n", 101 | " transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n", 102 | " return transforms.Compose(transform_list)\n", 103 | "\n", 104 | "\n" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 4, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "===============================================================================\n", 117 | "* *\n", 118 | "* Interpreter : *\n", 119 | "* python : 3.6.10 |Anaconda, Inc.| (default, May 7 2020, 23:06:31) *\n", 120 | "* [GCC 4.2.1 Compatible Clang 4.0.1 *\n", 121 | "* (tags/RELEASE_401/final)] *\n", 122 | "* *\n", 123 | "* colour-science.org : *\n", 124 | "* colour : 0.3.15 *\n", 125 | "* colour-checker-detection : 0.1.1 *\n", 126 | "* *\n", 127 | "* Runtime : *\n", 128 | "* imageio : 2.9.0 *\n", 129 | "* matplotlib : 3.3.0 *\n", 130 | "* networkx : 2.5 *\n", 131 | "* numpy : 1.19.5 *\n", 132 | "* scipy : 1.5.1 *\n", 133 | "* six : 1.15.0 *\n", 134 | "* opencv : 4.4.0 *\n", 135 | "* *\n", 136 | "===============================================================================\n", 137 | "216\n", 138 | "0\n", 139 | "pexels-johannes-plenio-1123445\n" 140 | ] 141 | }, 142 | { 143 | "name": "stderr", 144 | "output_type": "stream", 145 | "text": [ 146 | "/Users/yifanjiang/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/functional.py:3509: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.\n", 147 | " warnings.warn(\"nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.\")\n", 148 | "/Users/yifanjiang/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/functional.py:3635: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.\n", 149 | " \"See the documentation of nn.Upsample for details.\".format(mode)\n", 150 | "/Users/yifanjiang/anaconda3/envs/torch/lib/python3.6/site-packages/colour/utilities/verbose.py:235: ColourUsageWarning: \"OpenImageIO\" related API features are not available, switching to \"Imageio\"!\n", 151 | " warn(*args, **kwargs)\n" 152 | ] 153 | }, 154 | { 155 | "ename": "KeyboardInterrupt", 156 | "evalue": "", 157 | "output_type": "error", 158 | "traceback": [ 159 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 160 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 161 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 193\u001b[0m output = colour.cctf_encoding(\n\u001b[1;32m 194\u001b[0m colour.colour_correction(\n\u001b[0;32m--> 195\u001b[0;31m colour.cctf_decoding(colour.io.read_image(fore_shift_path)), color_0, color_1, terms = 17))\n\u001b[0m\u001b[1;32m 196\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 162 | "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/colour/characterisation/correction.py\u001b[0m in \u001b[0;36mcolour_correction\u001b[0;34m(RGB, M_T, M_R, method, **kwargs)\u001b[0m\n\u001b[1;32m 915\u001b[0m \u001b[0mfunction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCOLOUR_CORRECTION_METHODS\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mmethod\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 916\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 917\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mRGB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mM_T\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mM_R\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfilter_kwargs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 163 | "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/colour/characterisation/correction.py\u001b[0m in \u001b[0;36mcolour_correction_Cheung2004\u001b[0;34m(RGB, M_T, M_R, terms)\u001b[0m\n\u001b[1;32m 681\u001b[0m \u001b[0mRGB\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mRGB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 682\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 683\u001b[0;31m \u001b[0mRGB_e\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0maugmented_matrix_Cheung2004\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mRGB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mterms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 684\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 685\u001b[0m \u001b[0mCCM\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcolour_correction_matrix_Cheung2004\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mM_T\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mM_R\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mterms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 164 | "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/colour/characterisation/correction.py\u001b[0m in \u001b[0;36maugmented_matrix_Cheung2004\u001b[0;34m(RGB, terms)\u001b[0m\n\u001b[1;32m 165\u001b[0m return tstack([\n\u001b[1;32m 166\u001b[0m \u001b[0mR\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mR\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mG\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mR\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mR\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mB\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mR\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mG\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mB\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 167\u001b[0;31m \u001b[0mR\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mG\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mB\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mR\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mR\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mB\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mones\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 168\u001b[0m ])\n\u001b[1;32m 169\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mterms\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m19\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 165 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "import cv2\n", 171 | "import glob\n", 172 | "import matplotlib.pyplot as plt\n", 173 | "import numpy as np\n", 174 | "import os\n", 175 | "from collections import OrderedDict\n", 176 | "import pdb\n", 177 | "import colour\n", 178 | "from colour.plotting import *\n", 179 | "import imageio\n", 180 | "from colour_checker_detection import (\n", 181 | " EXAMPLES_RESOURCES_DIRECTORY,\n", 182 | " colour_checkers_coordinates_segmentation,\n", 183 | " detect_colour_checkers_segmentation)\n", 184 | "\n", 185 | "from colour_checker_detection.detection.segmentation import (\n", 186 | " adjust_image)\n", 187 | "\n", 188 | "colour.utilities.describe_environment();\n", 189 | "colour_style();\n", 190 | "\n", 191 | "opt = OPT()\n", 192 | "opt.norm = \"IN\"\n", 193 | "opt.preprocess = \"none\"\n", 194 | "opt.no_flip = True\n", 195 | "transform = get_transform(opt)\n", 196 | "\n", 197 | "\n", 198 | "fore_list = glob.glob('./portrait_image/*_cloud_input_WB.jpg') + glob.glob('./portrait2/*_cloud_input_WB.jpg')\n", 199 | "back_list = glob.glob('./street_image/*.jpg') + glob.glob('./background_image/*.jpg')+ glob.glob('./background2/*.jpg')\n", 200 | "# back_list = glob.glob('./portrait2/*_cloud_input_WB.jpg') + glob.glob('./portrait_image/*_cloud_input_WB.jpg') + glob.glob('./street_image/*_WB.jpg') + glob.glob('./background_image/*_WB.jpg')\n", 201 | "\n", 202 | "def fitting(fore_name, fore_path, dir_path, example_name, mask_0, back_image_0):\n", 203 | " EXAMPLES_RESOURCES_DIRECTORY = dir_path\n", 204 | " COLOUR_CHECKER_IMAGE_PATHS = glob.glob(\n", 205 | " os.path.join(EXAMPLES_RESOURCES_DIRECTORY, '*.jpg'))\n", 206 | " COLOUR_CHECKER_IMAGES = [\n", 207 | " colour.cctf_decoding(colour.io.read_image(path))\n", 208 | " for path in COLOUR_CHECKER_IMAGE_PATHS\n", 209 | " ]\n", 210 | "\n", 211 | " for iteration in range(1):\n", 212 | "\n", 213 | " color_0 = cv2.resize(COLOUR_CHECKER_IMAGES[1], (64,64)).reshape((-1,3))\n", 214 | " color_1 = cv2.resize(COLOUR_CHECKER_IMAGES[0], (64,64)).reshape((-1,3))\n", 215 | " output = colour.cctf_encoding(\n", 216 | " colour.colour_correction(\n", 217 | " colour.cctf_decoding(colour.io.read_image(fore_path)), color_0, color_1, terms = 17))\n", 218 | " output = np.clip(output, 0, 1)\n", 219 | "\n", 220 | " comp = output*(mask_0) + back_image_0 * (1-mask_0)\n", 221 | "\n", 222 | " imageio.imwrite('./dovenet_compare_psnr_lr/'+ fore_name+example_name,input_image*255)\n", 223 | " \n", 224 | "def process(inp_img, ref_img, inp_show, model_1, dir_path):\n", 225 | " if not os.path.isdir(dir_path):\n", 226 | " os.makedirs(dir_path)\n", 227 | " SS_scratch_output, _, _ = model_1(inp_img, ref_img)\n", 228 | " SS_scratch_show = tensor2im(SS_scratch_output[0])\n", 229 | " SS_scratch_show = SS_scratch_show.astype(float)/255\n", 230 | " Image.fromarray(np.uint8(inp_show*255)).save(\"./\" + dir_path + \"/input.jpg\")\n", 231 | " Image.fromarray(np.uint8(SS_scratch_show*255)).save(\"./\" + dir_path + \"/result.jpg\")\n", 232 | "\n", 233 | "image_small_list = []\n", 234 | "for root, dirs, files in os.walk(\"../RealHM/vendor_testing_1/\"):\n", 235 | " for file in files:\n", 236 | " if \"small\" in file:\n", 237 | " image_small_list.append(os.path.join(root, file))\n", 238 | "for root, dirs, files in os.walk(\"../RealHM/vendor_testing_2/\"):\n", 239 | " for file in files:\n", 240 | " if \"small\" in file:\n", 241 | " image_small_list.append(os.path.join(root, file))\n", 242 | " \n", 243 | "for root, dirs, files in os.walk(\"../RealHM/vendor_testing_3/\"):\n", 244 | " for file in files:\n", 245 | " if \"small\" in file:\n", 246 | " image_small_list.append(os.path.join(root, file))\n", 247 | " \n", 248 | "print(len(image_small_list))\n", 249 | "for index, file in enumerate(image_small_list):\n", 250 | " if 1 == 1:\n", 251 | " print(index)\n", 252 | " fore_path = file\n", 253 | " fore_shift_path = file.replace(\"_small.jpg\", \".jpg\")\n", 254 | " mask_path = file.replace(\"_small.jpg\", \"_mask.jpg\")\n", 255 | " bg_path = file.replace(\"_small.jpg\", \"_fore.jpg\")\n", 256 | " gt_path = file.replace(\"_small.jpg\", \"_gt.jpg\")\n", 257 | " fore_name = bg_path.replace(\"_fore.jpg\", \"\").replace(\"../vendor_testing_2/\", \"\").replace(\"../vendor_testing_1/\", \"\").replace(\"../vendor_testing_3/\", \"\")\n", 258 | " fore_name = fore_name.split(\"/\")[-1]\n", 259 | "# if os.path.isfile('./output/'+ fore_name+'_ssh.jpg'):\n", 260 | "# continue\n", 261 | " inp_img = Image.open(fore_path).convert('RGB')\n", 262 | " inp_shift_img = Image.open(fore_shift_path).convert('RGB')\n", 263 | " ref_img = Image.open(bg_path).convert('RGB')\n", 264 | " mask_img = Image.open(mask_path).convert('RGB')\n", 265 | " gt_img = Image.open(gt_path).convert('RGB')\n", 266 | "\n", 267 | " inp_img = tf.resize(inp_img, [256, 256])\n", 268 | " inp_shift_img = tf.resize(inp_shift_img, [256, 256])\n", 269 | " ref_img = tf.resize(ref_img, [256, 256])\n", 270 | " mask_img = tf.resize(mask_img, [256, 256])\n", 271 | " gt_img = tf.resize(gt_img, [256, 256])\n", 272 | " \n", 273 | " inp_img = np.array(inp_img)\n", 274 | " inp_shift_img = np.array(inp_shift_img)\n", 275 | " ref_img = np.array(ref_img)\n", 276 | " mask_img = np.array(mask_img)\n", 277 | " gt_img = np.array(gt_img)\n", 278 | "\n", 279 | " inp_img = np2tensor(inp_img)\n", 280 | " inp_shift_img = np2tensor(inp_shift_img)\n", 281 | " ref_img = np2tensor(ref_img)\n", 282 | " mask_img = np2tensor(mask_img)\n", 283 | "\n", 284 | "# ratio = 0.95\n", 285 | "# output, _, style = model_inter(inp_img, ref_img, inp_img, ratio)\n", 286 | "\n", 287 | " inp_show = tensor2im(inp_img[0])\n", 288 | " inp_shift_show = tensor2im(inp_shift_img[0])\n", 289 | " ref_show = tensor2im(ref_img[0])\n", 290 | "# oup_show = tensor2im(output[0])\n", 291 | "\n", 292 | " # import pdb\n", 293 | " mask_img1 = tensor2im(mask_img.squeeze())\n", 294 | " mask_img1 = mask_img1.astype(float)/255\n", 295 | " # pdb.set_trace()\n", 296 | "\n", 297 | "# oup_show = oup_show.astype(float)/255\n", 298 | " ref_show = ref_show.astype(float)/255\n", 299 | " inp_show = inp_show.astype(float)/255\n", 300 | " inp_shift_show = inp_shift_show.astype(float)/255\n", 301 | "\n", 302 | "# comp_our = oup_show*(mask_img1)+ref_show*(1-mask_img1)\n", 303 | " comp_input = inp_shift_show*(mask_img1)+ref_show*(1-mask_img1)\n", 304 | "\n", 305 | " mask_img1 = mask_img1[:,:,0]\n", 306 | " comp = Image.fromarray(np.uint8(comp_input*255))\n", 307 | " mask = Image.fromarray(np.uint8(mask_img1*255))\n", 308 | " real = Image.fromarray(np.uint8(comp_input*255))\n", 309 | "\n", 310 | " # apply the same transform to composite and real images\n", 311 | " comp = transform(comp)\n", 312 | " mask = tf.to_tensor(mask)\n", 313 | " real = transform(real)\n", 314 | " # concate the composite and mask as the input of generator\n", 315 | " inputs=torch.cat([comp,mask],0)\n", 316 | "\n", 317 | "# dove_output = DoveNet_model(inputs.unsqueeze(0))\n", 318 | " mask = mask.unsqueeze(0)\n", 319 | " input_show = tensor2im(comp)\n", 320 | "# output_show = tensor2im(dove_output[0])\n", 321 | "\n", 322 | " mask_img1_3 = np.zeros_like(comp_input)\n", 323 | " mask_img1_3[:,:,0] = mask_img1\n", 324 | " mask_img1_3[:,:,1] = mask_img1\n", 325 | " mask_img1_3[:,:,2] = mask_img1\n", 326 | "\n", 327 | "\n", 328 | "\n", 329 | "\n", 330 | "\n", 331 | "\n", 332 | "\n", 333 | " input_image = np.array(Image.open(fore_shift_path).convert('RGB')).astype(float)/255\n", 334 | " back_image = np.array(Image.open(bg_path).convert('RGB')).astype(float)/255\n", 335 | " mask = np.array(Image.open(mask_path).convert('RGB')).astype(float)\n", 336 | "\n", 337 | " mask3 = np.zeros_like(input_image)\n", 338 | " mask3[:,:,0] = mask[:,:,0]/255\n", 339 | " mask3[:,:,1] = mask[:,:,0]/255\n", 340 | " mask3[:,:,2] = mask[:,:,0]/255 \n", 341 | " print(fore_name) \n", 342 | " \n", 343 | " checkpoint = torch.load(\"./checkpoint\", map_location='cpu')\n", 344 | " opt.norm = \"BN\"\n", 345 | " opt.style_norm = \"BN\"\n", 346 | " model_1 = UNetTileSENorm(opt)\n", 347 | " model_1.load_state_dict(checkpoint['model_state_dict'])\n", 348 | " model_1.eval()\n", 349 | " process(inp_img, ref_img, inp_show, model_1, \"NEW_SS_transform_2\")\n", 350 | "# fitting(fore_name, fore_path, 'NEW_SS_transform_1', '_64_128_cycle0.1_newcube.jpg', mask3, back_image)\n", 351 | " EXAMPLES_RESOURCES_DIRECTORY = 'NEW_SS_transform_2'\n", 352 | " COLOUR_CHECKER_IMAGE_PATHS = glob.glob(os.path.join(EXAMPLES_RESOURCES_DIRECTORY, '*.jpg'))\n", 353 | " COLOUR_CHECKER_IMAGES = [\n", 354 | " colour.cctf_decoding(colour.io.read_image(path))\n", 355 | " for path in COLOUR_CHECKER_IMAGE_PATHS\n", 356 | " ]\n", 357 | "\n", 358 | " for iteration in range(1):\n", 359 | "\n", 360 | " color_0 = cv2.resize(COLOUR_CHECKER_IMAGES[1], (64,64)).reshape((-1,3))\n", 361 | " color_1 = cv2.resize(COLOUR_CHECKER_IMAGES[0], (64,64)).reshape((-1,3))\n", 362 | " output = colour.cctf_encoding(\n", 363 | " colour.colour_correction(\n", 364 | " colour.cctf_decoding(colour.io.read_image(fore_shift_path)), color_0, color_1, terms = 17))\n", 365 | " output = np.clip(output, 0, 1)\n", 366 | "\n", 367 | " comp = output[:, :, :3]*(mask3) + back_image * (1-mask3)\n", 368 | " comp = output[:, :, :3]*(mask3) + back_image * (1-mask3)\n", 369 | " comp = comp*255\n", 370 | " comp = Image.fromarray(np.uint8(comp))\n", 371 | " comp.save('./output/'+ fore_name+'_ssh.jpg')\n", 372 | " " 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": null, 378 | "metadata": {}, 379 | "outputs": [], 380 | "source": [ 381 | "import os\n", 382 | "import cv2" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": null, 388 | "metadata": {}, 389 | "outputs": [], 390 | "source": [ 391 | "from skimage.measure import compare_mse as mse\n", 392 | "from skimage.measure import compare_psnr as psnr\n", 393 | "from skimage.measure import compare_ssim as ssim\n", 394 | "\n", 395 | "total_psnr = 0\n", 396 | "total_ssim = 0\n", 397 | "total_mse = 0\n", 398 | "total_lpips = 0\n", 399 | "index_1 = 0\n", 400 | "for files, dirs, root in os.walk(\"./output_1\"):\n", 401 | " for file in root:\n", 402 | " if \".jpg\" in file:\n", 403 | "# print(file)\n", 404 | " ssh = cv2.imread(\"./output_1/\"+file)\n", 405 | " for name in image_small_list:\n", 406 | " if file.replace(\"_ssh.jpg\", \"\") in name:\n", 407 | "# print(name)\n", 408 | " index_1 += 1\n", 409 | " gt = cv2.imread(name.replace(\"_small_fore\", \"_small\").replace(\"_small\", \"_gt\"))\n", 410 | " gt = cv2.resize(gt, (256,256), interpolation=cv2.INTER_AREA)\n", 411 | " ssh = cv2.resize(ssh, (256,256), interpolation=cv2.INTER_AREA)\n", 412 | " psnr_1 = psnr(gt, ssh, data_range=ssh.max() - ssh.min())\n", 413 | " \n", 414 | " total_psnr += psnr_1\n", 415 | "# print(len(image_small_list))\n", 416 | "print(index_1)\n", 417 | "print(total_psnr/index_1)\n", 418 | " \n", 419 | "# img = cv2.imread(f\"../camera_ready/vendor_testing_3/{file}\")\n", 420 | "# cv2.imwrite(f\"../camera_ready/vendor_testing_3/{file}\", img)\n", 421 | "# replace = file.replace(\".png\", \".jpg\")\n", 422 | "# os.system(f\"mv ../camera_ready/vendor_testing_3/{file} ../camera_ready/vendor_testing_3/{replace}\")\n", 423 | "# print(f\"mv ../camera_ready/vendor_testing_3/{file} ../camera_ready/vendor_testing_3/{replace}\")" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": null, 429 | "metadata": {}, 430 | "outputs": [], 431 | "source": [ 432 | "index_1\n", 433 | "total_psnr/200" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": null, 439 | "metadata": {}, 440 | "outputs": [], 441 | "source": [] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": null, 446 | "metadata": {}, 447 | "outputs": [], 448 | "source": [] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": null, 453 | "metadata": {}, 454 | "outputs": [], 455 | "source": [] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": null, 460 | "metadata": {}, 461 | "outputs": [], 462 | "source": [] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": null, 467 | "metadata": {}, 468 | "outputs": [], 469 | "source": [] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": null, 474 | "metadata": {}, 475 | "outputs": [], 476 | "source": [] 477 | } 478 | ], 479 | "metadata": { 480 | "kernelspec": { 481 | "display_name": "Python 3", 482 | "language": "python", 483 | "name": "python3" 484 | }, 485 | "language_info": { 486 | "codemirror_mode": { 487 | "name": "ipython", 488 | "version": 3 489 | }, 490 | "file_extension": ".py", 491 | "mimetype": "text/x-python", 492 | "name": "python", 493 | "nbconvert_exporter": "python", 494 | "pygments_lexer": "ipython3", 495 | "version": "3.6.10" 496 | } 497 | }, 498 | "nbformat": 4, 499 | "nbformat_minor": 2 500 | } 501 | -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SSHarmonization/0bb4ce83883ec9fb8690b3c8dbefb69018a1f6db/example.png -------------------------------------------------------------------------------- /pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SSHarmonization/0bb4ce83883ec9fb8690b3c8dbefb69018a1f6db/pipeline.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | colour-checker-detection==0.1.1 2 | colour-science==0.3.15 3 | torch==1.7 4 | matplotlib==3.3.4 5 | torchvision==0.8.0 6 | opencv-python==4.5.5 7 | -------------------------------------------------------------------------------- /unet_tile_se_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | def pad_tensor(input): 6 | 7 | height_org, width_org = input.shape[2], input.shape[3] 8 | divide = 16 9 | 10 | if width_org % divide != 0 or height_org % divide != 0: 11 | 12 | width_res = width_org % divide 13 | height_res = height_org % divide 14 | if width_res != 0: 15 | width_div = divide - width_res 16 | pad_left = int(width_div / 2) 17 | pad_right = int(width_div - pad_left) 18 | else: 19 | pad_left = 0 20 | pad_right = 0 21 | 22 | if height_res != 0: 23 | height_div = divide - height_res 24 | pad_top = int(height_div / 2) 25 | pad_bottom = int(height_div - pad_top) 26 | else: 27 | pad_top = 0 28 | pad_bottom = 0 29 | 30 | padding = nn.ReflectionPad2d((pad_left, pad_right, pad_top, pad_bottom)) 31 | input = padding(input) 32 | else: 33 | pad_left = 0 34 | pad_right = 0 35 | pad_top = 0 36 | pad_bottom = 0 37 | 38 | height, width = input.data.shape[2], input.data.shape[3] 39 | assert width % divide == 0, 'width cant divided by stride' 40 | assert height % divide == 0, 'height cant divided by stride' 41 | 42 | return input, pad_left, pad_right, pad_top, pad_bottom 43 | 44 | def pad_tensor_back(input, pad_left, pad_right, pad_top, pad_bottom): 45 | height, width = input.shape[2], input.shape[3] 46 | return input[:,:, pad_top: height - pad_bottom, pad_left: width - pad_right] 47 | 48 | def Normalization(opt, dim): 49 | if opt.norm == 'BN': 50 | return nn.BatchNorm2d(dim) 51 | elif opt.norm == 'IN': 52 | return nn.InstanceNorm2d(dim) 53 | 54 | def Pooling(opt, stride): 55 | return nn.MaxPool2d(stride) 56 | 57 | class h_sigmoid(nn.Module): 58 | def __init__(self, inplace=True): 59 | super(h_sigmoid, self).__init__() 60 | self.relu = nn.ReLU6(inplace=inplace) 61 | 62 | def forward(self, x): 63 | return self.relu(x + 3) / 6 64 | 65 | class SELayer(nn.Module): 66 | def __init__(self, channel, reduction=4): 67 | super(SELayer, self).__init__() 68 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 69 | self.fc = nn.Sequential( 70 | nn.Linear(channel, channel // reduction), 71 | nn.ReLU(inplace=True), 72 | nn.Linear(channel // reduction, channel), 73 | h_sigmoid(), 74 | ) 75 | 76 | def forward(self, x): 77 | b, c, _, _ = x.size() 78 | y = self.avg_pool(x).view(b, c) 79 | y = self.fc(y).view(b, c, 1, 1) 80 | return x * y 81 | 82 | class Conv2dBlock(nn.Module): 83 | def __init__(self, input_dim ,output_dim, kernel_size, stride, 84 | padding=0, norm='none', activation='relu', pad_type='zero'): 85 | super(Conv2dBlock, self).__init__() 86 | self.use_bias = True 87 | # initialize padding 88 | if pad_type == 'reflect': 89 | self.pad = nn.ReflectionPad2d(padding) 90 | elif pad_type == 'replicate': 91 | self.pad = nn.ReplicationPad2d(padding) 92 | elif pad_type == 'zero': 93 | self.pad = nn.ZeroPad2d(padding) 94 | else: 95 | assert 0, "Unsupported padding type: {}".format(pad_type) 96 | 97 | # initialize normalization 98 | norm_dim = output_dim 99 | if norm == 'BN': 100 | self.norm = nn.BatchNorm2d(norm_dim) 101 | elif norm == 'IN': 102 | #self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True) 103 | self.norm = nn.InstanceNorm2d(norm_dim) 104 | elif norm == 'ln': 105 | self.norm = LayerNorm(norm_dim) 106 | elif norm == 'adain': 107 | self.norm = AdaptiveInstanceNorm2d(norm_dim) 108 | elif norm == 'none' or norm == 'sn': 109 | self.norm = None 110 | else: 111 | assert 0, "Unsupported normalization: {}".format(norm) 112 | 113 | # initialize activation 114 | if activation == 'relu': 115 | self.activation = nn.ReLU(inplace=True) 116 | elif activation == 'lrelu': 117 | self.activation = nn.LeakyReLU(0.2, inplace=True) 118 | elif activation == 'prelu': 119 | self.activation = nn.PReLU() 120 | elif activation == 'selu': 121 | self.activation = nn.SELU(inplace=True) 122 | elif activation == 'tanh': 123 | self.activation = nn.Tanh() 124 | elif activation == 'none': 125 | self.activation = None 126 | else: 127 | assert 0, "Unsupported activation: {}".format(activation) 128 | 129 | # initialize convolution 130 | if norm == 'sn': 131 | self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)) 132 | else: 133 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) 134 | 135 | def forward(self, x): 136 | x = self.conv(self.pad(x)) 137 | if self.activation: 138 | x = self.activation(x) 139 | if self.norm: 140 | x = self.norm(x) 141 | 142 | return x 143 | 144 | 145 | class StyleEncoder(nn.Module): 146 | def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type): 147 | super(StyleEncoder, self).__init__() 148 | self.model = [] 149 | self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] 150 | for i in range(2): 151 | self.model += [Conv2dBlock(dim, 2 * dim, 3, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] 152 | dim *= 2 153 | for i in range(n_downsample - 2): 154 | self.model += [Conv2dBlock(dim, 2 * dim, 3, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] 155 | dim *= 2 156 | # self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling 157 | # self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)] 158 | self.model = nn.Sequential(*self.model) 159 | self.se = SELayer(512) 160 | self.output_dim = dim 161 | 162 | def forward(self, x): 163 | x = self.model(x) 164 | out = self.se(x) 165 | return out 166 | 167 | class UNetTileSENorm(nn.Module): 168 | def __init__(self, opt): 169 | super(UNetTileSENorm, self).__init__() 170 | 171 | self.opt = opt 172 | p = 1 173 | # self.conv1_1 = nn.Conv2d(4, 32, 3, padding=p) 174 | self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p) 175 | self.LReLU1_1 = nn.LeakyReLU(0.2, inplace=True) 176 | self.bn1_1 = Normalization(opt, 32) 177 | 178 | self.conv1_2 = nn.Conv2d(32, 32, 3, padding=p) 179 | self.LReLU1_2 = nn.LeakyReLU(0.2, inplace=True) 180 | self.bn1_2 = Normalization(opt, 32) 181 | self.max_pool1 = Pooling(opt, 2) 182 | 183 | self.conv2_1 = nn.Conv2d(32, 64, 3, padding=p) 184 | self.LReLU2_1 = nn.LeakyReLU(0.2, inplace=True) 185 | self.bn2_1 = Normalization(opt, 64) 186 | 187 | self.conv2_2 = nn.Conv2d(64, 64, 3, padding=p) 188 | self.LReLU2_2 = nn.LeakyReLU(0.2, inplace=True) 189 | self.bn2_2 = Normalization(opt, 64) 190 | self.max_pool2 = Pooling(opt, 2) 191 | 192 | self.conv3_1 = nn.Conv2d(64, 128, 3, padding=p) 193 | self.LReLU3_1 = nn.LeakyReLU(0.2, inplace=True) 194 | self.bn3_1 = Normalization(opt, 128) 195 | self.conv3_2 = nn.Conv2d(128, 128, 3, padding=p) 196 | self.LReLU3_2 = nn.LeakyReLU(0.2, inplace=True) 197 | self.bn3_2 = Normalization(opt, 128) 198 | self.max_pool3 = Pooling(opt, 2) 199 | 200 | self.conv4_1 = nn.Conv2d(128, 256, 3, padding=p) 201 | self.LReLU4_1 = nn.LeakyReLU(0.2, inplace=True) 202 | self.bn4_1 = Normalization(opt, 256) 203 | self.conv4_2 = nn.Conv2d(256, 256, 3, padding=p) 204 | self.LReLU4_2 = nn.LeakyReLU(0.2, inplace=True) 205 | self.bn4_2 = Normalization(opt, 256) 206 | self.max_pool4 = Pooling(opt, 2) 207 | 208 | self.conv5_1 = nn.Conv2d(256, 512, 3, padding=p) 209 | self.LReLU5_1 = nn.LeakyReLU(0.2, inplace=True) 210 | self.bn5_1 = Normalization(opt, 512) 211 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=p) 212 | self.LReLU5_2 = nn.LeakyReLU(0.2, inplace=True) 213 | self.bn5_2 = Normalization(opt, 512) 214 | 215 | self.se_1 = SELayer(512) 216 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 217 | self.style_encoder = StyleEncoder(4, 3, 32, 512, norm=opt.style_norm, activ="relu", pad_type='reflect') 218 | self.content_style_conv = nn.Conv2d(1024, 512, 3, padding=p) 219 | 220 | self.deconv5 = nn.Conv2d(512, 256, 3, padding=p) 221 | self.conv6_1 = nn.Conv2d(512, 256, 3, padding=p) 222 | self.LReLU6_1 = nn.LeakyReLU(0.2, inplace=True) 223 | self.bn6_1 = Normalization(opt, 256) 224 | self.conv6_2 = nn.Conv2d(256, 256, 3, padding=p) 225 | self.LReLU6_2 = nn.LeakyReLU(0.2, inplace=True) 226 | self.bn6_2 = Normalization(opt, 256) 227 | 228 | self.deconv6 = nn.Conv2d(256, 128, 3, padding=p) 229 | self.conv7_1 = nn.Conv2d(256, 128, 3, padding=p) 230 | self.LReLU7_1 = nn.LeakyReLU(0.2, inplace=True) 231 | self.bn7_1 = Normalization(opt, 128) 232 | self.conv7_2 = nn.Conv2d(128, 128, 3, padding=p) 233 | self.LReLU7_2 = nn.LeakyReLU(0.2, inplace=True) 234 | self.bn7_2 = Normalization(opt, 128) 235 | 236 | self.deconv7 = nn.Conv2d(128, 64, 3, padding=p) 237 | self.conv8_1 = nn.Conv2d(128, 64, 3, padding=p) 238 | self.LReLU8_1 = nn.LeakyReLU(0.2, inplace=True) 239 | self.bn8_1 = Normalization(opt, 64) 240 | self.conv8_2 = nn.Conv2d(64, 64, 3, padding=p) 241 | self.LReLU8_2 = nn.LeakyReLU(0.2, inplace=True) 242 | self.bn8_2 = Normalization(opt, 64) 243 | 244 | self.deconv8 = nn.Conv2d(64, 32, 3, padding=p) 245 | self.conv9_1 = nn.Conv2d(64, 32, 3, padding=p) 246 | self.LReLU9_1 = nn.LeakyReLU(0.2, inplace=True) 247 | self.bn9_1 = Normalization(opt, 32) 248 | self.conv9_2 = nn.Conv2d(32, 32, 3, padding=p) 249 | self.LReLU9_2 = nn.LeakyReLU(0.2, inplace=True) 250 | 251 | self.conv10 = nn.Conv2d(32, 3, 1) 252 | 253 | def forward(self, input, ref=None): 254 | input, pad_left, pad_right, pad_top, pad_bottom = pad_tensor(input) 255 | 256 | if ref is not None: 257 | ref_h = self.style_encoder(ref) 258 | content_style = self.style_encoder(input) 259 | else: 260 | ref_h = self.style_encoder(input) 261 | content_style = ref_h 262 | x = self.bn1_1(self.LReLU1_1(self.conv1_1(input))) 263 | conv1 = self.bn1_2(self.LReLU1_2(self.conv1_2(x))) 264 | x = self.max_pool1(conv1) 265 | 266 | x = self.bn2_1(self.LReLU2_1(self.conv2_1(x))) 267 | conv2 = self.bn2_2(self.LReLU2_2(self.conv2_2(x))) 268 | x = self.max_pool2(conv2) 269 | 270 | x = self.bn3_1(self.LReLU3_1(self.conv3_1(x))) 271 | conv3 = self.bn3_2(self.LReLU3_2(self.conv3_2(x))) 272 | x = self.max_pool3(conv3) 273 | 274 | x = self.bn4_1(self.LReLU4_1(self.conv4_1(x))) 275 | conv4 = self.bn4_2(self.LReLU4_2(self.conv4_2(x))) 276 | x = self.max_pool4(conv4) 277 | 278 | x = self.bn5_1(self.LReLU5_1(self.conv5_1(x))) 279 | h = self.bn5_2(self.LReLU5_2(self.conv5_2(x))) 280 | content = self.se_1(h) 281 | 282 | style = self.avg_pool(ref_h) 283 | _, _, h, w = content.size() 284 | ref_h = style.repeat(1, 1, h, w).view(-1, 512, h, w) 285 | content_style_cat = torch.cat([content, ref_h], 1) 286 | content_style_cat = self.content_style_conv(content_style_cat) 287 | content_style_cat = F.upsample(content_style_cat, scale_factor=2, mode='bilinear') 288 | up6 = torch.cat([self.deconv5(content_style_cat), conv4], 1) 289 | x = self.bn6_1(self.LReLU6_1(self.conv6_1(up6))) 290 | conv6 = self.bn6_2(self.LReLU6_2(self.conv6_2(x))) 291 | 292 | conv6 = F.upsample(conv6, scale_factor=2, mode='bilinear') 293 | up7 = torch.cat([self.deconv6(conv6), conv3], 1) 294 | x = self.bn7_1(self.LReLU7_1(self.conv7_1(up7))) 295 | conv7 = self.bn7_2(self.LReLU7_2(self.conv7_2(x))) 296 | 297 | conv7 = F.upsample(conv7, scale_factor=2, mode='bilinear') 298 | up8 = torch.cat([self.deconv7(conv7), conv2], 1) 299 | x = self.bn8_1(self.LReLU8_1(self.conv8_1(up8))) 300 | conv8 = self.bn8_2(self.LReLU8_2(self.conv8_2(x))) 301 | 302 | conv8 = F.upsample(conv8, scale_factor=2, mode='bilinear') 303 | up9 = torch.cat([self.deconv8(conv8), conv1], 1) 304 | x = self.bn9_1(self.LReLU9_1(self.conv9_1(up9))) 305 | conv9 = self.LReLU9_2(self.conv9_2(x)) 306 | 307 | latent = self.conv10(conv9) 308 | 309 | output = latent 310 | target_style = self.avg_pool(self.style_encoder(output)) 311 | # target_style = target_style.repeat(1, 1, h, w).view(-1, 512, h, w) 312 | content_style = self.avg_pool(content_style) 313 | # content_style = content_style.repeat(1, 1, h, w).view(-1, 512, h, w) 314 | output = pad_tensor_back(output, pad_left, pad_right, pad_top, pad_bottom) 315 | return output, content, [style, content_style, target_style] 316 | 317 | if __name__ == '__main__': 318 | model = UNet(None) 319 | model = model.cuda() 320 | model.train() 321 | tensor = torch.randn(1, 3, 512, 512).cuda() 322 | output = model(tensor) 323 | print("finished") --------------------------------------------------------------------------------