├── .circleci └── config.yml ├── .gitignore ├── LICENSE ├── README.md ├── code └── +p2p │ ├── +data │ ├── PairedImageDatastore.m │ └── transformImagePair.m │ ├── +networks │ ├── block.m │ ├── discriminator.m │ ├── downBlock.m │ ├── generator.m │ ├── instanceNormalizationLayer.m │ └── upBlock.m │ ├── +util │ ├── AdamOptimiser.m │ └── downloadFacades.m │ ├── +vis │ └── TrainingPlot.m │ ├── train.m │ ├── trainingOptions.m │ └── translate.m ├── docs ├── getting_started.mlx ├── labels.png ├── output.jpg ├── target.jpg └── training.gif ├── facades.rights ├── install.m ├── runAllTests.m ├── tests ├── +tests │ ├── DatastoreTests.m │ ├── DiscriminatorTest.m │ ├── GeneratorTest.m │ ├── InstanceNormTest.m │ ├── TrainingPlotTest.m │ ├── TrainingTest.m │ └── WithWorkingDirectory.m ├── resources │ ├── badImagePairs │ │ ├── A │ │ │ ├── cmp_b0003 - Copy.jpg │ │ │ ├── cmp_b0003.jpg │ │ │ ├── cmp_b0005.jpg │ │ │ └── cmp_b0011.jpg │ │ └── B │ │ │ ├── cmp_b0003.jpg │ │ │ ├── cmp_b0005.jpg │ │ │ └── cmp_b0011.jpg │ └── imagePairs │ │ ├── A │ │ ├── cmp_b0003.jpg │ │ ├── cmp_b0005.jpg │ │ └── cmp_b0011.jpg │ │ └── B │ │ ├── cmp_b0003.jpg │ │ ├── cmp_b0005.jpg │ │ └── cmp_b0011.jpg └── testRoot.m └── uninstall.m /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build: 4 | machine: 5 | image: ubuntu-1604:201903-01 6 | steps: 7 | - checkout 8 | - run: wget -qO- --retry-connrefused https://storage.googleapis.com/matlabimagesus/public/install.sh | sudo -E bash 9 | - run: matlab -batch "runAllTests" 10 | - store_test_results: 11 | path: artifacts 12 | - store_artifacts: 13 | path: artifacts/junit 14 | - run: bash <(curl -s https://codecov.io/bash) -f artifacts/coverage/codeCoverage.xml -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | scratch/ 3 | datasets/ 4 | 5 | *.asv 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, The MathWorks, Inc. 2 | All rights reserved. 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 5 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 6 | 3. In all cases, the software is, and all modifications and derivatives of the software shall be, licensed to you solely for use in conjunction with MathWorks products and service offerings. 7 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 8 | 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pix2pix - Image to Image Translation Using Generative Adversarial Networks 2 | 3 | ![circecli](https://circleci.com/gh/matlab-deep-learning/pix2pix.svg?style=svg) 4 | [![codecov](https://codecov.io/gh/matlab-deep-learning/pix2pix/branch/master/graph/badge.svg)](https://codecov.io/gh/matlab-deep-learning/pix2pix) 5 | 6 | This repository contains MATLAB code to implement the pix2pix image to image translation method described in the paper by [Isola et al. *Image-to-Image Translation with Conditional Adversarial Nets*](https://phillipi.github.io/pix2pix/). 7 | 8 | - [Before you begin](#before-you-begin) 9 | - [Getting started](#getting-started) 10 | - [Installation](#installation) 11 | - [Training a model](#training-a-model) 12 | - [Generating images](#generating-images) 13 | - [Any problems?](#any-problems) 14 | - [Finally](#finally) 15 | 16 | ## Before you begin 17 | 18 | Make sure you have the minimum following requirements: 19 | 20 | - MATLAB R2019b or greater 21 | - Deep Learning Toolbox 22 | 23 | 24 | ## Getting started 25 | 26 | ### Installation 27 | 28 | 29 | First off [clone](https://github.com/matlab-deep-learning/pix2pix.git) or [download](https://github.com/matlab-deep-learning/pix2pix/archive/master.zip) the repository to get a copy of the code. Then run the function `install.m` to ensure that all required files are added to the MATLAB path. 30 | 31 | ```matlab 32 | install(); 33 | ``` 34 | 35 | ### Training a model 36 | 37 | To train a model you need many pairs of images of "before" and "after". The classic example is the [facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/) which contains label images of the fronts of buildings, and the corresponding original photo. 38 | 39 | ![](docs/labels.png)![](docs/target.jpg) 40 | 41 | Use the helper function `p2p.util.downloadFacades` to download and prepare the dataset for model training. Once that's ready you will have two folders 'A' the input labels, and 'B' the desired output images. 42 | 43 | To train the model we need to provide the locations of the A and B images, as well as any training options. The model will then try and learn to convert A images into B images! 44 | 45 | ```matlab 46 | [labelFolder, targetFolder] = p2p.util.downloadFacades(); 47 | ``` 48 | 49 | We will just use the default options which approximately reproduce the setttings from the original pix2pix paper. 50 | 51 | ```matlab 52 | options = p2p.trainingOptions(); 53 | p2pModel = p2p.train(labelFolder, targetFolder, options); 54 | ``` 55 | 56 | _Note that with the default options training the model will take several hours on a GPU and requires around 6GB of memory._ 57 | 58 | ![](docs/training.gif) 59 | 60 | ### Generating images 61 | 62 | Once the model is trained we can use the generator to make generate a new image. 63 | 64 | ```matlab 65 | exampleInput = imread("docs/labels.png"); 66 | ``` 67 | 68 | We can then use the `p2p.translate` function to convert the input image using trained model. (Note that the generator we have used expects an input image with pixel dimensions divisible by 256) 69 | 70 | ```matlab 71 | exampleOutput = p2p.translate(p2pModel, exampleInput); 72 | imshowpair(exampleInput, exampleOutput, "montage"); 73 | ``` 74 | 75 | ![](docs/labels.png)![](docs/output.jpg) 76 | 77 | For an example you can directly run in MATLAB see the [Getting Started](docs/getting_started.mlx) live script. 78 | 79 | ## Any problems? 80 | 81 | If you have any trouble using this code, report any bugs, or want to request a feature please use the [GitHub issues](https://github.com/matlab-deep-learning/pix2pix/issues). 82 | 83 | ## Finally 84 | 85 | This repository uses some images from the [facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/) used under the [CC BY-SA licence](facades.rights) 86 | 87 | _Copyright 2020 The MathWorks, Inc._ -------------------------------------------------------------------------------- /code/+p2p/+data/PairedImageDatastore.m: -------------------------------------------------------------------------------- 1 | classdef PairedImageDatastore < matlab.io.Datastore & ... 2 | matlab.io.datastore.Shuffleable & ... 3 | matlab.io.datastore.MiniBatchable 4 | % PairedImageDatastore A datastore to provide pairs of images. 5 | % 6 | % This datastore allows mini-batching and shuffling of matching pairs of 7 | % images in two folders while, preserving the pairing of images. 8 | 9 | % Copyright 2020 The MathWorks, Inc. 10 | 11 | properties (Dependent) 12 | MiniBatchSize 13 | end 14 | 15 | properties (SetAccess = protected) 16 | DirA 17 | DirB 18 | ImagesA 19 | ImagesB 20 | NumObservations 21 | MiniBatchSize_ 22 | Augmenter 23 | PreSize 24 | CropSize 25 | ARange 26 | BRange 27 | end 28 | 29 | methods (Static) 30 | function [inputs, remaining] = parseInputs(varargin) 31 | parser = inputParser(); 32 | % Remaining inputs should be for the imageAugmenter 33 | parser.KeepUnmatched = true; 34 | parser.addParameter('PreSize', [256, 256]); 35 | parser.addParameter('CropSize', [256, 256]); 36 | parser.addParameter('ARange', 255); 37 | parser.addParameter('BRange', 255); 38 | parser.parse(varargin{:}); 39 | inputs = parser.Results; 40 | remaining = parser.Unmatched; 41 | end 42 | end 43 | 44 | methods 45 | function obj = PairedImageDatastore(dirA, dirB, miniBatchSize, varargin) 46 | % Create a PairedImageDatastore 47 | % 48 | % Args: 49 | % dirA - directory or cell array of filenames 50 | % dirB - directory or cell array of filenames 51 | % miniBatchSize - Number of image pairs to provide in each 52 | % minibatch 53 | % TODO list optional name-value pairs PreSize, CropSize, 54 | % Mirror 55 | % 56 | % Note: 57 | % This datastore relies on the naming of image files in the 58 | % two directory to appear in the same ordering for correct 59 | % pairing. The simplest way to ensure this is if pairs of 60 | % images both have the same name. 61 | 62 | includeSubFolders = true; 63 | 64 | obj.DirA = dirA; 65 | obj.DirB = dirB; 66 | obj.ImagesA = imageDatastore(obj.DirA, "IncludeSubfolders", includeSubFolders); 67 | obj.ImagesB = imageDatastore(obj.DirB, "IncludeSubfolders", includeSubFolders); 68 | obj.MiniBatchSize = miniBatchSize; 69 | 70 | assert(numel(obj.ImagesA.Files) == numel(obj.ImagesB.Files), ... 71 | 'p2p:datastore:notMatched', ... 72 | 'Number of files in A and B folders do not match'); 73 | obj.NumObservations = numel(obj.ImagesA.Files); 74 | 75 | % Handle optional arguments 76 | [inputs, remaining] = obj.parseInputs(varargin{:}); 77 | 78 | obj.ARange = inputs.ARange; 79 | obj.BRange = inputs.BRange; 80 | obj.Augmenter = imageDataAugmenter(remaining); 81 | obj.PreSize = inputs.PreSize; 82 | obj.CropSize = inputs.CropSize; 83 | 84 | end 85 | 86 | function tf = hasdata(obj) 87 | tf = obj.ImagesA.hasdata() && obj.ImagesB.hasdata(); 88 | end 89 | 90 | function data = read(obj) 91 | imagesA = obj.ImagesA.read(); 92 | imagesB = obj.ImagesB.read(); 93 | 94 | % for batch size 1 imagedatastore doesn't wrap in a cell 95 | if ~iscell(imagesA) 96 | imagesA = {imagesA}; 97 | imagesB = {imagesB}; 98 | end 99 | [transformedA, transformedB] = ... 100 | p2p.data.transformImagePair(imagesA, imagesB, ... 101 | obj.PreSize, obj.CropSize, ... 102 | obj.Augmenter); 103 | [A, B] = obj.normaliseImages(transformedA, transformedB); 104 | data = table(A, B); 105 | end 106 | 107 | function reset(obj) 108 | obj.ImagesA.reset(); 109 | obj.ImagesB.reset(); 110 | end 111 | 112 | function objNew = shuffle(obj) 113 | objNew = obj.copy(); 114 | numObservations = objNew.NumObservations; 115 | objNew.ImagesA = copy(obj.ImagesA); 116 | objNew.ImagesB = copy(obj.ImagesB); 117 | idx = randperm(numObservations); 118 | 119 | objNew.ImagesA.Files = objNew.ImagesA.Files(idx); 120 | objNew.ImagesB.Files = objNew.ImagesB.Files(idx); 121 | end 122 | 123 | function [aOut, bOut] = normaliseImages(obj, aIn, bIn) 124 | aOut = cellfun(@(x) 2*(single(x)/obj.ARange) - 1, aIn, 'UniformOutput', false); 125 | bOut = cellfun(@(x) 2*(single(x)/obj.BRange) - 1, bIn, 'UniformOutput', false); 126 | end 127 | 128 | function val = get.MiniBatchSize(obj) 129 | val = obj.MiniBatchSize_; 130 | end 131 | 132 | function set.MiniBatchSize(obj, val) 133 | obj.ImagesA.ReadSize = val; 134 | obj.ImagesB.ReadSize = val; 135 | obj.MiniBatchSize_ = val; 136 | end 137 | end 138 | end -------------------------------------------------------------------------------- /code/+p2p/+data/transformImagePair.m: -------------------------------------------------------------------------------- 1 | function [transformedA, transformedB] = transformImagePair(imagesA, imagesB, preSize, cropSize, augmenter) 2 | % transformImagePair Apply a matching set of transformations to images 3 | % 4 | % Args: 5 | % imagesA - cell array of images to transform 6 | % imagesB - cell array of images to transform 7 | % preSize - [1x2] dimensions to initially resize image to 8 | % cropSize - [1x2] dimensions to crop image to 9 | % augment - imageDataAugmenter to use for image transforms 10 | % 11 | % Returns: 12 | % transformedA - cell array of transformed images A 13 | % transformedB - cell array of transformed images B 14 | 15 | % Copyright 2020 The MathWorks, Inc. 16 | 17 | % Default to identity transform 18 | transformedA = imagesA; 19 | transformedB = imagesB; 20 | 21 | % Apply a resize opertion 22 | if ~isempty(preSize) 23 | transformedA = cellfun(@(im) imresize(im, preSize), ... 24 | transformedA, ... 25 | 'UniformOutput', false); 26 | transformedB = cellfun(@(im) imresize(im, preSize), ... 27 | transformedB, ... 28 | 'UniformOutput', false); 29 | end 30 | 31 | % Apply the imageDataAugmenter 32 | if ~isempty(augmenter) 33 | [transformedA, transformedB] = augmenter.augmentPair(transformedA, transformedB); 34 | end 35 | 36 | % Apply a random crop 37 | if ~isempty(cropSize) 38 | [transformedA, transformedB] = randCrop(transformedA, transformedB, cropSize); 39 | end 40 | 41 | end 42 | 43 | function [imOut1, imOut2] = randCrop(im1, im2, cropSize) 44 | rect = augmentedImageDatastore.randCropRect(im1, cropSize); 45 | doCrop = @(im) augmentedImageDatastore.cropGivenDiscreteValuedRect(im, rect); 46 | imOut1 = cellfun(doCrop, im1, 'UniformOutput', false); 47 | imOut2 = cellfun(doCrop, im2, 'UniformOutput', false); 48 | 49 | end -------------------------------------------------------------------------------- /code/+p2p/+networks/block.m: -------------------------------------------------------------------------------- 1 | function layers = block(id, nChannels, direction, varargin) 2 | % block Base building block of networks 3 | 4 | % Copyright 2020 The MathWorks, Inc. 5 | 6 | parser = inputParser(); 7 | parser.addRequired('nChannels'); 8 | parser.addRequired('direction'); 9 | parser.addParameter('NormType', 'batch'); 10 | parser.addParameter('DoNorm', true); 11 | parser.addParameter('Dropout', false); 12 | parser.addParameter('KernelSize', 4); 13 | parser.parse(nChannels, direction, varargin{:}); 14 | inputs = parser.Results; 15 | 16 | switch inputs.direction 17 | case 'down' 18 | conv = @convolution2dLayer; 19 | padName = 'Padding'; 20 | case 'up' 21 | conv = @transposedConv2dLayer; 22 | padName = 'Cropping'; 23 | otherwise 24 | error('Unrecognised parameter'); 25 | end 26 | 27 | layers = conv(inputs.KernelSize, inputs.nChannels, ... 28 | 'Name', sprintf('conv_%s', id), ... 29 | padName, 'same', ... 30 | 'Stride', 2, ... 31 | 'WeightsInitializer', @(sz) 0.02*randn(sz, 'single')); 32 | 33 | if inputs.DoNorm 34 | switch inputs.NormType 35 | case 'instance' 36 | layers = [layers; p2p.networks.instanceNormalizationLayer(sprintf('in_%s', id))]; 37 | case 'batch' 38 | layers = [layers; batchNormalizationLayer('Name', sprintf('bn_%s', id))]; 39 | case 'none' 40 | % no normalization 41 | otherwise 42 | error('p2p:networks:badNorm', 'unrecognised normalisation type ''%s''.', inputs.NormType) 43 | end 44 | end 45 | if inputs.Dropout 46 | layers = [layers; dropoutLayer(0.5, 'Name', sprintf('drop_%s', id))]; 47 | end 48 | layers = [layers; leakyReluLayer(0.2, 'Name', sprintf('lrelu_%s', id))]; 49 | end -------------------------------------------------------------------------------- /code/+p2p/+networks/discriminator.m: -------------------------------------------------------------------------------- 1 | function model = discriminator(inputSize, inputChannels, depth) 2 | % discriminator pix2pix discriminator network 3 | 4 | % Copyright 2020 The MathWorks, Inc. 5 | 6 | layers = imageInputLayer([inputSize, inputChannels], 'Name', 'inputImage', 'Normalization', 'none'); 7 | 8 | downChannels = [64, 128, 256, 512]; 9 | 10 | if nargin < 3 11 | depth = numel(downChannels); 12 | else 13 | assert(depth <= 4, ... 14 | 'p2p:networks:discriminator', ... 15 | 'Current max depth is 4'); 16 | % Modify down and up channels accordingly 17 | downChannels = downChannels(1:depth); 18 | end 19 | 20 | for iLevel = 1:depth 21 | 22 | if iLevel == 1 23 | doNorm = false; 24 | else 25 | doNorm = true; 26 | end 27 | 28 | layers = [layers 29 | p2p.networks.downBlock(sprintf('D_%d', iLevel), downChannels(iLevel), ... 30 | 'DoNorm', doNorm)]; 31 | end 32 | 33 | layers = [layers 34 | convolution2dLayer(1, 1, ... 35 | 'Name', 'outputLayer', ... 36 | 'Padding', 'same', ... 37 | 'Stride', 1, ... 38 | 'WeightsInitializer', @(sz) 0.02*randn(sz, 'single'))]; 39 | lg = layerGraph(layers); 40 | model = dlnetwork(lg); 41 | end -------------------------------------------------------------------------------- /code/+p2p/+networks/downBlock.m: -------------------------------------------------------------------------------- 1 | function out = downBlock(id, nChannels, varargin) 2 | % downBlock Downsampling block 3 | 4 | % Copyright 2020 The MathWorks, Inc. 5 | 6 | out = p2p.networks.block(id, nChannels, 'down', varargin{:}); 7 | 8 | end 9 | -------------------------------------------------------------------------------- /code/+p2p/+networks/generator.m: -------------------------------------------------------------------------------- 1 | function model = generator(inputSize, inputChannels, outputChannels, depth) 2 | % generator pix2pix generator network 3 | 4 | % Copyright 2020 The MathWorks, Inc. 5 | 6 | downChannels = [64, 128, 256, 512, 512, 512, 512, 512]; 7 | upChannels = [512, 512, 512, 512, 512, 256, 128, 64]; 8 | 9 | if nargin < 4 10 | depth = numel(downChannels); 11 | else 12 | assert(depth <= 8, ... 13 | 'p2p:networks:generator', ... 14 | 'Current max depth is 8'); 15 | % Modify down and up channels accordingly 16 | downChannels = downChannels(1:depth); 17 | upChannels = upChannels(9-depth:end); 18 | end 19 | 20 | layers = imageInputLayer([inputSize, inputChannels], ... 21 | 'Name', 'inputImage', ... 22 | 'Normalization', 'none'); 23 | 24 | for iLevel = 1:depth 25 | 26 | if iLevel == 1 27 | doNorm = false; 28 | else 29 | doNorm = true; 30 | end 31 | 32 | layers = [layers 33 | p2p.networks.downBlock(sprintf('down_%d', iLevel), downChannels(iLevel), ... 34 | 'DoNorm', doNorm)]; 35 | end 36 | 37 | for iLevel = depth:-1:1 38 | if iLevel >= (depth-3) 39 | doDropout = true; 40 | else 41 | doDropout = false; 42 | end 43 | 44 | layers = [layers 45 | p2p.networks.upBlock(sprintf('up_%d', iLevel), ... 46 | upChannels(depth-iLevel+1), ... 47 | 'Dropout', doDropout)]; 48 | end 49 | 50 | layers = [layers 51 | convolution2dLayer(1, outputChannels, ... 52 | 'Padding', 'same', ... 53 | 'Stride', 1, ... 54 | 'Name', 'Output', ... 55 | 'WeightsInitializer', @(sz) 0.02*randn(sz, 'single'))]; 56 | 57 | lg = layerGraph(layers); 58 | % add the skip connections 59 | for iLevel = 1:depth-1 60 | lg = lg.connectLayers(sprintf('lrelu_down_%d', iLevel), sprintf('cat_up_%d/in2', iLevel+1)); 61 | end 62 | lg = lg.connectLayers('inputImage', sprintf('cat_up_%d/in2', 1)); 63 | 64 | model = dlnetwork(lg); 65 | end 66 | -------------------------------------------------------------------------------- /code/+p2p/+networks/instanceNormalizationLayer.m: -------------------------------------------------------------------------------- 1 | classdef instanceNormalizationLayer < nnet.layer.Layer 2 | % instanceNormalizationLayer Instance Normalization 3 | 4 | % Copyright 2020 The MathWorks, Inc. 5 | 6 | properties (Learnable) 7 | Scale 8 | Offset 9 | end 10 | 11 | properties 12 | Epsilon = 1e-5; 13 | end 14 | 15 | methods 16 | function layer = instanceNormalizationLayer(name) 17 | layer.Name = name; 18 | layer.Scale = 1; 19 | layer.Offset = 0; 20 | end 21 | 22 | function Y = predict(layer, X) 23 | % Apply instance normalization to X 24 | 25 | means = mean(X, [1, 2]); 26 | variances = var(X, 1, [1, 2]); 27 | Y = (X - means)./sqrt(variances + layer.Epsilon); 28 | 29 | Y = layer.Scale.*Y + layer.Offset; 30 | end 31 | 32 | end 33 | end -------------------------------------------------------------------------------- /code/+p2p/+networks/upBlock.m: -------------------------------------------------------------------------------- 1 | function out = upBlock(id, nChannels, varargin) 2 | % upBlock Upsampling block 3 | 4 | % Copyright 2020 The MathWorks, Inc. 5 | 6 | out = p2p.networks.block(id, nChannels, 'up', varargin{:}); 7 | out = [out; depthConcatenationLayer(2, 'Name', sprintf('cat_%s', id))]; 8 | 9 | end -------------------------------------------------------------------------------- /code/+p2p/+util/AdamOptimiser.m: -------------------------------------------------------------------------------- 1 | classdef AdamOptimiser < handle 2 | % AdamOptimiser A convenience class for handling Adam state. 3 | % 4 | % This class takes care of keeping track of the running statistics of 5 | % average gradient and average squared gradient. It also automatically 6 | % increments the current iteration on every call to update. 7 | % 8 | % Note: 9 | % The parameter update does NOT use weight decay. 10 | 11 | % Copyright 2020 The MathWorks, Inc. 12 | 13 | properties 14 | LearnRate 15 | Beta1 16 | Beta2 17 | Iteration = 1 18 | AvgGradient 19 | AvgGradientSq 20 | end 21 | 22 | methods 23 | function obj = AdamOptimiser(lr, beta1, beta2) 24 | obj.LearnRate = lr; 25 | obj.Beta1 = beta1; 26 | obj.Beta2 = beta2; 27 | end 28 | 29 | function updatedParams = update(obj, params, gradients) 30 | % Apply the Adam update to parameters given a set of gradients 31 | 32 | if ~isempty(obj.AvgGradient) 33 | assert(height(params) == height(obj.AvgGradient), ... 34 | "p2p:util:AdamOptimiser", ... 35 | "Size of parameters should not change during optimisation."); 36 | end 37 | 38 | [updatedParams, obj.AvgGradient, obj.AvgGradientSq] = ... 39 | adamupdate(params, gradients, ... 40 | obj.AvgGradient, obj.AvgGradientSq, ... 41 | obj.Iteration, obj.LearnRate, obj.Beta1, obj.Beta2); 42 | obj.Iteration = obj.Iteration + 1; 43 | end 44 | end 45 | end -------------------------------------------------------------------------------- /code/+p2p/+util/downloadFacades.m: -------------------------------------------------------------------------------- 1 | function [aFolder, bFolder] = downloadFacades(destination) 2 | % downloadFacades Saves a copy of the facades dataset images. 3 | % 4 | % Inputs: 5 | % destination - Location to save dataset to (default: "./datasets/facades") 6 | % Returns: 7 | % aFolder - Location of label images 8 | % bFolder - Location of target images 9 | 10 | % Copyright 2020 The MathWorks, Inc. 11 | 12 | if nargin < 1 13 | destination = "./datasets/facades"; 14 | end 15 | 16 | aFolder = fullfile(destination, "A"); 17 | bFolder = fullfile(destination, "B"); 18 | 19 | if ~isfolder(destination) 20 | mkdir(destination); 21 | mkdir(aFolder); 22 | mkdir(bFolder); 23 | end 24 | 25 | dataUrl = "http://cmp.felk.cvut.cz/~tylecr1/facade/CMP_facade_DB_base.zip"; 26 | tempZipFile = tempname; 27 | tempUnzippedFolder = tempname; 28 | fprintf("Downloading facades dataset...") 29 | websave(tempZipFile, dataUrl); 30 | fprintf("done.\n") 31 | 32 | fprintf("Extracting zip...") 33 | unzip(tempZipFile, tempUnzippedFolder); 34 | fprintf("done.\n") 35 | 36 | 37 | % Labels are indexed pngs 38 | movefile(fullfile(tempUnzippedFolder, "base", "*.png"), aFolder); 39 | % Convert them all to RGB 40 | convertToRgb(aFolder); 41 | 42 | % Photos are RGB jpgs 43 | movefile(fullfile(tempUnzippedFolder, "base", "*.jpg"), bFolder); 44 | 45 | fprintf("done.\n") 46 | 47 | end 48 | 49 | function convertToRgb(directory) 50 | % Converts all the images in the directory to RGB. 51 | ims = imageDatastore(directory); 52 | for iIm = 1:numel(ims.Files) 53 | filename = ims.Files{iIm}; 54 | [im, map] = imread(filename); 55 | rgbIm = ind2rgb(im, map); 56 | imwrite(rgbIm, filename); 57 | end 58 | end -------------------------------------------------------------------------------- /code/+p2p/+vis/TrainingPlot.m: -------------------------------------------------------------------------------- 1 | classdef TrainingPlot < handle 2 | % TrainingPlot Displays training progress 3 | 4 | % Copyright 2020 The MathWorks, Inc. 5 | 6 | properties (Access = private) 7 | TiledChart 8 | InputsAx 9 | OutputsAx 10 | LossAx1 11 | LossAx2 12 | InputsIm 13 | OutputsIm 14 | ExampleInputs 15 | Lines = matlab.graphics.animation.AnimatedLine.empty 16 | StartTime 17 | end 18 | 19 | methods 20 | function obj = TrainingPlot(exampleInputs) 21 | 22 | obj.StartTime = datetime("now"); 23 | 24 | trainingName = sprintf("pix2pix training started at %s", ... 25 | obj.StartTime); 26 | fig = figure("Units", "Normalized", ... 27 | "Position", [0.1, 0.1, 0.7, 0.6], ... 28 | "Name", trainingName, ... 29 | "NumberTitle", "off", ... 30 | "Tag", "p2p.vis.TrainingPlot"); 31 | obj.TiledChart = tiledlayout(fig, 3, 4, ... 32 | "TileSpacing", "compact", ... 33 | "Padding", "compact"); 34 | obj.InputsAx = nexttile(obj.TiledChart, 1, [2, 2]); 35 | obj.OutputsAx = nexttile(obj.TiledChart, 3, [2, 2]); 36 | obj.LossAx1 = nexttile(obj.TiledChart, 9, [1, 2]); 37 | obj.LossAx2 = nexttile(obj.TiledChart, 11, [1, 2]); 38 | 39 | obj.ExampleInputs = exampleInputs; 40 | obj.initImages(); 41 | obj.initLines(); 42 | drawnow(); 43 | end 44 | 45 | function update(obj, epoch, iteration, ... 46 | gLoss, lossL1, ganLoss, dLoss, generator) 47 | obj.updateImages(generator) 48 | obj.updateLines(epoch, iteration, gLoss, lossL1, ganLoss, dLoss); 49 | drawnow(); 50 | end 51 | 52 | function initImages(obj) 53 | displayIm = obj.prepForPlot(obj.ExampleInputs); 54 | montageIm = imtile(displayIm); 55 | obj.InputsIm = imshow(montageIm, "Parent", obj.InputsAx); 56 | 57 | zeroIm = 0*montageIm; 58 | obj.OutputsIm = imshow(zeroIm, "Parent", obj.OutputsAx); 59 | end 60 | 61 | function updateImages(obj, generator) 62 | output = tanh(generator.forward(obj.ExampleInputs)); 63 | displayIm = obj.prepForPlot(output); 64 | obj.OutputsIm.CData = imtile(displayIm); 65 | end 66 | 67 | function initLines(obj) 68 | % First plot just for generator 69 | obj.Lines(1) = animatedline(obj.LossAx1, ... 70 | "LineWidth", 1, ... 71 | "DisplayName", "Generator total"); 72 | xlabel(obj.LossAx1, "Iteration"); 73 | ylabel(obj.LossAx1, "Loss"); 74 | grid(obj.LossAx1, "on"); 75 | legend(obj.LossAx1); 76 | 77 | % Remaining plots for other losses 78 | nLines = 3; 79 | cMap = parula(nLines); 80 | labels = ["L1 loss", "GAN loss", "Discriminator loss"]; 81 | for iLine = 1:nLines 82 | obj.Lines(iLine + 1) = animatedline(obj.LossAx2, ... 83 | "Color", cMap(iLine, :), ... 84 | "LineWidth", 1, ... 85 | "DisplayName", labels(iLine)); 86 | end 87 | xlabel(obj.LossAx2, "Iteration"); 88 | ylabel(obj.LossAx2, "Loss"); 89 | grid(obj.LossAx2, "on"); 90 | legend(obj.LossAx2); 91 | end 92 | 93 | function updateLines(obj, epoch, iteration, gLoss, lossL1, ganLoss, dLoss) 94 | titleString = sprintf("Current epoch: %d, elapsed time: %s", ... 95 | epoch, datetime("now") - obj.StartTime); 96 | title(obj.LossAx1, titleString); 97 | addpoints(obj.Lines(1), iteration, double(gLoss)); 98 | addpoints(obj.Lines(2), iteration, double(lossL1)); 99 | addpoints(obj.Lines(3), iteration, double(ganLoss)); 100 | addpoints(obj.Lines(4), iteration, double(dLoss)); 101 | end 102 | 103 | end 104 | 105 | methods (Static) 106 | function imOut = prepForPlot(im) 107 | nChannels = size(im, 3); 108 | imOut = (gather(extractdata(im)) + 1)/2; 109 | 110 | % only take the first channel for n != 3 111 | if nChannels ~= 3 112 | imOut = imOut(:,:,1,:); 113 | end 114 | end 115 | end 116 | end 117 | -------------------------------------------------------------------------------- /code/+p2p/train.m: -------------------------------------------------------------------------------- 1 | function p2pModel = train(inData, outData, options) 2 | % train Train a pix2pix model. 3 | % 4 | % A pix2pix model that attempts to learn how to convert images from 5 | % inData to outData is trained, following the approach described in 6 | % 'Isola et al. Image-to-Image Translation with Conditional Adversarial 7 | % Nets'. 8 | % 9 | % Args: 10 | % inData - Training data input images 11 | % outData - Training data target images 12 | % options - Training options as a struct generated by p2p.trainingOptions 13 | % 14 | % Returns: 15 | % p2pModel - A struct containing trained newtorks and optimisiers 16 | % 17 | % See also: p2p.trainingOptions 18 | 19 | % Copyright 2020 The MathWorks, Inc. 20 | 21 | if nargin < 3 22 | options = p2p.getDefaultOptions(); 23 | end 24 | 25 | if (options.ExecutionEnvironment == "auto" && canUseGPU) || ... 26 | options.ExecutionEnvironment == "gpu" 27 | env = @gpuArray; 28 | else 29 | env = @(x) x; 30 | end 31 | 32 | if ~isempty(options.CheckpointPath) 33 | % Make a subfolder for storing checkpoints 34 | timestamp = strcat("p2p-", datestr(now, 'yyyymmdd-HHMMSS')); 35 | checkpointSubDir = fullfile(options.CheckpointPath, timestamp); 36 | mkdir(checkpointSubDir) 37 | end 38 | 39 | combinedChannels = options.InputChannels + options.OutputChannels; 40 | 41 | % model learns A to B mapping 42 | imageAndLabel = p2p.data.PairedImageDatastore(inData, outData, options.MiniBatchSize, ... 43 | "PreSize", options.PreSize, "CropSize", options.InputSize, "RandXReflection", options.RandXReflection); 44 | 45 | if options.Plots == "training-progress" 46 | examples = imageAndLabel.shuffle(); 47 | nExamples = 9; 48 | examples.MiniBatchSize = nExamples; 49 | data = examples.read(); 50 | thisInput = cat(4, data.A{:}); 51 | exampleInputs = dlarray(env(thisInput), 'SSCB'); 52 | trainingPlot = p2p.vis.TrainingPlot(exampleInputs); 53 | end 54 | 55 | if isempty(options.ResumeFrom) 56 | g = p2p.networks.generator(options.InputSize, options.InputChannels, options.OutputChannels, options.GDepth); 57 | d = p2p.networks.discriminator(options.InputSize, combinedChannels, options.DDepth); 58 | 59 | gOptimiser = p2p.util.AdamOptimiser(options.GLearnRate, options.GBeta1, options.GBeta2); 60 | dOptimiser = p2p.util.AdamOptimiser(options.DLearnRate, options.DBeta1, options.DBeta2); 61 | 62 | iteration = 0; 63 | startEpoch = 1; 64 | else 65 | data = load(options.ResumeFrom, 'p2pModel'); 66 | g = data.p2pModel.g; 67 | d = data.p2pModel.d; 68 | gOptimiser = data.p2pModel.gOptimiser; 69 | dOptimiser = data.p2pModel.dOptimiser; 70 | 71 | iteration = gOptimiser.Iteration; 72 | startEpoch = floor(iteration/imageAndLabel.NumObservations)+1; 73 | end 74 | 75 | %% Training loop 76 | for epoch = startEpoch:options.MaxEpochs 77 | 78 | imageAndLabel = imageAndLabel.shuffle(); 79 | 80 | while imageAndLabel.hasdata 81 | 82 | iteration = iteration + 1; 83 | 84 | data = imageAndLabel.read(); 85 | thisInput = cat(4, data.A{:}); 86 | thisTarget = cat(4, data.B{:}); 87 | 88 | inputImage = dlarray(env(thisInput), 'SSCB'); 89 | targetImage = dlarray(env(thisTarget), 'SSCB'); 90 | 91 | [g, gLoss, d, dLoss, lossL1, ganLoss, ~] = ... 92 | dlfeval(@stepBoth, g, d, gOptimiser, dOptimiser, inputImage, targetImage, options); 93 | 94 | if mod(iteration, options.VerboseFrequency) == 0 95 | logArgs = {epoch, iteration, ... 96 | gLoss, lossL1, ganLoss, dLoss}; 97 | fprintf('epoch: %d, it: %d, G: %f (L1: %f, GAN: %f), D: %f\n', ... 98 | logArgs{:}); 99 | if options.Plots == "training-progress" 100 | trainingPlot.update(logArgs{:}, g); 101 | end 102 | end 103 | end 104 | 105 | p2pModel = struct('g', g, 'd', d, 'gOptimiser', gOptimiser, 'dOptimiser', dOptimiser); 106 | if ~isempty(options.CheckpointPath) 107 | checkpointFilename = sprintf('p2p_checkpoint_%s_%04d.mat', datestr(now, 'YYYY-mm-DDTHH-MM-ss'), epoch); 108 | p2pModel = gather(p2pModel); 109 | save(fullfile(checkpointSubDir, checkpointFilename), 'p2pModel') 110 | end 111 | end 112 | end 113 | 114 | function [g, gLoss, d, dLoss, lossL1, ganLoss, images] = stepBoth(g, d, gOpt, dOpt, inputImage, targetImage, options) 115 | 116 | % Make a fake image 117 | fakeImage = tanh(g.forward(inputImage)); 118 | 119 | %% D update 120 | % Apply the discriminator 121 | realPredictions = sigmoid(d.forward(... 122 | cat(3, targetImage, inputImage) ... 123 | )); 124 | fakePredictions = sigmoid(d.forward(... 125 | cat(3, fakeImage, inputImage)... 126 | )); 127 | 128 | % calculate D losses 129 | labels = ones(size(fakePredictions), 'single'); 130 | % crossentropy divides by nBatch, so we need to divide further 131 | dLoss = options.DRelLearnRate*(crossentropy(realPredictions, labels)/numel(fakePredictions(:,:,1,1)) + ... 132 | crossentropy(1-fakePredictions, labels)/numel(fakePredictions(:,:,1,1))); 133 | 134 | % get d gradients 135 | dGrads = dlgradient(dLoss, d.Learnables, "RetainData", true); 136 | dLoss = extractdata(dLoss); 137 | 138 | %% G update 139 | % to save time I just use the existing result from d 140 | 141 | % calculate g Losses 142 | ganLoss = crossentropy(fakePredictions, labels)/numel(fakePredictions(:,:,1,1)); 143 | lossL1 = mean(abs(fakeImage - targetImage), 'all'); 144 | gLoss = options.Lambda*lossL1 + ganLoss; 145 | 146 | % get g grads 147 | gGrads = dlgradient(gLoss, g.Learnables); 148 | 149 | % update g 150 | g.Learnables = dOpt.update(g.Learnables, gGrads); 151 | % update d 152 | d.Learnables = gOpt.update(d.Learnables, dGrads); 153 | % things for plotting 154 | gLoss = extractdata(gLoss); 155 | lossL1 = extractdata(lossL1); 156 | ganLoss = extractdata(ganLoss); 157 | 158 | images = {fakeImage, inputImage, targetImage}; 159 | end 160 | -------------------------------------------------------------------------------- /code/+p2p/trainingOptions.m: -------------------------------------------------------------------------------- 1 | function options = trainingOptions(varargin) 2 | % trainingOptions Create options struct for training pix2pix model 3 | % 4 | % By default the struct will contain parameters which are close to those 5 | % described in the original pix2pix paper. To change any parameters 6 | % either modify the struct after creation, or pass in Name-Value pairs to 7 | % this function. 8 | % 9 | % trainingOptions accepts the following Name-Value pairs: 10 | % 11 | % ExecutionEnvironment - What processor to use for image translation: 12 | % "auto", "cpu", or, "gpu" (default: "auto") 13 | % InputChannels - Number of channels in the input image 14 | % (default: 3) 15 | % OutputChannels - Number of channels in the target image 16 | % (default: 3) 17 | % MiniBatchSize - MiniBatch size during training (default: 1) 18 | % RandXReflection - Whether to apply horizontal flipping data 19 | % augmentation (default: true) 20 | % PreSize - Dimensions to initially resize image to 21 | % (before cropping) (default: [286, 286]) 22 | % InputSize - Dimension to crop images to 23 | % (default: [256, 256]) 24 | % ARange - Maximum numeric value of input images 25 | % (default: 255) 26 | % BRange - Maximum numeric value of target images 27 | % (default: 255) 28 | % ResumeFrom - File path to resume training from checkpoint 29 | % (default: []) 30 | % GLearnRate - Learn rate of the generator's optimizer 31 | % (default: 0.0002) 32 | % GBeta1 - Beta 1 parameter of the generator's 33 | % optimizer(default: 0.5) 34 | % GBeta2 - Beta 2 parameter of the generator's 35 | % optimizer(default: 0.999) 36 | % DLearnRate - Learn rate of the discriminator's optimizer 37 | % (default: 0.0002) 38 | % DBeta1 - Beta 1 parameter of the discriminator's 39 | % optimizer(default: 0.5) 40 | % DBeta2 - Beta 2 parameter of the discriminator's 41 | % optimizer(default: 0.999) 42 | % MaxEpochs - Total epochs for training (default: 200) 43 | % CheckpointPath - Path to a folder to save checkpoints to 44 | % (default: "checkpoints") 45 | % DRelLearnRate - Relative scaling factor for the 46 | % discriminator's loss (default: 0.5) 47 | % Lambda - Relative scaling factor for the L1 loss 48 | % (default: 100) 49 | % GDepth - Depth of the generator (default: 8) 50 | % DDepth - Depth of the discriminator (default: 4) 51 | % Verbose - Whether to print status to command line 52 | % (default: true) 53 | % VerboseFrequency - Frequency of plot and command line update in 54 | % iterations (default: 50) 55 | % Plots - Plot type to show during training: "none" or 56 | % "training-progress" (default: "training-progress") 57 | % 58 | % See also: p2p.train 59 | 60 | % Copyright 2020 The MathWorks, Inc. 61 | 62 | parser = inputParser(); 63 | 64 | parser.addParameter("ExecutionEnvironment", "auto", ... 65 | @(x) ismember(x, ["auto", "cpu", "gpu"])); 66 | parser.addParameter("InputChannels", 3, ... 67 | @(x) validateattributes(x, "numeric", ["scalar","integer","positive"])); 68 | parser.addParameter("OutputChannels", 3, ... 69 | @(x) validateattributes(x, "numeric", ["scalar","integer","positive"])); 70 | parser.addParameter("MiniBatchSize", 1, ... 71 | @(x) validateattributes(x, "numeric", ["scalar","integer","positive"])); 72 | parser.addParameter("RandXReflection", true, ... 73 | @(x) validateattributes(x, "logical", "scalar")); 74 | parser.addParameter("PreSize", [286, 286], ... 75 | @(x) validateattributes(x, "numeric", ["positive", "integer"])); 76 | parser.addParameter("InputSize", [256, 256], ... 77 | @(x) validateattributes(x, "numeric", ["positive", "integer"])); 78 | parser.addParameter("ARange", 255, ... 79 | @(x) validateattributes(x, "numeric", "positive")); 80 | parser.addParameter("BRange", 255, ... 81 | @(x) validateattributes(x, "numeric", "positive")); 82 | parser.addParameter("ResumeFrom", [], ... 83 | @(x) validateattributes(x, ["char", "string"], "scalartext")); 84 | parser.addParameter("GLearnRate", 0.0002, ... 85 | @(x) validateattributes(x, "numeric", "scalar")); 86 | parser.addParameter("GBeta1", 0.5, ... 87 | @(x) validateattributes(x, "numeric", "scalar")); 88 | parser.addParameter("GBeta2", 0.999, ... 89 | @(x) validateattributes(x, "numeric", "scalar")); 90 | parser.addParameter("DLearnRate", 0.0002, ... 91 | @(x) validateattributes(x, "numeric", "scalar")); 92 | parser.addParameter("DBeta1", 0.5, ... 93 | @(x) validateattributes(x, "numeric", "scalar")); 94 | parser.addParameter("DBeta2", 0.999, ... 95 | @(x) validateattributes(x, "numeric", "scalar")); 96 | parser.addParameter("MaxEpochs", 200, ... 97 | @(x) validateattributes(x, "numeric", ["scalar","integer","positive"])); 98 | parser.addParameter("CheckpointPath", "checkpoints", ... 99 | @(x) validateattributes(x, ["char", "string"], "scalartext")); 100 | parser.addParameter("DRelLearnRate", 0.5, ... 101 | @(x) validateattributes(x, "numeric", "scalar")); 102 | parser.addParameter("Lambda", 100, ... 103 | @(x) validateattributes(x, "numeric", "scalar")); 104 | parser.addParameter("GDepth", 8, ... 105 | @(x) validateattributes(x, "numeric", ["scalar","integer","positive"])); 106 | parser.addParameter("DDepth", 4, ... 107 | @(x) validateattributes(x, "numeric", ["scalar","integer","positive"])); 108 | parser.addParameter("Verbose", true, ... 109 | @(x) validateattributes(x, "logical", "scalar")); 110 | parser.addParameter("VerboseFrequency", 50, ... 111 | @(x) validateattributes(x, "numeric", ["scalar","integer","positive"])); 112 | parser.addParameter("Plots", "training-progress", ... 113 | @(x) ismember(x, ["none", "training-progress"])); 114 | 115 | parser.parse(varargin{:}); 116 | options = parser.Results; 117 | 118 | % Convert path the char to ensure isempty checks work. 119 | options.CheckpointPath = convertStringsToChars(options.CheckpointPath); 120 | end -------------------------------------------------------------------------------- /code/+p2p/translate.m: -------------------------------------------------------------------------------- 1 | function translatedImage = translate(p2pModel, inputImage, varargin) 2 | % translate Apply a generator to an image. 3 | % 4 | % Args: 5 | % p2pModel - struct containing a pix2pix generator as produced by the 6 | % output of p2p.train 7 | % inputImage - Input image to be translated 8 | % 9 | % translate also accepts the following Name-Value pairs: 10 | % 11 | % ExecutionEnvironment - What processor to use for image translation, 12 | % "auto", "cpu", or, "gpu" (default: "auto") 13 | % ARange - Maximum numeric value of input image, used 14 | % for input scaling (default: 255) 15 | % 16 | % Returns: 17 | % translatedImage - Image translated by the generator model 18 | % 19 | % Note: 20 | % The input image must be a suitable size for the generator model 21 | % 22 | % See also: p2p.train 23 | 24 | % Copyright 2020 The MathWorks, Inc. 25 | 26 | options = parseInputs(varargin{:}); 27 | 28 | inputClass = class(inputImage); 29 | 30 | networkInput = prepImageForNetwork(inputImage, options); 31 | out = tanh(p2pModel.g.forward(networkInput)); 32 | 33 | % Make the output match the input 34 | translatedImage = (extractdata(out) + 1)/2; 35 | if strcmp(inputClass, "uint8") 36 | translatedImage = uint8(255*translatedImage); 37 | else 38 | translatedImage = cast(translatedImage, "like", inputImage); 39 | end 40 | 41 | end 42 | 43 | function options = parseInputs(varargin) 44 | % Parse name value pair arguments 45 | parser = inputParser(); 46 | parser.addParameter("ExecutionEnvironment", "auto", ... 47 | @(x) ismember(x, ["auto", "cpu", "gpu"])); 48 | parser.addParameter("ARange", 255, ... 49 | @(x) validateattributes(x, "numeric", "positive")); 50 | 51 | parser.parse(varargin{:}); 52 | options = parser.Results; 53 | end 54 | 55 | function networkInput = prepImageForNetwork(inputImage, options) 56 | % cast to single, scale and put on the gpu as appropriate 57 | networkInput = 2*single(inputImage)/options.ARange - 1; 58 | if (options.ExecutionEnvironment == "auto" && canUseGPU) || ... 59 | options.ExecutionEnvironment == "gpu" 60 | networkInput = gpuArray(networkInput); 61 | end 62 | networkInput = dlarray(networkInput, 'SSCB'); 63 | end -------------------------------------------------------------------------------- /docs/getting_started.mlx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pix2pix/dec1f666edf0858acee990701ea601376f756435/docs/getting_started.mlx -------------------------------------------------------------------------------- /docs/labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pix2pix/dec1f666edf0858acee990701ea601376f756435/docs/labels.png -------------------------------------------------------------------------------- /docs/output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pix2pix/dec1f666edf0858acee990701ea601376f756435/docs/output.jpg -------------------------------------------------------------------------------- /docs/target.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pix2pix/dec1f666edf0858acee990701ea601376f756435/docs/target.jpg -------------------------------------------------------------------------------- /docs/training.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pix2pix/dec1f666edf0858acee990701ea601376f756435/docs/training.gif -------------------------------------------------------------------------------- /facades.rights: -------------------------------------------------------------------------------- 1 | Creative Commons Attribution-ShareAlike 4.0 International Public License 2 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 3 | 4 | Section 1 – Definitions. 5 | 6 | Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. 7 | Adapter's License means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. 8 | BY-SA Compatible License means a license listed at creativecommons.org/compatiblelicenses, approved by Creative Commons as essentially the equivalent of this Public License. 9 | Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. 10 | Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. 11 | Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. 12 | License Elements means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution and ShareAlike. 13 | Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this Public License. 14 | Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. 15 | Licensor means the individual(s) or entity(ies) granting rights under this Public License. 16 | Share means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. 17 | Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. 18 | You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. 19 | Section 2 – Scope. 20 | 21 | License grant. 22 | Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: 23 | reproduce and Share the Licensed Material, in whole or in part; and 24 | produce, reproduce, and Share Adapted Material. 25 | Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 26 | Term. The term of this Public License is specified in Section 6(a). 27 | Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. 28 | Downstream recipients. 29 | Offer from the Licensor – Licensed Material. Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. 30 | Additional offer from the Licensor – Adapted Material. Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter’s License You apply. 31 | No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 32 | No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). 33 | Other rights. 34 | 35 | Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 36 | Patent and trademark rights are not licensed under this Public License. 37 | To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties. 38 | Section 3 – License Conditions. 39 | 40 | Your exercise of the Licensed Rights is expressly made subject to the following conditions. 41 | 42 | Attribution. 43 | 44 | If You Share the Licensed Material (including in modified form), You must: 45 | 46 | retain the following if it is supplied by the Licensor with the Licensed Material: 47 | identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); 48 | a copyright notice; 49 | a notice that refers to this Public License; 50 | a notice that refers to the disclaimer of warranties; 51 | a URI or hyperlink to the Licensed Material to the extent reasonably practicable; 52 | indicate if You modified the Licensed Material and retain an indication of any previous modifications; and 53 | indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 54 | You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 55 | If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 56 | ShareAlike. 57 | In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply. 58 | 59 | The Adapter’s License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-SA Compatible License. 60 | You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material. 61 | You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply. 62 | Section 4 – Sui Generis Database Rights. 63 | 64 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: 65 | 66 | for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database; 67 | if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and 68 | You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. 69 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. 70 | Section 5 – Disclaimer of Warranties and Limitation of Liability. 71 | 72 | Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You. 73 | To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You. 74 | The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 75 | Section 6 – Term and Termination. 76 | 77 | This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. 78 | Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 79 | 80 | automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 81 | upon express reinstatement by the Licensor. 82 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. 83 | For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. 84 | Sections 1, 5, 6, 7, and 8 survive termination of this Public License. 85 | Section 7 – Other Terms and Conditions. 86 | 87 | The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 88 | Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. 89 | Section 8 – Interpretation. 90 | 91 | For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. 92 | To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. 93 | No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 94 | Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. -------------------------------------------------------------------------------- /install.m: -------------------------------------------------------------------------------- 1 | function install() 2 | addpath('code'); 3 | addpath('tests'); 4 | end -------------------------------------------------------------------------------- /runAllTests.m: -------------------------------------------------------------------------------- 1 | function runAllTests() 2 | import matlab.unittest.* 3 | import matlab.unittest.plugins.* 4 | import matlab.unittest.plugins.codecoverage.* 5 | 6 | try 7 | artifacts = "artifacts"; 8 | mkdir(artifacts); 9 | 10 | junitResults = fullfile(artifacts, "junit"); 11 | mkdir(junitResults); 12 | 13 | covResults = fullfile(artifacts, "coverage"); 14 | mkdir(covResults); 15 | 16 | % Add folders to path 17 | install(); 18 | 19 | % Assemble test quite 20 | suite = TestSuite.fromPackage('tests', 'IncludingSubpackages', true); 21 | runner = TestRunner.withTextOutput; 22 | 23 | % Add tests reults publish plugin 24 | xmlFile = fullfile(junitResults, "testResults.xml"); 25 | p = XMLPlugin.producingJUnitFormat(xmlFile); 26 | runner.addPlugin(p) 27 | 28 | % Add code coverage 29 | covFile = fullfile(covResults, "codeCoverage.xml"); 30 | p = CodeCoveragePlugin.forPackage('p2p',... 31 | 'IncludingSubPackages', true,... 32 | 'Producing', CoberturaFormat(covFile)); 33 | runner.addPlugin(p); 34 | 35 | % run the tests 36 | runner.run(suite); 37 | 38 | % exit with success 39 | exit(0); 40 | 41 | catch err 42 | % If there is an error then print report and exit 43 | err.getReport 44 | exit(1); 45 | end 46 | end -------------------------------------------------------------------------------- /tests/+tests/DatastoreTests.m: -------------------------------------------------------------------------------- 1 | classdef DatastoreTests < tests.WithWorkingDirectory 2 | % Tests for the custom PairedImageDatastore 3 | 4 | % Copyright 2020 The MathWorks, Inc. 5 | 6 | properties (TestParameter) 7 | miniBatchSize = {1, 2, 3} 8 | end 9 | 10 | methods (Test) 11 | function testReadData(test, miniBatchSize) 12 | ds = p2p.data.PairedImageDatastore(fullfile(test.Resources, 'imagePairs', 'A'), ... 13 | fullfile(test.Resources, 'imagePairs', 'B'), miniBatchSize); 14 | data = ds.read(); 15 | test.verifyEqual(size(data), [miniBatchSize, 2]); 16 | test.verifyClass(data, 'table'); 17 | end 18 | 19 | function testReset(test) 20 | ds = p2p.data.PairedImageDatastore(fullfile(test.Resources, 'imagePairs', 'A'), ... 21 | fullfile(test.Resources, 'imagePairs', 'B'), 1); 22 | data1 = ds.read(); 23 | ds.reset(); 24 | data2 = ds.read(); 25 | 26 | test.verifyEqual(data1, data2); 27 | end 28 | 29 | function testUnpairable(test) 30 | makeDatastore = @() p2p.data.PairedImageDatastore(fullfile(test.Resources, 'badImagePairs', 'A'), ... 31 | fullfile(test.Resources, 'badImagePairs', 'B'), 1); 32 | test.verifyError(makeDatastore, 'p2p:datastore:notMatched') 33 | end 34 | 35 | function testSetMiniBatchSize(test) 36 | newMiniBatchSize = 3; 37 | ds = p2p.data.PairedImageDatastore(fullfile(test.Resources, 'imagePairs', 'A'), ... 38 | fullfile(test.Resources, 'imagePairs', 'B'), 1); 39 | ds.MiniBatchSize = newMiniBatchSize; 40 | data = ds.read(); 41 | test.verifyEqual(size(data), [newMiniBatchSize, 2]); 42 | end 43 | 44 | function testOptionalArgs(test) 45 | cropSize = [64, 64]; 46 | ds = p2p.data.PairedImageDatastore(fullfile(test.Resources, 'imagePairs', 'A'), ... 47 | fullfile(test.Resources, 'imagePairs', 'B'), 1, ... 48 | "PreSize", [128, 128], ... 49 | "CropSize", cropSize,... 50 | "RandXReflection", false); 51 | 52 | data = ds.read(); 53 | 54 | test.verifyEqual(size(data{1,1}{1}), [cropSize, 3]); 55 | test.verifyEqual(size(data{1,2}{1}), [cropSize, 3]); 56 | end 57 | 58 | function testAugmenter(test) 59 | ds = p2p.data.PairedImageDatastore(fullfile(test.Resources, 'imagePairs', 'A'), ... 60 | fullfile(test.Resources, 'imagePairs', 'B'), 1, ... 61 | "RandRotation", [0, 360]); 62 | data1 = ds.read(); 63 | ds.reset(); 64 | data2 = ds.read(); 65 | 66 | test.verifyNotEqual(data1, data2); 67 | end 68 | 69 | end 70 | 71 | end -------------------------------------------------------------------------------- /tests/+tests/DiscriminatorTest.m: -------------------------------------------------------------------------------- 1 | classdef DiscriminatorTest < matlab.unittest.TestCase 2 | % Tests for the discriminator 3 | 4 | % Copyright 2020 The MathWorks, Inc. 5 | 6 | properties (TestParameter) 7 | miniBatchSize = {1, 2, 3} 8 | end 9 | 10 | 11 | methods (Test) 12 | function testForwardOutputSize(testCase, miniBatchSize) 13 | inputSize = [256, 256]; 14 | inputChannels = 3; 15 | inputData = dlarray(zeros([inputSize, inputChannels, miniBatchSize], 'single'), 'SSCB'); 16 | expectedSize = [16, 16, 1, miniBatchSize]; 17 | 18 | d = p2p.networks.discriminator(inputSize, inputChannels); 19 | output = extractdata(d.forward(inputData)); 20 | 21 | testCase.verifyEqual(size(output(:,:,1,1)), expectedSize(1:2)); 22 | testCase.verifyEqual(size(output, 3), expectedSize(3)); 23 | % Verify batch dimension separately in case it is 1 24 | testCase.verifyEqual(size(output, 4), expectedSize(4)); 25 | end 26 | end 27 | end -------------------------------------------------------------------------------- /tests/+tests/GeneratorTest.m: -------------------------------------------------------------------------------- 1 | classdef GeneratorTest < matlab.unittest.TestCase 2 | % Tests for the generator 3 | 4 | % Copyright 2020 The MathWorks, Inc. 5 | 6 | properties (TestParameter) 7 | miniBatchSize = {1, 2, 3} 8 | inputChannels = {1, 3} 9 | end 10 | 11 | 12 | methods (Test) 13 | function testForwardOutputSize(testCase, inputChannels, miniBatchSize) 14 | inputSize = [256, 256]; 15 | outputChannels = 3; 16 | inputData = dlarray(zeros([inputSize, inputChannels, miniBatchSize], 'single'), 'SSCB'); 17 | expectedSize = [256, 256, outputChannels, miniBatchSize]; 18 | 19 | g = p2p.networks.generator(inputSize, inputChannels, outputChannels); 20 | output = extractdata(g.forward(inputData)); 21 | 22 | testCase.verifyEqual(size(output(:,:,:,1)), expectedSize(1:3)); 23 | % Verify batch dimension separately in case it is 1 24 | testCase.verifyEqual(size(output, 4), expectedSize(4)); 25 | end 26 | 27 | end 28 | end -------------------------------------------------------------------------------- /tests/+tests/InstanceNormTest.m: -------------------------------------------------------------------------------- 1 | classdef InstanceNormTest < matlab.unittest.TestCase 2 | % Test instance normlization layer 3 | 4 | % Copyright 2020 The MathWorks, Inc. 5 | 6 | methods (Test) 7 | function testCreate(testCase) 8 | name = 'InstanceNorm1'; 9 | expectedScale = 1; 10 | expectedOffset = 0; 11 | layer = p2p.networks.instanceNormalizationLayer(name); 12 | 13 | testCase.verifyEqual(layer.Name, name); 14 | testCase.verifyEqual(layer.Scale, expectedScale); 15 | testCase.verifyEqual(layer.Offset, expectedOffset); 16 | end 17 | 18 | function testForward(testCase) 19 | layer = p2p.networks.instanceNormalizationLayer('test'); 20 | X = ones(10, 11, 3, 5); 21 | Y = layer.predict(X); 22 | 23 | testCase.verifyEqual(size(Y), size(X)); 24 | testCase.verifyTrue(all(Y == 0, 'all')); 25 | end 26 | 27 | function testForwardChannels(testCase) 28 | % test normalisation along channels 29 | layer = p2p.networks.instanceNormalizationLayer('test'); 30 | X = cat(3, 1*ones(10, 11, 1, 5), ... 31 | 2*ones(10, 11, 1, 5), ... 32 | 3*ones(10, 11, 1, 5)); 33 | Y = layer.predict(X); 34 | 35 | testCase.verifyEqual(size(Y), size(X)); 36 | testCase.verifyTrue(all(Y == 0, 'all')); 37 | end 38 | 39 | function testBatchNormEquivDlarray(testCase) 40 | % behaviour should be the same as batch norm for batchsize = 1 41 | nChannels = 9; 42 | X = dlarray(rand(13, 15, nChannels, 1), 'SSCB'); 43 | expected = batchnorm(X, zeros(nChannels, 1), ones(nChannels,1)); 44 | 45 | layer = p2p.networks.instanceNormalizationLayer('test'); 46 | Y = layer.predict(X); 47 | 48 | testCase.verifyEqual(extractdata(Y), extractdata(expected), ... 49 | 'RelTol', 1e-5) 50 | end 51 | 52 | function testForwardChannelsBatch(testCase) 53 | % test normalisation along channels and batch dim 54 | layer = p2p.networks.instanceNormalizationLayer('test'); 55 | oneX = cat(3, 1*ones(10, 11, 1, 1), ... 56 | 2*ones(10, 11, 1, 1), ... 57 | 3*ones(10, 11, 1, 1)); 58 | X = cat(4, oneX, 2*oneX, 3*oneX); 59 | Y = layer.predict(X); 60 | 61 | testCase.verifyEqual(size(Y), size(X)); 62 | testCase.verifyTrue(all(Y == 0, 'all')); 63 | end 64 | 65 | function testForwardOffset(testCase) 66 | layer = p2p.networks.instanceNormalizationLayer('test'); 67 | offset = 100; 68 | layer.Offset = offset; 69 | X = ones(10, 11, 3, 5); 70 | Y = layer.predict(X); 71 | 72 | testCase.verifyEqual(size(Y), size(X)); 73 | testCase.verifyTrue(all(Y == offset, 'all')); 74 | end 75 | 76 | function testLayerChecks(testCase) 77 | results = checkLayer(p2p.networks.instanceNormalizationLayer('test'), [10, 22, 4, 12]); 78 | testCase.verifyTrue(all([results.Passed])); 79 | end 80 | end 81 | 82 | end -------------------------------------------------------------------------------- /tests/+tests/TrainingPlotTest.m: -------------------------------------------------------------------------------- 1 | classdef TrainingPlotTest < matlab.unittest.TestCase 2 | % Test of training plots 3 | 4 | % Copyright 2020 The MathWorks, Inc. 5 | properties (TestParameter) 6 | inChannels = {1, 2, 3, 4} 7 | outChannels = {1, 1, 3, 1} 8 | end 9 | 10 | methods (Test, ParameterCombination='sequential') 11 | function testPrepForPlot(test, inChannels, outChannels) 12 | h = 100; 13 | w = 200; 14 | c = inChannels; 15 | im = dlarray(zeros(h, w, c, "single"), "SSCB"); 16 | 17 | imOut = p2p.vis.TrainingPlot.prepForPlot(im); 18 | 19 | test.assertEqual(size(imOut, [1, 2]), [h, w], ... 20 | "Height and width should be preserved"); 21 | test.assertEqual(size(imOut, 3), outChannels, ... 22 | "Channels not as expected"); 23 | end 24 | 25 | end 26 | end -------------------------------------------------------------------------------- /tests/+tests/TrainingTest.m: -------------------------------------------------------------------------------- 1 | classdef TrainingTest < tests.WithWorkingDirectory 2 | % Full example training test 3 | 4 | % Copyright 2020 The MathWorks, Inc. 5 | 6 | 7 | methods (Test) 8 | function testTrainAndTranslate(test) 9 | dataFolder = fullfile(toolboxdir('vision'),'visiondata','triangleImages'); 10 | filesA = imageDatastore(fullfile(dataFolder,'trainingLabels')).Files(1:5); 11 | filesB = imageDatastore(fullfile(dataFolder,'trainingImages')).Files(1:5); 12 | testImage = imread(filesA{1}); 13 | 14 | options = p2p.trainingOptions("InputChannels", 1, ... 15 | "OutputChannels", 1, ... 16 | "MaxEpochs", 1, ... 17 | "GDepth", 2, ... 18 | "DDepth", 2, ... 19 | "Plots", "none"); 20 | 21 | p2pModel = p2p.train(filesA, filesB, options); 22 | 23 | translatedImage = p2p.translate(p2pModel, testImage); 24 | 25 | test.verifyEqual(size(translatedImage), size(testImage)); 26 | end 27 | 28 | function testTrainAndResume(test) 29 | dataFolder = fullfile(toolboxdir('vision'),'visiondata','triangleImages'); 30 | filesA = imageDatastore(fullfile(dataFolder,'trainingLabels')).Files(1:5); 31 | filesB = imageDatastore(fullfile(dataFolder,'trainingImages')).Files(1:5); 32 | testImage = imread(filesA{1}); 33 | 34 | options = p2p.trainingOptions("InputChannels", 1, ... 35 | "OutputChannels", 1, ... 36 | "MaxEpochs", 1, ... 37 | "GDepth", 2, ... 38 | "DDepth", 2, ... 39 | "CheckpointPath", "temp", ... 40 | "Plots", "none"); 41 | 42 | p2p.train(filesA, filesB, options); 43 | 44 | checkpointFile = dir("temp/**/*.mat"); 45 | test.verifyEqual(numel(checkpointFile), 1, ... 46 | "One checkpoint should be saved"); 47 | checkpointFilepath = fullfile(checkpointFile.folder, checkpointFile.name); 48 | options.ResumeFrom = checkpointFilepath; 49 | 50 | % Already done 1 epoch, so this shouldn't do any more 51 | p2p.train(filesA, filesB, options); 52 | checkpointFile = dir("temp/**/*.mat"); 53 | test.verifyEqual(numel(checkpointFile), 1, ... 54 | "Shouldn't do any more epochs"); 55 | 56 | options.MaxEpochs = 2; 57 | p2pModel = p2p.train(filesA, filesB, options); 58 | checkpointFile = dir("temp/**/*.mat"); 59 | test.verifyEqual(numel(checkpointFile), 2, ... 60 | "Should have 2 checkpoint"); 61 | 62 | translatedImage = p2p.translate(p2pModel, testImage); 63 | test.verifyEqual(size(translatedImage), size(testImage)); 64 | end 65 | 66 | function testWithPlot(test) 67 | dataFolder = fullfile(toolboxdir('vision'),'visiondata','triangleImages'); 68 | filesA = imageDatastore(fullfile(dataFolder,'trainingLabels')).Files(1:5); 69 | filesB = imageDatastore(fullfile(dataFolder,'trainingImages')).Files(1:5); 70 | 71 | options = p2p.trainingOptions("InputChannels", 1, ... 72 | "OutputChannels", 1, ... 73 | "MaxEpochs", 2, ... 74 | "GDepth", 2, ... 75 | "DDepth", 2, ... 76 | "Plots", "training-progress"); 77 | 78 | p2p.train(filesA, filesB, options); 79 | 80 | % Just check that the training figure was created 81 | f = findobj(groot, "Tag", "p2p.vis.TrainingPlot"); 82 | test.verifyEqual(numel(f), 1, ... 83 | "1 training plot should be created"); 84 | close(f); 85 | end 86 | end 87 | 88 | end -------------------------------------------------------------------------------- /tests/+tests/WithWorkingDirectory.m: -------------------------------------------------------------------------------- 1 | classdef WithWorkingDirectory < matlab.unittest.TestCase 2 | % Test fixture for using a working directory with test resources 3 | 4 | % Copyright 2020 The MathWorks, Inc. 5 | 6 | properties (GetAccess = public, SetAccess = private) 7 | Root (1,1) string 8 | Resources (1,1) string 9 | end 10 | 11 | methods (TestMethodSetup) 12 | function initializeWorkingDirWithResources(this) 13 | fixture = matlab.unittest.fixtures.WorkingFolderFixture(); 14 | this.applyFixture(fixture); 15 | this.Root = fixture.Folder; 16 | this.Resources = fullfile(this.Root, 'resources'); 17 | copyfile(... 18 | fullfile(testRoot(), 'resources'), ... 19 | this.Resources ... 20 | ); 21 | end 22 | 23 | end 24 | 25 | end -------------------------------------------------------------------------------- /tests/resources/badImagePairs/A/cmp_b0003 - Copy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pix2pix/dec1f666edf0858acee990701ea601376f756435/tests/resources/badImagePairs/A/cmp_b0003 - Copy.jpg -------------------------------------------------------------------------------- /tests/resources/badImagePairs/A/cmp_b0003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pix2pix/dec1f666edf0858acee990701ea601376f756435/tests/resources/badImagePairs/A/cmp_b0003.jpg -------------------------------------------------------------------------------- /tests/resources/badImagePairs/A/cmp_b0005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pix2pix/dec1f666edf0858acee990701ea601376f756435/tests/resources/badImagePairs/A/cmp_b0005.jpg -------------------------------------------------------------------------------- /tests/resources/badImagePairs/A/cmp_b0011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pix2pix/dec1f666edf0858acee990701ea601376f756435/tests/resources/badImagePairs/A/cmp_b0011.jpg -------------------------------------------------------------------------------- /tests/resources/badImagePairs/B/cmp_b0003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pix2pix/dec1f666edf0858acee990701ea601376f756435/tests/resources/badImagePairs/B/cmp_b0003.jpg -------------------------------------------------------------------------------- /tests/resources/badImagePairs/B/cmp_b0005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pix2pix/dec1f666edf0858acee990701ea601376f756435/tests/resources/badImagePairs/B/cmp_b0005.jpg -------------------------------------------------------------------------------- /tests/resources/badImagePairs/B/cmp_b0011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pix2pix/dec1f666edf0858acee990701ea601376f756435/tests/resources/badImagePairs/B/cmp_b0011.jpg -------------------------------------------------------------------------------- /tests/resources/imagePairs/A/cmp_b0003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pix2pix/dec1f666edf0858acee990701ea601376f756435/tests/resources/imagePairs/A/cmp_b0003.jpg -------------------------------------------------------------------------------- /tests/resources/imagePairs/A/cmp_b0005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pix2pix/dec1f666edf0858acee990701ea601376f756435/tests/resources/imagePairs/A/cmp_b0005.jpg -------------------------------------------------------------------------------- /tests/resources/imagePairs/A/cmp_b0011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pix2pix/dec1f666edf0858acee990701ea601376f756435/tests/resources/imagePairs/A/cmp_b0011.jpg -------------------------------------------------------------------------------- /tests/resources/imagePairs/B/cmp_b0003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pix2pix/dec1f666edf0858acee990701ea601376f756435/tests/resources/imagePairs/B/cmp_b0003.jpg -------------------------------------------------------------------------------- /tests/resources/imagePairs/B/cmp_b0005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pix2pix/dec1f666edf0858acee990701ea601376f756435/tests/resources/imagePairs/B/cmp_b0005.jpg -------------------------------------------------------------------------------- /tests/resources/imagePairs/B/cmp_b0011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pix2pix/dec1f666edf0858acee990701ea601376f756435/tests/resources/imagePairs/B/cmp_b0011.jpg -------------------------------------------------------------------------------- /tests/testRoot.m: -------------------------------------------------------------------------------- 1 | function path = testRoot() 2 | 3 | % Copyright 2020 The MathWorks, Inc. 4 | path = fileparts(mfilename('fullpath')); 5 | end -------------------------------------------------------------------------------- /uninstall.m: -------------------------------------------------------------------------------- 1 | function uninstall() 2 | rmpath('code'); 3 | rmpath('tests'); 4 | end --------------------------------------------------------------------------------