├── .gitignore ├── README.md ├── dataset ├── mnist │ └── __init__.py └── svhn │ └── __init__.py ├── extra ├── functional.py ├── network.jpg ├── result.jpg └── transforms.py ├── main.py ├── model.py ├── models └── __init__.py ├── rec_image.py ├── recovery_image └── __init__.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## This is a pytorch implementation of the model [Deep Reconstruction-Classification Network for Unsupervised Domain Adapation (DRCN)](https://arxiv.org/abs/1607.03516). 2 | 3 | ## Environment 4 | 5 | - Pytorch 0.4.0 6 | - Python 2.7 7 | 8 | ## Structure 9 | 10 | ![DRCN](./extra/network.jpg) 11 | 12 | ## Usage 13 | 14 | - put the mnist and svhn data in the entries in `dataset`, respectively 15 | - if there is no Grayscale transform in your torchvision, please replace your `functional.py` and `transforms.py` 16 | with provided files in `extra` 17 | - run `python main.py` for training 18 | - the trained model will be saved in `model`, and recontructed images saved in `recovery_image` 19 | - In our implementation, no denoising include 20 | 21 | ## Result 22 | 23 | ![real svhn](./extra/result.jpg) 24 | 25 | **Real and Recovered SVHN images** 26 | -------------------------------------------------------------------------------- /dataset/mnist/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fungtion/DRCN/ab4fafd7a58ade83e42b33849720d268ba19a701/dataset/mnist/__init__.py -------------------------------------------------------------------------------- /dataset/svhn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fungtion/DRCN/ab4fafd7a58ade83e42b33849720d268ba19a701/dataset/svhn/__init__.py -------------------------------------------------------------------------------- /extra/functional.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps, ImageEnhance 6 | try: 7 | import accimage 8 | except ImportError: 9 | accimage = None 10 | import numpy as np 11 | import numbers 12 | import types 13 | import collections 14 | import warnings 15 | 16 | 17 | def _is_pil_image(img): 18 | if accimage is not None: 19 | return isinstance(img, (Image.Image, accimage.Image)) 20 | else: 21 | return isinstance(img, Image.Image) 22 | 23 | 24 | def _is_tensor_image(img): 25 | return torch.is_tensor(img) and img.ndimension() == 3 26 | 27 | 28 | def _is_numpy_image(img): 29 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 30 | 31 | 32 | def to_tensor(pic): 33 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 34 | 35 | See ``ToTensor`` for more details. 36 | 37 | Args: 38 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 39 | 40 | Returns: 41 | Tensor: Converted image. 42 | """ 43 | if not(_is_pil_image(pic) or _is_numpy_image(pic)): 44 | raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) 45 | 46 | if isinstance(pic, np.ndarray): 47 | # handle numpy array 48 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 49 | # backward compatibility 50 | if isinstance(img, torch.ByteTensor): 51 | return img.float().div(255) 52 | else: 53 | return img 54 | 55 | if accimage is not None and isinstance(pic, accimage.Image): 56 | nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) 57 | pic.copyto(nppic) 58 | return torch.from_numpy(nppic) 59 | 60 | # handle PIL Image 61 | if pic.mode == 'I': 62 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 63 | elif pic.mode == 'I;16': 64 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 65 | elif pic.mode == 'F': 66 | img = torch.from_numpy(np.array(pic, np.float32, copy=False)) 67 | else: 68 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 69 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 70 | if pic.mode == 'YCbCr': 71 | nchannel = 3 72 | elif pic.mode == 'I;16': 73 | nchannel = 1 74 | else: 75 | nchannel = len(pic.mode) 76 | img = img.view(pic.size[1], pic.size[0], nchannel) 77 | # put it from HWC to CHW format 78 | # yikes, this transpose takes 80% of the loading time/CPU 79 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 80 | if isinstance(img, torch.ByteTensor): 81 | return img.float().div(255) 82 | else: 83 | return img 84 | 85 | 86 | def to_pil_image(pic, mode=None): 87 | """Convert a tensor or an ndarray to PIL Image. 88 | 89 | See :class:`~torchvision.transforms.ToPIlImage` for more details. 90 | 91 | Args: 92 | pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. 93 | mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). 94 | 95 | .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes 96 | 97 | Returns: 98 | PIL Image: Image converted to PIL Image. 99 | """ 100 | if not(_is_numpy_image(pic) or _is_tensor_image(pic)): 101 | raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) 102 | 103 | npimg = pic 104 | if isinstance(pic, torch.FloatTensor): 105 | pic = pic.mul(255).byte() 106 | if torch.is_tensor(pic): 107 | npimg = np.transpose(pic.numpy(), (1, 2, 0)) 108 | 109 | if not isinstance(npimg, np.ndarray): 110 | raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + 111 | 'not {}'.format(type(npimg))) 112 | 113 | if npimg.shape[2] == 1: 114 | expected_mode = None 115 | npimg = npimg[:, :, 0] 116 | if npimg.dtype == np.uint8: 117 | expected_mode = 'L' 118 | elif npimg.dtype == np.int16: 119 | expected_mode = 'I;16' 120 | elif npimg.dtype == np.int32: 121 | expected_mode = 'I' 122 | elif npimg.dtype == np.float32: 123 | expected_mode = 'F' 124 | if mode is not None and mode != expected_mode: 125 | raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}" 126 | .format(mode, np.dtype, expected_mode)) 127 | mode = expected_mode 128 | 129 | elif npimg.shape[2] == 4: 130 | permitted_4_channel_modes = ['RGBA', 'CMYK'] 131 | if mode is not None and mode not in permitted_4_channel_modes: 132 | raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) 133 | 134 | if mode is None and npimg.dtype == np.uint8: 135 | mode = 'RGBA' 136 | else: 137 | permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] 138 | if mode is not None and mode not in permitted_3_channel_modes: 139 | raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes)) 140 | if mode is None and npimg.dtype == np.uint8: 141 | mode = 'RGB' 142 | 143 | if mode is None: 144 | raise TypeError('Input type {} is not supported'.format(npimg.dtype)) 145 | 146 | return Image.fromarray(npimg, mode=mode) 147 | 148 | 149 | def normalize(tensor, mean, std): 150 | """Normalize a tensor image with mean and standard deviation. 151 | 152 | See ``Normalize`` for more details. 153 | 154 | Args: 155 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 156 | mean (sequence): Sequence of means for each channel. 157 | std (sequence): Sequence of standard deviations for each channely. 158 | 159 | Returns: 160 | Tensor: Normalized Tensor image. 161 | """ 162 | if not _is_tensor_image(tensor): 163 | raise TypeError('tensor is not a torch image.') 164 | # TODO: make efficient 165 | for t, m, s in zip(tensor, mean, std): 166 | t.sub_(m).div_(s) 167 | return tensor 168 | 169 | 170 | def resize(img, size, interpolation=Image.BILINEAR): 171 | """Resize the input PIL Image to the given size. 172 | 173 | Args: 174 | img (PIL Image): Image to be resized. 175 | size (sequence or int): Desired output size. If size is a sequence like 176 | (h, w), the output size will be matched to this. If size is an int, 177 | the smaller edge of the image will be matched to this number maintaing 178 | the aspect ratio. i.e, if height > width, then image will be rescaled to 179 | (size * height / width, size) 180 | interpolation (int, optional): Desired interpolation. Default is 181 | ``PIL.Image.BILINEAR`` 182 | 183 | Returns: 184 | PIL Image: Resized image. 185 | """ 186 | if not _is_pil_image(img): 187 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 188 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): 189 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 190 | 191 | if isinstance(size, int): 192 | w, h = img.size 193 | if (w <= h and w == size) or (h <= w and h == size): 194 | return img 195 | if w < h: 196 | ow = size 197 | oh = int(size * h / w) 198 | return img.resize((ow, oh), interpolation) 199 | else: 200 | oh = size 201 | ow = int(size * w / h) 202 | return img.resize((ow, oh), interpolation) 203 | else: 204 | return img.resize(size[::-1], interpolation) 205 | 206 | 207 | def scale(*args, **kwargs): 208 | warnings.warn("The use of the transforms.Scale transform is deprecated, " + 209 | "please use transforms.Resize instead.") 210 | return resize(*args, **kwargs) 211 | 212 | 213 | def pad(img, padding, fill=0): 214 | """Pad the given PIL Image on all sides with the given "pad" value. 215 | 216 | Args: 217 | img (PIL Image): Image to be padded. 218 | padding (int or tuple): Padding on each border. If a single int is provided this 219 | is used to pad all borders. If tuple of length 2 is provided this is the padding 220 | on left/right and top/bottom respectively. If a tuple of length 4 is provided 221 | this is the padding for the left, top, right and bottom borders 222 | respectively. 223 | fill: Pixel fill value. Default is 0. If a tuple of 224 | length 3, it is used to fill R, G, B channels respectively. 225 | 226 | Returns: 227 | PIL Image: Padded image. 228 | """ 229 | if not _is_pil_image(img): 230 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 231 | 232 | if not isinstance(padding, (numbers.Number, tuple)): 233 | raise TypeError('Got inappropriate padding arg') 234 | if not isinstance(fill, (numbers.Number, str, tuple)): 235 | raise TypeError('Got inappropriate fill arg') 236 | 237 | if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: 238 | raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + 239 | "{} element tuple".format(len(padding))) 240 | 241 | return ImageOps.expand(img, border=padding, fill=fill) 242 | 243 | 244 | def crop(img, i, j, h, w): 245 | """Crop the given PIL Image. 246 | 247 | Args: 248 | img (PIL Image): Image to be cropped. 249 | i: Upper pixel coordinate. 250 | j: Left pixel coordinate. 251 | h: Height of the cropped image. 252 | w: Width of the cropped image. 253 | 254 | Returns: 255 | PIL Image: Cropped image. 256 | """ 257 | if not _is_pil_image(img): 258 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 259 | 260 | return img.crop((j, i, j + w, i + h)) 261 | 262 | 263 | def center_crop(img, output_size): 264 | if isinstance(output_size, numbers.Number): 265 | output_size = (int(output_size), int(output_size)) 266 | w, h = img.size 267 | th, tw = output_size 268 | i = int(round((h - th) / 2.)) 269 | j = int(round((w - tw) / 2.)) 270 | return crop(img, i, j, th, tw) 271 | 272 | 273 | def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): 274 | """Crop the given PIL Image and resize it to desired size. 275 | 276 | Notably used in RandomResizedCrop. 277 | 278 | Args: 279 | img (PIL Image): Image to be cropped. 280 | i: Upper pixel coordinate. 281 | j: Left pixel coordinate. 282 | h: Height of the cropped image. 283 | w: Width of the cropped image. 284 | size (sequence or int): Desired output size. Same semantics as ``scale``. 285 | interpolation (int, optional): Desired interpolation. Default is 286 | ``PIL.Image.BILINEAR``. 287 | Returns: 288 | PIL Image: Cropped image. 289 | """ 290 | assert _is_pil_image(img), 'img should be PIL Image' 291 | img = crop(img, i, j, h, w) 292 | img = resize(img, size, interpolation) 293 | return img 294 | 295 | 296 | def hflip(img): 297 | """Horizontally flip the given PIL Image. 298 | 299 | Args: 300 | img (PIL Image): Image to be flipped. 301 | 302 | Returns: 303 | PIL Image: Horizontall flipped image. 304 | """ 305 | if not _is_pil_image(img): 306 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 307 | 308 | return img.transpose(Image.FLIP_LEFT_RIGHT) 309 | 310 | 311 | def vflip(img): 312 | """Vertically flip the given PIL Image. 313 | 314 | Args: 315 | img (PIL Image): Image to be flipped. 316 | 317 | Returns: 318 | PIL Image: Vertically flipped image. 319 | """ 320 | if not _is_pil_image(img): 321 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 322 | 323 | return img.transpose(Image.FLIP_TOP_BOTTOM) 324 | 325 | 326 | def five_crop(img, size): 327 | """Crop the given PIL Image into four corners and the central crop. 328 | 329 | .. Note:: 330 | This transform returns a tuple of images and there may be a 331 | mismatch in the number of inputs and targets your ``Dataset`` returns. 332 | 333 | Args: 334 | size (sequence or int): Desired output size of the crop. If size is an 335 | int instead of sequence like (h, w), a square crop (size, size) is 336 | made. 337 | Returns: 338 | tuple: tuple (tl, tr, bl, br, center) corresponding top left, 339 | top right, bottom left, bottom right and center crop. 340 | """ 341 | if isinstance(size, numbers.Number): 342 | size = (int(size), int(size)) 343 | else: 344 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 345 | 346 | w, h = img.size 347 | crop_h, crop_w = size 348 | if crop_w > w or crop_h > h: 349 | raise ValueError("Requested crop size {} is bigger than input size {}".format(size, 350 | (h, w))) 351 | tl = img.crop((0, 0, crop_w, crop_h)) 352 | tr = img.crop((w - crop_w, 0, w, crop_h)) 353 | bl = img.crop((0, h - crop_h, crop_w, h)) 354 | br = img.crop((w - crop_w, h - crop_h, w, h)) 355 | center = center_crop(img, (crop_h, crop_w)) 356 | return (tl, tr, bl, br, center) 357 | 358 | 359 | def ten_crop(img, size, vertical_flip=False): 360 | """Crop the given PIL Image into four corners and the central crop plus the 361 | flipped version of these (horizontal flipping is used by default). 362 | 363 | .. Note:: 364 | This transform returns a tuple of images and there may be a 365 | mismatch in the number of inputs and targets your ``Dataset`` returns. 366 | 367 | Args: 368 | size (sequence or int): Desired output size of the crop. If size is an 369 | int instead of sequence like (h, w), a square crop (size, size) is 370 | made. 371 | vertical_flip (bool): Use vertical flipping instead of horizontal 372 | 373 | Returns: 374 | tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, 375 | br_flip, center_flip) corresponding top left, top right, 376 | bottom left, bottom right and center crop and same for the 377 | flipped image. 378 | """ 379 | if isinstance(size, numbers.Number): 380 | size = (int(size), int(size)) 381 | else: 382 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 383 | 384 | first_five = five_crop(img, size) 385 | 386 | if vertical_flip: 387 | img = vflip(img) 388 | else: 389 | img = hflip(img) 390 | 391 | second_five = five_crop(img, size) 392 | return first_five + second_five 393 | 394 | 395 | def adjust_brightness(img, brightness_factor): 396 | """Adjust brightness of an Image. 397 | 398 | Args: 399 | img (PIL Image): PIL Image to be adjusted. 400 | brightness_factor (float): How much to adjust the brightness. Can be 401 | any non negative number. 0 gives a black image, 1 gives the 402 | original image while 2 increases the brightness by a factor of 2. 403 | 404 | Returns: 405 | PIL Image: Brightness adjusted image. 406 | """ 407 | if not _is_pil_image(img): 408 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 409 | 410 | enhancer = ImageEnhance.Brightness(img) 411 | img = enhancer.enhance(brightness_factor) 412 | return img 413 | 414 | 415 | def adjust_contrast(img, contrast_factor): 416 | """Adjust contrast of an Image. 417 | 418 | Args: 419 | img (PIL Image): PIL Image to be adjusted. 420 | contrast_factor (float): How much to adjust the contrast. Can be any 421 | non negative number. 0 gives a solid gray image, 1 gives the 422 | original image while 2 increases the contrast by a factor of 2. 423 | 424 | Returns: 425 | PIL Image: Contrast adjusted image. 426 | """ 427 | if not _is_pil_image(img): 428 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 429 | 430 | enhancer = ImageEnhance.Contrast(img) 431 | img = enhancer.enhance(contrast_factor) 432 | return img 433 | 434 | 435 | def adjust_saturation(img, saturation_factor): 436 | """Adjust color saturation of an image. 437 | 438 | Args: 439 | img (PIL Image): PIL Image to be adjusted. 440 | saturation_factor (float): How much to adjust the saturation. 0 will 441 | give a black and white image, 1 will give the original image while 442 | 2 will enhance the saturation by a factor of 2. 443 | 444 | Returns: 445 | PIL Image: Saturation adjusted image. 446 | """ 447 | if not _is_pil_image(img): 448 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 449 | 450 | enhancer = ImageEnhance.Color(img) 451 | img = enhancer.enhance(saturation_factor) 452 | return img 453 | 454 | 455 | def adjust_hue(img, hue_factor): 456 | """Adjust hue of an image. 457 | 458 | The image hue is adjusted by converting the image to HSV and 459 | cyclically shifting the intensities in the hue channel (H). 460 | The image is then converted back to original image mode. 461 | 462 | `hue_factor` is the amount of shift in H channel and must be in the 463 | interval `[-0.5, 0.5]`. 464 | 465 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 466 | 467 | Args: 468 | img (PIL Image): PIL Image to be adjusted. 469 | hue_factor (float): How much to shift the hue channel. Should be in 470 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 471 | HSV space in positive and negative direction respectively. 472 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 473 | with complementary colors while 0 gives the original image. 474 | 475 | Returns: 476 | PIL Image: Hue adjusted image. 477 | """ 478 | if not(-0.5 <= hue_factor <= 0.5): 479 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 480 | 481 | if not _is_pil_image(img): 482 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 483 | 484 | input_mode = img.mode 485 | if input_mode in {'L', '1', 'I', 'F'}: 486 | return img 487 | 488 | h, s, v = img.convert('HSV').split() 489 | 490 | np_h = np.array(h, dtype=np.uint8) 491 | # uint8 addition take cares of rotation across boundaries 492 | with np.errstate(over='ignore'): 493 | np_h += np.uint8(hue_factor * 255) 494 | h = Image.fromarray(np_h, 'L') 495 | 496 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 497 | return img 498 | 499 | 500 | def adjust_gamma(img, gamma, gain=1): 501 | """Perform gamma correction on an image. 502 | 503 | Also known as Power Law Transform. Intensities in RGB mode are adjusted 504 | based on the following equation: 505 | 506 | I_out = 255 * gain * ((I_in / 255) ** gamma) 507 | 508 | See https://en.wikipedia.org/wiki/Gamma_correction for more details. 509 | 510 | Args: 511 | img (PIL Image): PIL Image to be adjusted. 512 | gamma (float): Non negative real number. gamma larger than 1 make the 513 | shadows darker, while gamma smaller than 1 make dark regions 514 | lighter. 515 | gain (float): The constant multiplier. 516 | """ 517 | if not _is_pil_image(img): 518 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 519 | 520 | if gamma < 0: 521 | raise ValueError('Gamma should be a non-negative real number') 522 | 523 | input_mode = img.mode 524 | img = img.convert('RGB') 525 | 526 | gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3 527 | img = img.point(gamma_map) # use PIL's point-function to accelerate this part 528 | 529 | img = img.convert(input_mode) 530 | return img 531 | 532 | 533 | def rotate(img, angle, resample=False, expand=False, center=None): 534 | """Rotate the image by angle. 535 | 536 | 537 | Args: 538 | img (PIL Image): PIL Image to be rotated. 539 | angle ({float, int}): In degrees degrees counter clockwise order. 540 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 541 | An optional resampling filter. 542 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 543 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 544 | expand (bool, optional): Optional expansion flag. 545 | If true, expands the output image to make it large enough to hold the entire rotated image. 546 | If false or omitted, make the output image the same size as the input image. 547 | Note that the expand flag assumes rotation around the center and no translation. 548 | center (2-tuple, optional): Optional center of rotation. 549 | Origin is the upper left corner. 550 | Default is the center of the image. 551 | """ 552 | 553 | if not _is_pil_image(img): 554 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 555 | 556 | return img.rotate(angle, resample, expand, center) 557 | 558 | 559 | def _get_inverse_affine_matrix(center, angle, translate, scale, shear): 560 | # Helper method to compute inverse matrix for affine transformation 561 | 562 | # As it is explained in PIL.Image.rotate 563 | # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1 564 | # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] 565 | # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] 566 | # RSS is rotation with scale and shear matrix 567 | # RSS(a, scale, shear) = [ cos(a)*scale -sin(a + shear)*scale 0] 568 | # [ sin(a)*scale cos(a + shear)*scale 0] 569 | # [ 0 0 1] 570 | # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1 571 | 572 | angle = math.radians(angle) 573 | shear = math.radians(shear) 574 | scale = 1.0 / scale 575 | 576 | # Inverted rotation matrix with scale and shear 577 | d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle) 578 | matrix = [ 579 | math.cos(angle + shear), math.sin(angle + shear), 0, 580 | -math.sin(angle), math.cos(angle), 0 581 | ] 582 | matrix = [scale / d * m for m in matrix] 583 | 584 | # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 585 | matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1]) 586 | matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1]) 587 | 588 | # Apply center translation: C * RSS^-1 * C^-1 * T^-1 589 | matrix[2] += center[0] 590 | matrix[5] += center[1] 591 | return matrix 592 | 593 | 594 | def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None): 595 | """Apply affine transformation on the image keeping image center invariant 596 | 597 | Args: 598 | img (PIL Image): PIL Image to be rotated. 599 | angle ({float, int}): rotation angle in degrees between -180 and 180, clockwise direction. 600 | translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation) 601 | scale (float): overall scale 602 | shear (float): shear angle value in degrees between -180 to 180, clockwise direction. 603 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 604 | An optional resampling filter. 605 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 606 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 607 | fillcolor (int): Optional fill color for the area outside the transform in the output image. 608 | """ 609 | if not _is_pil_image(img): 610 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 611 | 612 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 613 | "Argument translate should be a list or tuple of length 2" 614 | 615 | assert scale > 0.0, "Argument scale should be positive" 616 | 617 | output_size = img.size 618 | center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5) 619 | matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) 620 | return img.transform(output_size, Image.AFFINE, matrix, resample, fillcolor=fillcolor) 621 | 622 | 623 | def to_grayscale(img, num_output_channels=1): 624 | """Convert image to grayscale version of image. 625 | 626 | Args: 627 | img (PIL Image): Image to be converted to grayscale. 628 | 629 | Returns: 630 | PIL Image: Grayscale version of the image. 631 | if num_output_channels == 1 : returned image is single channel 632 | if num_output_channels == 3 : returned image is 3 channel with r == g == b 633 | """ 634 | if not _is_pil_image(img): 635 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 636 | 637 | if num_output_channels == 1: 638 | img = img.convert('L') 639 | elif num_output_channels == 3: 640 | img = img.convert('L') 641 | np_img = np.array(img, dtype=np.uint8) 642 | np_img = np.dstack([np_img, np_img, np_img]) 643 | img = Image.fromarray(np_img, 'RGB') 644 | else: 645 | raise ValueError('num_output_channels should be either 1 or 3') 646 | 647 | return img 648 | -------------------------------------------------------------------------------- /extra/network.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fungtion/DRCN/ab4fafd7a58ade83e42b33849720d268ba19a701/extra/network.jpg -------------------------------------------------------------------------------- /extra/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fungtion/DRCN/ab4fafd7a58ade83e42b33849720d268ba19a701/extra/result.jpg -------------------------------------------------------------------------------- /extra/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps, ImageEnhance 6 | try: 7 | import accimage 8 | except ImportError: 9 | accimage = None 10 | import numpy as np 11 | import numbers 12 | import types 13 | import collections 14 | import warnings 15 | 16 | from . import functional as F 17 | 18 | __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", 19 | "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", 20 | "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", 21 | "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"] 22 | 23 | _pil_interpolation_to_str = { 24 | Image.NEAREST: 'PIL.Image.NEAREST', 25 | Image.BILINEAR: 'PIL.Image.BILINEAR', 26 | Image.BICUBIC: 'PIL.Image.BICUBIC', 27 | Image.LANCZOS: 'PIL.Image.LANCZOS', 28 | } 29 | 30 | 31 | class Compose(object): 32 | """Composes several transforms together. 33 | 34 | Args: 35 | transforms (list of ``Transform`` objects): list of transforms to compose. 36 | 37 | Example: 38 | >>> transforms.Compose([ 39 | >>> transforms.CenterCrop(10), 40 | >>> transforms.ToTensor(), 41 | >>> ]) 42 | """ 43 | 44 | def __init__(self, transforms): 45 | self.transforms = transforms 46 | 47 | def __call__(self, img): 48 | for t in self.transforms: 49 | img = t(img) 50 | return img 51 | 52 | def __repr__(self): 53 | format_string = self.__class__.__name__ + '(' 54 | for t in self.transforms: 55 | format_string += '\n' 56 | format_string += ' {0}'.format(t) 57 | format_string += '\n)' 58 | return format_string 59 | 60 | 61 | class ToTensor(object): 62 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 63 | 64 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 65 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 66 | """ 67 | 68 | def __call__(self, pic): 69 | """ 70 | Args: 71 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 72 | 73 | Returns: 74 | Tensor: Converted image. 75 | """ 76 | return F.to_tensor(pic) 77 | 78 | def __repr__(self): 79 | return self.__class__.__name__ + '()' 80 | 81 | 82 | class ToPILImage(object): 83 | """Convert a tensor or an ndarray to PIL Image. 84 | 85 | Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape 86 | H x W x C to a PIL Image while preserving the value range. 87 | 88 | Args: 89 | mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). 90 | If ``mode`` is ``None`` (default) there are some assumptions made about the input data: 91 | 1. If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. 92 | 2. If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. 93 | 3. If the input has 1 channel, the ``mode`` is determined by the data type (i,e, 94 | ``int``, ``float``, ``short``). 95 | 96 | .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes 97 | """ 98 | def __init__(self, mode=None): 99 | self.mode = mode 100 | 101 | def __call__(self, pic): 102 | """ 103 | Args: 104 | pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. 105 | 106 | Returns: 107 | PIL Image: Image converted to PIL Image. 108 | 109 | """ 110 | return F.to_pil_image(pic, self.mode) 111 | 112 | def __repr__(self): 113 | format_string = self.__class__.__name__ + '(' 114 | if self.mode is not None: 115 | format_string += 'mode={0}'.format(self.mode) 116 | format_string += ')' 117 | return format_string 118 | 119 | 120 | class Normalize(object): 121 | """Normalize a tensor image with mean and standard deviation. 122 | Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform 123 | will normalize each channel of the input ``torch.*Tensor`` i.e. 124 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 125 | 126 | Args: 127 | mean (sequence): Sequence of means for each channel. 128 | std (sequence): Sequence of standard deviations for each channel. 129 | """ 130 | 131 | def __init__(self, mean, std): 132 | self.mean = mean 133 | self.std = std 134 | 135 | def __call__(self, tensor): 136 | """ 137 | Args: 138 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 139 | 140 | Returns: 141 | Tensor: Normalized Tensor image. 142 | """ 143 | return F.normalize(tensor, self.mean, self.std) 144 | 145 | def __repr__(self): 146 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 147 | 148 | 149 | class Resize(object): 150 | """Resize the input PIL Image to the given size. 151 | 152 | Args: 153 | size (sequence or int): Desired output size. If size is a sequence like 154 | (h, w), output size will be matched to this. If size is an int, 155 | smaller edge of the image will be matched to this number. 156 | i.e, if height > width, then image will be rescaled to 157 | (size * height / width, size) 158 | interpolation (int, optional): Desired interpolation. Default is 159 | ``PIL.Image.BILINEAR`` 160 | """ 161 | 162 | def __init__(self, size, interpolation=Image.BILINEAR): 163 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 164 | self.size = size 165 | self.interpolation = interpolation 166 | 167 | def __call__(self, img): 168 | """ 169 | Args: 170 | img (PIL Image): Image to be scaled. 171 | 172 | Returns: 173 | PIL Image: Rescaled image. 174 | """ 175 | return F.resize(img, self.size, self.interpolation) 176 | 177 | def __repr__(self): 178 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 179 | return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str) 180 | 181 | 182 | class Scale(Resize): 183 | """ 184 | Note: This transform is deprecated in favor of Resize. 185 | """ 186 | def __init__(self, *args, **kwargs): 187 | warnings.warn("The use of the transforms.Scale transform is deprecated, " + 188 | "please use transforms.Resize instead.") 189 | super(Scale, self).__init__(*args, **kwargs) 190 | 191 | 192 | class CenterCrop(object): 193 | """Crops the given PIL Image at the center. 194 | 195 | Args: 196 | size (sequence or int): Desired output size of the crop. If size is an 197 | int instead of sequence like (h, w), a square crop (size, size) is 198 | made. 199 | """ 200 | 201 | def __init__(self, size): 202 | if isinstance(size, numbers.Number): 203 | self.size = (int(size), int(size)) 204 | else: 205 | self.size = size 206 | 207 | def __call__(self, img): 208 | """ 209 | Args: 210 | img (PIL Image): Image to be cropped. 211 | 212 | Returns: 213 | PIL Image: Cropped image. 214 | """ 215 | return F.center_crop(img, self.size) 216 | 217 | def __repr__(self): 218 | return self.__class__.__name__ + '(size={0})'.format(self.size) 219 | 220 | 221 | class Pad(object): 222 | """Pad the given PIL Image on all sides with the given "pad" value. 223 | 224 | Args: 225 | padding (int or tuple): Padding on each border. If a single int is provided this 226 | is used to pad all borders. If tuple of length 2 is provided this is the padding 227 | on left/right and top/bottom respectively. If a tuple of length 4 is provided 228 | this is the padding for the left, top, right and bottom borders 229 | respectively. 230 | fill: Pixel fill value. Default is 0. If a tuple of 231 | length 3, it is used to fill R, G, B channels respectively. 232 | """ 233 | 234 | def __init__(self, padding, fill=0): 235 | assert isinstance(padding, (numbers.Number, tuple)) 236 | assert isinstance(fill, (numbers.Number, str, tuple)) 237 | if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: 238 | raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + 239 | "{} element tuple".format(len(padding))) 240 | 241 | self.padding = padding 242 | self.fill = fill 243 | 244 | def __call__(self, img): 245 | """ 246 | Args: 247 | img (PIL Image): Image to be padded. 248 | 249 | Returns: 250 | PIL Image: Padded image. 251 | """ 252 | return F.pad(img, self.padding, self.fill) 253 | 254 | def __repr__(self): 255 | return self.__class__.__name__ + '(padding={0}, fill={1})'.format(self.padding, self.fill) 256 | 257 | 258 | class Lambda(object): 259 | """Apply a user-defined lambda as a transform. 260 | 261 | Args: 262 | lambd (function): Lambda/function to be used for transform. 263 | """ 264 | 265 | def __init__(self, lambd): 266 | assert isinstance(lambd, types.LambdaType) 267 | self.lambd = lambd 268 | 269 | def __call__(self, img): 270 | return self.lambd(img) 271 | 272 | def __repr__(self): 273 | return self.__class__.__name__ + '()' 274 | 275 | 276 | class RandomTransforms(object): 277 | """Base class for a list of transformations with randomness 278 | 279 | Args: 280 | transforms (list or tuple): list of transformations 281 | """ 282 | 283 | def __init__(self, transforms): 284 | assert isinstance(transforms, (list, tuple)) 285 | self.transforms = transforms 286 | 287 | def __call__(self, *args, **kwargs): 288 | raise NotImplementedError() 289 | 290 | def __repr__(self): 291 | format_string = self.__class__.__name__ + '(' 292 | for t in self.transforms: 293 | format_string += '\n' 294 | format_string += ' {0}'.format(t) 295 | format_string += '\n)' 296 | return format_string 297 | 298 | 299 | class RandomApply(RandomTransforms): 300 | """Apply randomly a list of transformations with a given probability 301 | 302 | Args: 303 | transforms (list or tuple): list of transformations 304 | p (float): probability 305 | """ 306 | 307 | def __init__(self, transforms, p=0.5): 308 | super(RandomApply, self).__init__(transforms) 309 | self.p = p 310 | 311 | def __call__(self, img): 312 | if self.p < random.random(): 313 | return img 314 | for t in self.transforms: 315 | img = t(img) 316 | return img 317 | 318 | def __repr__(self): 319 | format_string = self.__class__.__name__ + '(' 320 | format_string += '\n p={}'.format(self.p) 321 | for t in self.transforms: 322 | format_string += '\n' 323 | format_string += ' {0}'.format(t) 324 | format_string += '\n)' 325 | return format_string 326 | 327 | 328 | class RandomOrder(RandomTransforms): 329 | """Apply a list of transformations in a random order 330 | """ 331 | def __call__(self, img): 332 | order = list(range(len(self.transforms))) 333 | random.shuffle(order) 334 | for i in order: 335 | img = self.transforms[i](img) 336 | return img 337 | 338 | 339 | class RandomChoice(RandomTransforms): 340 | """Apply single transformation randomly picked from a list 341 | """ 342 | def __call__(self, img): 343 | t = random.choice(self.transforms) 344 | return t(img) 345 | 346 | 347 | class RandomCrop(object): 348 | """Crop the given PIL Image at a random location. 349 | 350 | Args: 351 | size (sequence or int): Desired output size of the crop. If size is an 352 | int instead of sequence like (h, w), a square crop (size, size) is 353 | made. 354 | padding (int or sequence, optional): Optional padding on each border 355 | of the image. Default is 0, i.e no padding. If a sequence of length 356 | 4 is provided, it is used to pad left, top, right, bottom borders 357 | respectively. 358 | """ 359 | 360 | def __init__(self, size, padding=0): 361 | if isinstance(size, numbers.Number): 362 | self.size = (int(size), int(size)) 363 | else: 364 | self.size = size 365 | self.padding = padding 366 | 367 | @staticmethod 368 | def get_params(img, output_size): 369 | """Get parameters for ``crop`` for a random crop. 370 | 371 | Args: 372 | img (PIL Image): Image to be cropped. 373 | output_size (tuple): Expected output size of the crop. 374 | 375 | Returns: 376 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 377 | """ 378 | w, h = img.size 379 | th, tw = output_size 380 | if w == tw and h == th: 381 | return 0, 0, h, w 382 | 383 | i = random.randint(0, h - th) 384 | j = random.randint(0, w - tw) 385 | return i, j, th, tw 386 | 387 | def __call__(self, img): 388 | """ 389 | Args: 390 | img (PIL Image): Image to be cropped. 391 | 392 | Returns: 393 | PIL Image: Cropped image. 394 | """ 395 | if self.padding > 0: 396 | img = F.pad(img, self.padding) 397 | 398 | i, j, h, w = self.get_params(img, self.size) 399 | 400 | return F.crop(img, i, j, h, w) 401 | 402 | def __repr__(self): 403 | return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) 404 | 405 | 406 | class RandomHorizontalFlip(object): 407 | """Horizontally flip the given PIL Image randomly with a given probability. 408 | 409 | Args: 410 | p (float): probability of the image being flipped. Default value is 0.5 411 | """ 412 | 413 | def __init__(self, p=0.5): 414 | self.p = p 415 | 416 | def __call__(self, img): 417 | """ 418 | Args: 419 | img (PIL Image): Image to be flipped. 420 | 421 | Returns: 422 | PIL Image: Randomly flipped image. 423 | """ 424 | if random.random() < self.p: 425 | return F.hflip(img) 426 | return img 427 | 428 | def __repr__(self): 429 | return self.__class__.__name__ + '(p={})'.format(self.p) 430 | 431 | 432 | class RandomVerticalFlip(object): 433 | """Vertically flip the given PIL Image randomly with a given probability. 434 | 435 | Args: 436 | p (float): probability of the image being flipped. Default value is 0.5 437 | """ 438 | 439 | def __init__(self, p=0.5): 440 | self.p = p 441 | 442 | def __call__(self, img): 443 | """ 444 | Args: 445 | img (PIL Image): Image to be flipped. 446 | 447 | Returns: 448 | PIL Image: Randomly flipped image. 449 | """ 450 | if random.random() < self.p: 451 | return F.vflip(img) 452 | return img 453 | 454 | def __repr__(self): 455 | return self.__class__.__name__ + '(p={})'.format(self.p) 456 | 457 | 458 | class RandomResizedCrop(object): 459 | """Crop the given PIL Image to random size and aspect ratio. 460 | 461 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 462 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 463 | is finally resized to given size. 464 | This is popularly used to train the Inception networks. 465 | 466 | Args: 467 | size: expected output size of each edge 468 | scale: range of size of the origin size cropped 469 | ratio: range of aspect ratio of the origin aspect ratio cropped 470 | interpolation: Default: PIL.Image.BILINEAR 471 | """ 472 | 473 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): 474 | self.size = (size, size) 475 | self.interpolation = interpolation 476 | self.scale = scale 477 | self.ratio = ratio 478 | 479 | @staticmethod 480 | def get_params(img, scale, ratio): 481 | """Get parameters for ``crop`` for a random sized crop. 482 | 483 | Args: 484 | img (PIL Image): Image to be cropped. 485 | scale (tuple): range of size of the origin size cropped 486 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 487 | 488 | Returns: 489 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 490 | sized crop. 491 | """ 492 | for attempt in range(10): 493 | area = img.size[0] * img.size[1] 494 | target_area = random.uniform(*scale) * area 495 | aspect_ratio = random.uniform(*ratio) 496 | 497 | w = int(round(math.sqrt(target_area * aspect_ratio))) 498 | h = int(round(math.sqrt(target_area / aspect_ratio))) 499 | 500 | if random.random() < 0.5: 501 | w, h = h, w 502 | 503 | if w <= img.size[0] and h <= img.size[1]: 504 | i = random.randint(0, img.size[1] - h) 505 | j = random.randint(0, img.size[0] - w) 506 | return i, j, h, w 507 | 508 | # Fallback 509 | w = min(img.size[0], img.size[1]) 510 | i = (img.size[1] - w) // 2 511 | j = (img.size[0] - w) // 2 512 | return i, j, w, w 513 | 514 | def __call__(self, img): 515 | """ 516 | Args: 517 | img (PIL Image): Image to be cropped and resized. 518 | 519 | Returns: 520 | PIL Image: Randomly cropped and resized image. 521 | """ 522 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 523 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) 524 | 525 | def __repr__(self): 526 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 527 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 528 | format_string += ', scale={0}'.format(round(self.scale, 4)) 529 | format_string += ', ratio={0}'.format(round(self.ratio, 4)) 530 | format_string += ', interpolation={0})'.format(interpolate_str) 531 | return format_string 532 | 533 | 534 | class RandomSizedCrop(RandomResizedCrop): 535 | """ 536 | Note: This transform is deprecated in favor of RandomResizedCrop. 537 | """ 538 | def __init__(self, *args, **kwargs): 539 | warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + 540 | "please use transforms.RandomResizedCrop instead.") 541 | super(RandomSizedCrop, self).__init__(*args, **kwargs) 542 | 543 | 544 | class FiveCrop(object): 545 | """Crop the given PIL Image into four corners and the central crop 546 | 547 | .. Note:: 548 | This transform returns a tuple of images and there may be a mismatch in the number of 549 | inputs and targets your Dataset returns. See below for an example of how to deal with 550 | this. 551 | 552 | Args: 553 | size (sequence or int): Desired output size of the crop. If size is an ``int`` 554 | instead of sequence like (h, w), a square crop of size (size, size) is made. 555 | 556 | Example: 557 | >>> transform = Compose([ 558 | >>> FiveCrop(size), # this is a list of PIL Images 559 | >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor 560 | >>> ]) 561 | >>> #In your test loop you can do the following: 562 | >>> input, target = batch # input is a 5d tensor, target is 2d 563 | >>> bs, ncrops, c, h, w = input.size() 564 | >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops 565 | >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops 566 | """ 567 | 568 | def __init__(self, size): 569 | self.size = size 570 | if isinstance(size, numbers.Number): 571 | self.size = (int(size), int(size)) 572 | else: 573 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 574 | self.size = size 575 | 576 | def __call__(self, img): 577 | return F.five_crop(img, self.size) 578 | 579 | def __repr__(self): 580 | return self.__class__.__name__ + '(size={0})'.format(self.size) 581 | 582 | 583 | class TenCrop(object): 584 | """Crop the given PIL Image into four corners and the central crop plus the flipped version of 585 | these (horizontal flipping is used by default) 586 | 587 | .. Note:: 588 | This transform returns a tuple of images and there may be a mismatch in the number of 589 | inputs and targets your Dataset returns. See below for an example of how to deal with 590 | this. 591 | 592 | Args: 593 | size (sequence or int): Desired output size of the crop. If size is an 594 | int instead of sequence like (h, w), a square crop (size, size) is 595 | made. 596 | vertical_flip(bool): Use vertical flipping instead of horizontal 597 | 598 | Example: 599 | >>> transform = Compose([ 600 | >>> TenCrop(size), # this is a list of PIL Images 601 | >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor 602 | >>> ]) 603 | >>> #In your test loop you can do the following: 604 | >>> input, target = batch # input is a 5d tensor, target is 2d 605 | >>> bs, ncrops, c, h, w = input.size() 606 | >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops 607 | >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops 608 | """ 609 | 610 | def __init__(self, size, vertical_flip=False): 611 | self.size = size 612 | if isinstance(size, numbers.Number): 613 | self.size = (int(size), int(size)) 614 | else: 615 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 616 | self.size = size 617 | self.vertical_flip = vertical_flip 618 | 619 | def __call__(self, img): 620 | return F.ten_crop(img, self.size, self.vertical_flip) 621 | 622 | def __repr__(self): 623 | return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) 624 | 625 | 626 | class LinearTransformation(object): 627 | """Transform a tensor image with a square transformation matrix computed 628 | offline. 629 | 630 | Given transformation_matrix, will flatten the torch.*Tensor, compute the dot 631 | product with the transformation matrix and reshape the tensor to its 632 | original shape. 633 | 634 | Applications: 635 | - whitening: zero-center the data, compute the data covariance matrix 636 | [D x D] with np.dot(X.T, X), perform SVD on this matrix and 637 | pass it as transformation_matrix. 638 | 639 | Args: 640 | transformation_matrix (Tensor): tensor [D x D], D = C x H x W 641 | """ 642 | 643 | def __init__(self, transformation_matrix): 644 | if transformation_matrix.size(0) != transformation_matrix.size(1): 645 | raise ValueError("transformation_matrix should be square. Got " + 646 | "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) 647 | self.transformation_matrix = transformation_matrix 648 | 649 | def __call__(self, tensor): 650 | """ 651 | Args: 652 | tensor (Tensor): Tensor image of size (C, H, W) to be whitened. 653 | 654 | Returns: 655 | Tensor: Transformed image. 656 | """ 657 | if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0): 658 | raise ValueError("tensor and transformation matrix have incompatible shape." + 659 | "[{} x {} x {}] != ".format(*tensor.size()) + 660 | "{}".format(self.transformation_matrix.size(0))) 661 | flat_tensor = tensor.view(1, -1) 662 | transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) 663 | tensor = transformed_tensor.view(tensor.size()) 664 | return tensor 665 | 666 | def __repr__(self): 667 | format_string = self.__class__.__name__ + '(' 668 | format_string += (str(self.transformation_matrix.numpy().tolist()) + ')') 669 | return format_string 670 | 671 | 672 | class ColorJitter(object): 673 | """Randomly change the brightness, contrast and saturation of an image. 674 | 675 | Args: 676 | brightness (float): How much to jitter brightness. brightness_factor 677 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 678 | contrast (float): How much to jitter contrast. contrast_factor 679 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 680 | saturation (float): How much to jitter saturation. saturation_factor 681 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 682 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 683 | [-hue, hue]. Should be >=0 and <= 0.5. 684 | """ 685 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 686 | self.brightness = brightness 687 | self.contrast = contrast 688 | self.saturation = saturation 689 | self.hue = hue 690 | 691 | @staticmethod 692 | def get_params(brightness, contrast, saturation, hue): 693 | """Get a randomized transform to be applied on image. 694 | 695 | Arguments are same as that of __init__. 696 | 697 | Returns: 698 | Transform which randomly adjusts brightness, contrast and 699 | saturation in a random order. 700 | """ 701 | transforms = [] 702 | if brightness > 0: 703 | brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness) 704 | transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) 705 | 706 | if contrast > 0: 707 | contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast) 708 | transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) 709 | 710 | if saturation > 0: 711 | saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation) 712 | transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) 713 | 714 | if hue > 0: 715 | hue_factor = random.uniform(-hue, hue) 716 | transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor))) 717 | 718 | random.shuffle(transforms) 719 | transform = Compose(transforms) 720 | 721 | return transform 722 | 723 | def __call__(self, img): 724 | """ 725 | Args: 726 | img (PIL Image): Input image. 727 | 728 | Returns: 729 | PIL Image: Color jittered image. 730 | """ 731 | transform = self.get_params(self.brightness, self.contrast, 732 | self.saturation, self.hue) 733 | return transform(img) 734 | 735 | def __repr__(self): 736 | format_string = self.__class__.__name__ + '(' 737 | format_string += 'brightness={0}'.format(self.brightness) 738 | format_string += ', contrast={0}'.format(self.contrast) 739 | format_string += ', saturation={0}'.format(self.saturation) 740 | format_string += ', hue={0})'.format(self.hue) 741 | return format_string 742 | 743 | 744 | class RandomRotation(object): 745 | """Rotate the image by angle. 746 | 747 | Args: 748 | degrees (sequence or float or int): Range of degrees to select from. 749 | If degrees is a number instead of sequence like (min, max), the range of degrees 750 | will be (-degrees, +degrees). 751 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 752 | An optional resampling filter. 753 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 754 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 755 | expand (bool, optional): Optional expansion flag. 756 | If true, expands the output to make it large enough to hold the entire rotated image. 757 | If false or omitted, make the output image the same size as the input image. 758 | Note that the expand flag assumes rotation around the center and no translation. 759 | center (2-tuple, optional): Optional center of rotation. 760 | Origin is the upper left corner. 761 | Default is the center of the image. 762 | """ 763 | 764 | def __init__(self, degrees, resample=False, expand=False, center=None): 765 | if isinstance(degrees, numbers.Number): 766 | if degrees < 0: 767 | raise ValueError("If degrees is a single number, it must be positive.") 768 | self.degrees = (-degrees, degrees) 769 | else: 770 | if len(degrees) != 2: 771 | raise ValueError("If degrees is a sequence, it must be of len 2.") 772 | self.degrees = degrees 773 | 774 | self.resample = resample 775 | self.expand = expand 776 | self.center = center 777 | 778 | @staticmethod 779 | def get_params(degrees): 780 | """Get parameters for ``rotate`` for a random rotation. 781 | 782 | Returns: 783 | sequence: params to be passed to ``rotate`` for random rotation. 784 | """ 785 | angle = random.uniform(degrees[0], degrees[1]) 786 | 787 | return angle 788 | 789 | def __call__(self, img): 790 | """ 791 | img (PIL Image): Image to be rotated. 792 | 793 | Returns: 794 | PIL Image: Rotated image. 795 | """ 796 | 797 | angle = self.get_params(self.degrees) 798 | 799 | return F.rotate(img, angle, self.resample, self.expand, self.center) 800 | 801 | def __repr__(self): 802 | format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) 803 | format_string += ', resample={0}'.format(self.resample) 804 | format_string += ', expand={0}'.format(self.expand) 805 | if self.center is not None: 806 | format_string += ', center={0}'.format(self.center) 807 | format_string += ')' 808 | return format_string 809 | 810 | 811 | class RandomAffine(object): 812 | """Random affine transformation of the image keeping center invariant 813 | 814 | Args: 815 | degrees (sequence or float or int): Range of degrees to select from. 816 | If degrees is a number instead of sequence like (min, max), the range of degrees 817 | will be (-degrees, +degrees). Set to 0 to desactivate rotations. 818 | translate (tuple, optional): tuple of maximum absolute fraction for horizontal 819 | and vertical translations. For example translate=(a, b), then horizontal shift 820 | is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is 821 | randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. 822 | scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is 823 | randomly sampled from the range a <= scale <= b. Will keep original scale by default. 824 | shear (sequence or float or int, optional): Range of degrees to select from. 825 | If degrees is a number instead of sequence like (min, max), the range of degrees 826 | will be (-degrees, +degrees). Will not apply shear by default 827 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 828 | An optional resampling filter. 829 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 830 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 831 | fillcolor (int): Optional fill color for the area outside the transform in the output image. 832 | """ 833 | 834 | def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0): 835 | if isinstance(degrees, numbers.Number): 836 | if degrees < 0: 837 | raise ValueError("If degrees is a single number, it must be positive.") 838 | self.degrees = (-degrees, degrees) 839 | else: 840 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ 841 | "degrees should be a list or tuple and it must be of length 2." 842 | self.degrees = degrees 843 | 844 | if translate is not None: 845 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 846 | "translate should be a list or tuple and it must be of length 2." 847 | for t in translate: 848 | if not (0.0 <= t <= 1.0): 849 | raise ValueError("translation values should be between 0 and 1") 850 | self.translate = translate 851 | 852 | if scale is not None: 853 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ 854 | "scale should be a list or tuple and it must be of length 2." 855 | for s in scale: 856 | if s <= 0: 857 | raise ValueError("scale values should be positive") 858 | self.scale = scale 859 | 860 | if shear is not None: 861 | if isinstance(shear, numbers.Number): 862 | if shear < 0: 863 | raise ValueError("If shear is a single number, it must be positive.") 864 | self.shear = (-shear, shear) 865 | else: 866 | assert isinstance(shear, (tuple, list)) and len(shear) == 2, \ 867 | "shear should be a list or tuple and it must be of length 2." 868 | self.shear = shear 869 | else: 870 | self.shear = shear 871 | 872 | self.resample = resample 873 | self.fillcolor = fillcolor 874 | 875 | @staticmethod 876 | def get_params(degrees, translate, scale_ranges, shears, img_size): 877 | """Get parameters for affine transformation 878 | 879 | Returns: 880 | sequence: params to be passed to the affine transformation 881 | """ 882 | angle = random.uniform(degrees[0], degrees[1]) 883 | if translate is not None: 884 | max_dx = translate[0] * img_size[0] 885 | max_dy = translate[1] * img_size[1] 886 | translations = (np.round(random.uniform(-max_dx, max_dx)), 887 | np.round(random.uniform(-max_dy, max_dy))) 888 | else: 889 | translations = (0, 0) 890 | 891 | if scale_ranges is not None: 892 | scale = random.uniform(scale_ranges[0], scale_ranges[1]) 893 | else: 894 | scale = 1.0 895 | 896 | if shears is not None: 897 | shear = random.uniform(shears[0], shears[1]) 898 | else: 899 | shear = 0.0 900 | 901 | return angle, translations, scale, shear 902 | 903 | def __call__(self, img): 904 | """ 905 | img (PIL Image): Image to be transformed. 906 | 907 | Returns: 908 | PIL Image: Affine transformed image. 909 | """ 910 | ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size) 911 | return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor) 912 | 913 | def __repr__(self): 914 | s = '{name}(degrees={degrees}' 915 | if self.translate is not None: 916 | s += ', translate={translate}' 917 | if self.scale is not None: 918 | s += ', scale={scale}' 919 | if self.shear is not None: 920 | s += ', shear={shear}' 921 | if self.resample > 0: 922 | s += ', resample={resample}' 923 | if self.fillcolor != 0: 924 | s += ', fillcolor={fillcolor}' 925 | s += ')' 926 | d = dict(self.__dict__) 927 | d['resample'] = _pil_interpolation_to_str[d['resample']] 928 | return s.format(name=self.__class__.__name__, **d) 929 | 930 | 931 | class Grayscale(object): 932 | """Convert image to grayscale. 933 | 934 | Args: 935 | num_output_channels (int): (1 or 3) number of channels desired for output image 936 | 937 | Returns: 938 | PIL Image: Grayscale version of the input. 939 | - If num_output_channels == 1 : returned image is single channel 940 | - If num_output_channels == 3 : returned image is 3 channel with r == g == b 941 | 942 | """ 943 | 944 | def __init__(self, num_output_channels=1): 945 | self.num_output_channels = num_output_channels 946 | 947 | def __call__(self, img): 948 | """ 949 | Args: 950 | img (PIL Image): Image to be converted to grayscale. 951 | 952 | Returns: 953 | PIL Image: Randomly grayscaled image. 954 | """ 955 | return F.to_grayscale(img, num_output_channels=self.num_output_channels) 956 | 957 | def __repr__(self): 958 | return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) 959 | 960 | 961 | class RandomGrayscale(object): 962 | """Randomly convert image to grayscale with a probability of p (default 0.1). 963 | 964 | Args: 965 | p (float): probability that image should be converted to grayscale. 966 | 967 | Returns: 968 | PIL Image: Grayscale version of the input image with probability p and unchanged 969 | with probability (1-p). 970 | - If input image is 1 channel: grayscale version is 1 channel 971 | - If input image is 3 channel: grayscale version is 3 channel with r == g == b 972 | 973 | """ 974 | 975 | def __init__(self, p=0.1): 976 | self.p = p 977 | 978 | def __call__(self, img): 979 | """ 980 | Args: 981 | img (PIL Image): Image to be converted to grayscale. 982 | 983 | Returns: 984 | PIL Image: Randomly grayscaled image. 985 | """ 986 | num_output_channels = 1 if img.mode == 'L' else 3 987 | if random.random() < self.p: 988 | return F.to_grayscale(img, num_output_channels=num_output_channels) 989 | return img 990 | 991 | def __repr__(self): 992 | return self.__class__.__name__ + '(p={0})'.format(self.p) 993 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import torch.backends.cudnn as cudnn 4 | import torch.optim as optim 5 | import torch.utils.data 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | from torchvision import datasets 9 | from torchvision import transforms 10 | from model import DRCN 11 | from test import test 12 | from torchvision.utils import save_image 13 | from rec_image import rec_image 14 | 15 | source_dataset_name = 'SVHN' 16 | target_dataset_name = 'mnist' 17 | source_dataset = os.path.join('.', 'dataset', 'svhn') 18 | target_dataset = os.path.join('.', 'dataset', 'mnist') 19 | model_root = 'models' # directory to save trained models 20 | cuda = True 21 | cudnn.benchmark = True 22 | lr = 1e-4 23 | batch_size = 64 24 | image_size = 32 25 | n_epoch = 100 26 | weight_decay = 5e-6 27 | m_lambda = 0.7 28 | 29 | 30 | def weights_init(m): 31 | if isinstance(m, nn.Conv2d): 32 | nn.init.xavier_uniform(m.weight.data, gain=1) 33 | nn.init.constant(m.bias.data, 0.1) 34 | 35 | manual_seed = random.randint(1, 10000) 36 | random.seed(manual_seed) 37 | torch.manual_seed(manual_seed) 38 | 39 | # load data 40 | img_transform_svhn = transforms.Compose([ 41 | transforms.Grayscale(), 42 | transforms.RandomRotation(20), 43 | transforms.ToTensor() 44 | ]) 45 | 46 | img_transform_mnist = transforms.Compose([ 47 | transforms.Resize(image_size), 48 | transforms.RandomRotation(20), 49 | transforms.ToTensor() 50 | ]) 51 | 52 | dataset_source = datasets.SVHN( 53 | root=source_dataset, 54 | split='train', 55 | transform=img_transform_svhn, 56 | ) 57 | 58 | datasetloader_source = torch.utils.data.DataLoader( 59 | dataset=dataset_source, 60 | batch_size=batch_size, 61 | shuffle=True, 62 | num_workers=8 63 | ) 64 | 65 | dataset_target = datasets.MNIST( 66 | root=target_dataset, 67 | train=True, 68 | transform=img_transform_mnist, 69 | ) 70 | 71 | datasetloader_target = torch.utils.data.DataLoader( 72 | dataset=dataset_target, 73 | batch_size=batch_size, 74 | shuffle=True, 75 | num_workers=8 76 | ) 77 | 78 | # load models 79 | my_net = DRCN(n_class=10) 80 | my_net.apply(weights_init) 81 | 82 | # setup optimizer 83 | optimizer_classify = optim.RMSprop([{'params': my_net.enc_feat.parameters()}, 84 | {'params': my_net.enc_dense.parameters()}, 85 | {'params': my_net.pred.parameters()}], lr=lr, weight_decay=weight_decay) 86 | 87 | optimizer_rec = optim.RMSprop([{'params': my_net.enc_feat.parameters()}, 88 | {'params': my_net.enc_dense.parameters()}, 89 | {'params': my_net.rec_dense.parameters()}, 90 | {'params': my_net.rec_feat.parameters()}], lr=lr, weight_decay=weight_decay) 91 | 92 | loss_class = nn.CrossEntropyLoss() 93 | loss_rec = nn.MSELoss() 94 | 95 | if cuda: 96 | my_net = my_net.cuda() 97 | loss_class = loss_class.cuda() 98 | loss_rec = loss_rec.cuda() 99 | 100 | for p in my_net.parameters(): 101 | p.requires_grad = True 102 | 103 | len_source = len(datasetloader_source) 104 | len_target = len(datasetloader_target) 105 | 106 | # training 107 | for epoch in xrange(n_epoch): 108 | 109 | # train reconstruction 110 | dataset_target_iter = iter(datasetloader_target) 111 | 112 | i = 0 113 | 114 | while i < len_target: 115 | my_net.zero_grad() 116 | 117 | data_target = dataset_target_iter.next() 118 | t_img, _ = data_target 119 | 120 | batch_size = len(t_img) 121 | 122 | input_img = torch.FloatTensor(batch_size, 1, image_size, image_size) 123 | 124 | if cuda: 125 | t_img = t_img.cuda() 126 | input_img = input_img.cuda() 127 | 128 | input_img.resize_as_(t_img).copy_(t_img) 129 | inputv_img = Variable(input_img) 130 | 131 | _, rec_img = my_net(input_data=inputv_img) 132 | save_image(rec_img.data, './recovery_image/mnist_rec' + str(epoch) + '.png', nrow=8) 133 | 134 | rec_img = rec_img.view(-1, 1 * image_size * image_size) 135 | inputv_img_img = inputv_img.contiguous().view(-1, 1 * image_size * image_size) 136 | err_rec = (1 - m_lambda) * loss_rec(rec_img, inputv_img) 137 | err_rec.backward() 138 | optimizer_rec.step() 139 | 140 | i += 1 141 | 142 | print 'epoch: %d, err_rec %f' \ 143 | % (epoch, err_rec.cpu().data.numpy()) 144 | 145 | # training label classifier 146 | 147 | dataset_source_iter = iter(datasetloader_source) 148 | 149 | i = 0 150 | 151 | while i < len_source: 152 | my_net.zero_grad() 153 | 154 | data_source = dataset_source_iter.next() 155 | s_img, s_label = data_source 156 | s_label = s_label.long().squeeze() 157 | 158 | batch_size = len(s_label) 159 | 160 | input_img = torch.FloatTensor(batch_size, 1, image_size, image_size) 161 | class_label = torch.LongTensor(batch_size) 162 | if cuda: 163 | s_img = s_img.cuda() 164 | s_label = s_label.cuda() 165 | input_img = input_img.cuda() 166 | class_label = class_label.cuda() 167 | 168 | input_img.resize_as_(s_img).copy_(s_img) 169 | class_label.resize_as_(s_label).copy_(s_label) 170 | inputv_img = Variable(input_img) 171 | classv_label = Variable(class_label) 172 | 173 | pred_label, _ = my_net(input_data=inputv_img) 174 | err_class = m_lambda * loss_class(pred_label, classv_label) 175 | err_class.backward() 176 | optimizer_classify.step() 177 | 178 | i += 1 179 | 180 | print 'epoch: %d, err_class: %f' \ 181 | % (epoch, err_class.cpu().data.numpy()) 182 | 183 | torch.save(my_net, '{0}/svhn_mnist_model_epoch_{1}.pth'.format(model_root, epoch)) 184 | 185 | rec_image(epoch) 186 | test(epoch) 187 | 188 | print 'done' 189 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DRCN(nn.Module): 6 | def __init__(self, n_class): 7 | super(DRCN, self).__init__() 8 | 9 | # convolutional encoder 10 | 11 | self.enc_feat = nn.Sequential() 12 | self.enc_feat.add_module('conv1', nn.Conv2d(in_channels=1, out_channels=100, kernel_size=5, 13 | padding=2)) 14 | self.enc_feat.add_module('relu1', nn.ReLU(True)) 15 | self.enc_feat.add_module('pool1', nn.MaxPool2d(kernel_size=2, stride=2)) 16 | 17 | self.enc_feat.add_module('conv2', nn.Conv2d(in_channels=100, out_channels=150, kernel_size=5, 18 | padding=2)) 19 | self.enc_feat.add_module('relu2', nn.ReLU(True)) 20 | self.enc_feat.add_module('pool2', nn.MaxPool2d(kernel_size=2, stride=2)) 21 | 22 | self.enc_feat.add_module('conv3', nn.Conv2d(in_channels=150, out_channels=200, kernel_size=3, 23 | padding=1)) 24 | self.enc_feat.add_module('relu3', nn.ReLU(True)) 25 | 26 | self.enc_dense = nn.Sequential() 27 | self.enc_dense.add_module('fc4', nn.Linear(in_features=200 * 8 * 8, out_features=1024)) 28 | self.enc_dense.add_module('relu4', nn.ReLU(True)) 29 | self.enc_dense.add_module('drop4', nn.Dropout2d()) 30 | 31 | self.enc_dense.add_module('fc5', nn.Linear(in_features=1024, out_features=1024)) 32 | self.enc_dense.add_module('relu5', nn.ReLU(True)) 33 | 34 | # label predict layer 35 | self.pred = nn.Sequential() 36 | self.pred.add_module('dropout6', nn.Dropout2d()) 37 | self.pred.add_module('predict6', nn.Linear(in_features=1024, out_features=n_class)) 38 | 39 | # convolutional decoder 40 | 41 | self.rec_dense = nn.Sequential() 42 | self.rec_dense.add_module('fc5_', nn.Linear(in_features=1024, out_features=1024)) 43 | self.rec_dense.add_module('relu5_', nn.ReLU(True)) 44 | 45 | self.rec_dense.add_module('fc4_', nn.Linear(in_features=1024, out_features=200 * 8 * 8)) 46 | self.rec_dense.add_module('relu4_', nn.ReLU(True)) 47 | 48 | self.rec_feat = nn.Sequential() 49 | 50 | self.rec_feat.add_module('conv3_', nn.Conv2d(in_channels=200, out_channels=150, 51 | kernel_size=3, padding=1)) 52 | self.rec_feat.add_module('relu3_', nn.ReLU(True)) 53 | self.rec_feat.add_module('pool3_', nn.Upsample(scale_factor=2)) 54 | 55 | self.rec_feat.add_module('conv2_', nn.Conv2d(in_channels=150, out_channels=100, 56 | kernel_size=5, padding=2)) 57 | self.rec_feat.add_module('relu2_', nn.ReLU(True)) 58 | self.rec_feat.add_module('pool2_', nn.Upsample(scale_factor=2)) 59 | 60 | self.rec_feat.add_module('conv1_', nn.Conv2d(in_channels=100, out_channels=1, 61 | kernel_size=5, padding=2)) 62 | 63 | def forward(self, input_data): 64 | feat = self.enc_feat(input_data) 65 | feat = feat.view(-1, 200 * 8 * 8) 66 | feat_code = self.enc_dense(feat) 67 | 68 | pred_label = self.pred(feat_code) 69 | 70 | feat_encode = self.rec_dense(feat_code) 71 | feat_encode = feat_encode.view(-1, 200, 8, 8) 72 | img_rec = self.rec_feat(feat_encode) 73 | 74 | return pred_label, img_rec 75 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fungtion/DRCN/ab4fafd7a58ade83e42b33849720d268ba19a701/models/__init__.py -------------------------------------------------------------------------------- /rec_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.backends.cudnn as cudnn 3 | import torch.utils.data 4 | import torchvision.utils as vutils 5 | from torch.autograd import Variable 6 | from torchvision import transforms 7 | from torchvision import datasets 8 | 9 | 10 | def rec_image(epoch): 11 | 12 | model_root = 'models' 13 | image_root = os.path.join('dataset', 'svhn') 14 | 15 | cuda = True 16 | cudnn.benchmark = True 17 | batch_size = 64 18 | image_size = 32 19 | 20 | # load data 21 | img_transfrom = transforms.Compose([ 22 | transforms.Grayscale(), 23 | transforms.ToTensor() 24 | ]) 25 | 26 | dataset = datasets.SVHN( 27 | root=image_root, 28 | split='test', 29 | transform=img_transfrom 30 | ) 31 | 32 | data_loader = torch.utils.data.DataLoader( 33 | dataset=dataset, 34 | batch_size=batch_size, 35 | shuffle=False, 36 | num_workers=8 37 | ) 38 | 39 | # test 40 | my_net = torch.load(os.path.join( 41 | model_root, 'svhn_mnist_model_epoch_' + str(epoch) + '.pth') 42 | ) 43 | 44 | my_net = my_net.eval() 45 | if cuda: 46 | my_net = my_net.cuda() 47 | 48 | data_iter = iter(data_loader) 49 | data = data_iter.next() 50 | img, _ = data 51 | 52 | batch_size = len(img) 53 | 54 | input_img = torch.FloatTensor(batch_size, 1, image_size, image_size) 55 | 56 | if cuda: 57 | img = img.cuda() 58 | input_img = input_img.cuda() 59 | 60 | input_img.resize_as_(img).copy_(img) 61 | inputv_img = Variable(input_img) 62 | 63 | _, rec_img = my_net(input_data=inputv_img) 64 | 65 | vutils.save_image(input_img, './recovery_image/svhn_real_epoch_' + str(epoch) + '.png', nrow=8) 66 | vutils.save_image(rec_img.data, './recovery_image/svhn_rec_' + str(epoch) + '.png', nrow=8) 67 | 68 | -------------------------------------------------------------------------------- /recovery_image/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fungtion/DRCN/ab4fafd7a58ade83e42b33849720d268ba19a701/recovery_image/__init__.py -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.backends.cudnn as cudnn 3 | import torch.utils.data 4 | from torch.autograd import Variable 5 | from torchvision import transforms 6 | from torchvision import datasets 7 | 8 | 9 | def test(epoch): 10 | 11 | model_root = 'models' 12 | image_root = os.path.join('dataset', 'mnist') 13 | 14 | cuda = True 15 | cudnn.benchmark = True 16 | batch_size = 64 17 | image_size = 32 18 | 19 | # load data 20 | img_transform = transforms.Compose([ 21 | transforms.Resize(image_size), 22 | transforms.ToTensor() 23 | ]) 24 | 25 | dataset = datasets.MNIST( 26 | root=image_root, 27 | train=False, 28 | transform=img_transform 29 | ) 30 | 31 | data_loader = torch.utils.data.DataLoader( 32 | dataset=dataset, 33 | batch_size=batch_size, 34 | shuffle=False, 35 | num_workers=8 36 | ) 37 | 38 | # test 39 | my_net = torch.load(os.path.join( 40 | model_root, 'svhn_mnist_model_epoch_' + str(epoch) + '.pth') 41 | ) 42 | 43 | my_net = my_net.eval() 44 | if cuda: 45 | my_net = my_net.cuda() 46 | 47 | len_dataloader = len(data_loader) 48 | data_iter = iter(data_loader) 49 | 50 | i = 0 51 | n_total = 0 52 | n_correct = 0 53 | 54 | while i < len_dataloader: 55 | 56 | data = data_iter.next() 57 | img, label = data 58 | 59 | batch_size = len(label) 60 | 61 | input_img = torch.FloatTensor(batch_size, 1, image_size, image_size) 62 | class_label = torch.LongTensor(batch_size) 63 | 64 | if cuda: 65 | img = img.cuda() 66 | label = label.cuda() 67 | input_img = input_img.cuda() 68 | class_label = class_label.cuda() 69 | 70 | input_img.resize_as_(img).copy_(img) 71 | class_label.resize_as_(label).copy_(label) 72 | inputv_img = Variable(input_img) 73 | classv_label = Variable(class_label) 74 | 75 | pred_label, _ = my_net(input_data=inputv_img) 76 | pred = pred_label.data.max(1, keepdim=True)[1] 77 | n_correct += pred.eq(classv_label.data.view_as(pred)).cpu().sum() 78 | n_total += batch_size 79 | 80 | i += 1 81 | 82 | accu = n_correct * 1.0 / n_total 83 | 84 | print 'epoch: %d, accuracy: %f' %(epoch, accu) 85 | --------------------------------------------------------------------------------