├── AudioSeparation.py ├── Image2Image.py ├── Image2ImageGrid.py ├── ImagePairs.py ├── ImagePairsGrid.py ├── LICENSE ├── PairedMNIST.py ├── README.md ├── Utils.py ├── analysis ├── __init__.py ├── imagepairs │ ├── ImagePairs Results.csv │ ├── ImagePairs_cityscapes_1000_joint_GAN_big_gens.png │ ├── ImagePairs_cityscapes_1000_joint_GAN_gens.png │ ├── ImagePairs_cityscapes_1000_joint_factorGAN_gens.png │ ├── ImagePairs_cityscapes_100_joint_GAN_big_gens.png │ ├── ImagePairs_cityscapes_100_joint_GAN_gens.png │ ├── ImagePairs_cityscapes_100_joint_factorGAN_gens.png │ ├── ImagePairs_cityscapes_All_joint_GAN_big_gens.png │ ├── ImagePairs_cityscapes_All_joint_GAN_gens.png │ ├── ImagePairs_cityscapes_All_joint_factorGAN_gens.png │ ├── ImagePairs_edges2shoes_1000_joint_GAN_big_gens.png │ ├── ImagePairs_edges2shoes_1000_joint_GAN_gens.png │ ├── ImagePairs_edges2shoes_1000_joint_factorGAN_gens.png │ ├── ImagePairs_edges2shoes_100_joint_GAN_big_gens.png │ ├── ImagePairs_edges2shoes_100_joint_GAN_gens.png │ ├── ImagePairs_edges2shoes_100_joint_GAN_gens_8.png │ ├── ImagePairs_edges2shoes_100_joint_factorGAN_gens.png │ ├── ImagePairs_edges2shoes_100_joint_factorGAN_gens_8.png │ ├── ImagePairs_edges2shoes_All_joint_GAN_big_gens.png │ ├── ImagePairs_edges2shoes_All_joint_GAN_gens.png │ ├── ImagePairs_edges2shoes_All_joint_factorGAN_gens.png │ ├── PlotExamples.py │ ├── PlotLS.py │ ├── __init__.py │ └── imagepairs_LS.pdf ├── img2img_cityscapes │ ├── Image2Image_Cityscapes.csv │ ├── PlotExamples.py │ ├── PlotL2Acc.py │ ├── __init__.py │ ├── cityscapes.pdf │ ├── cityscapes_1000_joint_GAN_gens.png │ ├── cityscapes_1000_joint_factorGAN_gens.png │ ├── cityscapes_100_joint_GAN_gens.png │ ├── cityscapes_100_joint_GAN_gens_small.png │ ├── cityscapes_100_joint_factorGAN_gens.png │ ├── cityscapes_100_joint_factorGAN_gens_small.png │ ├── cityscapes_all_joint_GAN_gens.png │ └── cityscapes_all_joint_factorGAN_gens.png ├── mnist │ ├── PairedMNIST Results Diff.csv │ ├── PairedMNIST Results FID.csv │ ├── PlotDep.py │ ├── PlotExamples.py │ ├── PlotFID.py │ ├── __init__.py │ ├── mnist_dep.pdf │ ├── mnist_examples.pdf │ └── mnist_fid.pdf └── source_separation │ ├── PlotSDR.py │ ├── __init__.py │ ├── sdr.csv │ └── sdr.pdf ├── datasets ├── AudioSeparationDataset.py ├── CropDataset.py ├── DoubleMNISTDataset.py ├── GeneratorInputDataset.py ├── InfiniteDataSampler.py ├── TransformDataSampler.py ├── __init__.py └── image2image │ ├── __init__.py │ ├── aligned_dataset.py │ ├── base_dataset.py │ ├── download_image2image.sh │ └── image_folder.py ├── eval ├── Cityscapes.py ├── FID.py ├── LS.py ├── SourceSeparation.py ├── Visualisation.py └── __init__.py ├── factorgan_conditional.png ├── factorgan_unconditional.png ├── models ├── MNISTClassifier.py ├── SpectralNorm.py ├── __init__.py ├── discriminators │ ├── ConvDiscriminator.py │ ├── FCDiscriminator.py │ └── __init__.py └── generators │ ├── ConvGenerator.py │ ├── FCGenerator.py │ ├── Unet.py │ └── __init__.py ├── requirements.txt └── training ├── AdversarialTraining.py ├── DiscriminatorTraining.py ├── MNIST.py ├── TrainingOptions.py └── __init__.py /AudioSeparation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import os 4 | from torch.utils.data import DataLoader 5 | 6 | import training.TrainingOptions 7 | import training.AdversarialTraining 8 | import Utils 9 | from datasets.GeneratorInputDataset import GeneratorInputDataset 10 | from datasets.InfiniteDataSampler import InfiniteDataSampler 11 | from datasets.TransformDataSampler import TransformDataSampler 12 | from datasets.AudioSeparationDataset import MUSDBDataset 13 | from eval.SourceSeparation import produce_musdb_source_estimates 14 | from models.generators.Unet import Unet 15 | from training.DiscriminatorTraining import DiscriminatorSetup, DependencyDiscriminatorSetup, DependencyDiscriminatorPair 16 | from models.discriminators.ConvDiscriminator import ConvDiscriminator 17 | 18 | def set_paths(opt): 19 | # Set up paths and create folders 20 | opt.experiment_path = os.path.join(opt.out_path, "AudioSeparation", opt.experiment_name) 21 | opt.gen_path = os.path.join(opt.experiment_path, "gen") 22 | opt.log_path = os.path.join(opt.experiment_path, "logs") 23 | opt.estimates_path = os.path.join(opt.experiment_path, "source_estimates") 24 | Utils.make_dirs([opt.experiment_path, opt.gen_path, opt.log_path, opt.estimates_path]) 25 | 26 | def train(opt): 27 | Utils.set_seeds(opt) 28 | device = Utils.get_device(opt.cuda) 29 | set_paths(opt) 30 | 31 | if opt.num_joint_songs > 100: 32 | print("ERROR: Cannot train with " + str(opt.num_joint_songs) + " samples, dataset has only size of 100") 33 | return 34 | 35 | # Partition into paired and unpaired songs 36 | idx = [i for i in range(100)] 37 | random.shuffle(idx) 38 | 39 | # Joint samples 40 | dataset_train = MUSDBDataset(opt, idx[:opt.num_joint_songs], "paired") 41 | train_joint = InfiniteDataSampler(DataLoader(dataset_train, num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True, drop_last=True)) 42 | 43 | if opt.factorGAN == 1: 44 | # For marginals, take full dataset 45 | mix_dataset = MUSDBDataset(opt, idx, "mix") 46 | 47 | acc_dataset = MUSDBDataset(opt, idx, "accompaniment") 48 | acc_loader = InfiniteDataSampler(DataLoader(acc_dataset, num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True, drop_last=True)) 49 | 50 | vocal_dataset = MUSDBDataset(opt, idx, "vocals") 51 | vocal_loader = InfiniteDataSampler(DataLoader(vocal_dataset, num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True, drop_last=True)) 52 | else: # For normal GAN, take only few joint songs 53 | mix_dataset = MUSDBDataset(opt, idx[:opt.num_joint_songs], "mix") 54 | 55 | # SETUP GENERATOR MODEL 56 | G = Unet(opt, opt.generator_channels, 1, 1).to(device) # 1 input channel (mixture), 1 output channel (mask) 57 | G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz)) 58 | G_opt = Utils.create_optim(G.parameters(), opt) 59 | 60 | # Prepare data sources that are a combination of real data and generator network, or purely from the generator network 61 | G_input_data = DataLoader(GeneratorInputDataset(mix_dataset, G_noise), num_workers=int(opt.workers), 62 | batch_size=opt.batchSize, shuffle=True, drop_last=True) 63 | G_inputs = InfiniteDataSampler(G_input_data) 64 | G_filled_outputs = TransformDataSampler(InfiniteDataSampler(G_inputs), G, device) 65 | 66 | # SETUP DISCRIMINATOR(S) 67 | crop_mix = lambda x: x[:, 1:, :, :] # Only keep sources, not mixture for dep discs 68 | if opt.factorGAN == 1: 69 | # Setup marginal disc networks 70 | D_voc = ConvDiscriminator(opt.input_height, opt.input_width, 1, opt.disc_channels).to(device) 71 | D_acc = ConvDiscriminator(opt.input_height, opt.input_width, 1, opt.disc_channels).to(device) 72 | 73 | D_acc_setup = DiscriminatorSetup("D_acc", D_acc, Utils.create_optim(D_acc.parameters(), opt), acc_loader, 74 | G_filled_outputs, crop_fake=lambda x : x[:,1:2,:,:]) 75 | 76 | D_voc_setup = DiscriminatorSetup("D_voc", D_voc, Utils.create_optim(D_voc.parameters(), opt), vocal_loader, 77 | G_filled_outputs, crop_fake=lambda x : x[:,2:3,:,:]) 78 | # Marginal discriminator 79 | D_setups = [D_acc_setup, D_voc_setup] 80 | 81 | # If our dep discriminators are only defined on classifier probabilites, integrate classification into discriminator network as first step 82 | if opt.use_real_dep_disc == 1: 83 | DP = ConvDiscriminator(opt.input_height, opt.input_width, 2, opt.disc_channels, spectral_norm=(opt.lipschitz_p == 1)).to(device) 84 | else: 85 | DP = lambda x : 0 86 | 87 | DQ = ConvDiscriminator(opt.input_height, opt.input_width, 2, opt.disc_channels).to(device) 88 | 89 | # Dependency discriminators 90 | shuffle_batch_func = lambda x: Utils.shuffle_batch_dims(x, 1) # Randomly mixes different sources together (e.g. accompaniment from one song with vocals from another) 91 | 92 | if opt.use_real_dep_disc: 93 | DP_setup = DependencyDiscriminatorSetup("DP", DP, Utils.create_optim(DP.parameters(), opt), 94 | train_joint, shuffle_batch_func, crop_func=crop_mix) 95 | else: 96 | DP_setup = None 97 | 98 | DQ_setup = DependencyDiscriminatorSetup("DQ", DQ, Utils.create_optim(DQ.parameters(), opt), 99 | G_filled_outputs, shuffle_batch_func, crop_func=crop_mix) 100 | D_dep_setups = [DependencyDiscriminatorPair(DP_setup, DQ_setup)] 101 | else: 102 | D = ConvDiscriminator(opt.input_height, opt.input_width, 2, opt.disc_channels).to(device) 103 | 104 | D_setup = DiscriminatorSetup("D", D, Utils.create_optim(D.parameters(), opt), 105 | train_joint, G_filled_outputs, crop_real=crop_mix, crop_fake=crop_mix) 106 | D_setups = [D_setup] 107 | D_dep_setups = [] 108 | 109 | # RUN TRAINING 110 | training.AdversarialTraining.train(opt, G, G_inputs, G_opt, D_setups, D_dep_setups, device, opt.log_path) 111 | torch.save(G.state_dict(), os.path.join(opt.experiment_path, "G")) 112 | 113 | def eval(opt): 114 | Utils.set_seeds(opt) 115 | device = Utils.get_device(opt.cuda) 116 | set_paths(opt) 117 | 118 | # GENERATOR 119 | # SETUP GENERATOR MODEL 120 | G = Unet(opt, opt.generator_channels, 1, 1).to(device) # 1 input channel (mixture), 1 output channel (mask) 121 | G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz)) 122 | G.load_state_dict(torch.load(os.path.join(opt.experiment_path, opt.eval_model))) 123 | G.eval() 124 | 125 | #EVALUATE BY PRODUCING SOURCE ESTIMATE AUDIO AND SDR METRICS 126 | produce_musdb_source_estimates(opt, G, G_noise, opt.estimates_path, subsets="test") 127 | 128 | 129 | def get_opt(): 130 | # COLLECT ALL CMD ARGUMENTS 131 | parser = training.TrainingOptions.get_parser() 132 | 133 | parser.add_argument('--musdb_path', type=str, help="Path to MUSDB dataset") 134 | parser.add_argument('--preprocessed_dataset_path', type=str, help="Path to where the preprocessed dataset should be saved") 135 | 136 | parser.add_argument('--num_joint_songs', type=int, default=100, 137 | help="Number of songs from which joint observations are available for training normal gan/dependency discriminators") 138 | parser.add_argument('--hop_size', type=int, default=256, 139 | help="Hop size of FFT") 140 | parser.add_argument('--fft_size', type=int, default=512, 141 | help="Size of FFT") 142 | parser.add_argument('--sample_rate', type=int, default=22050, 143 | help="Resample input audio to this sample rate") 144 | parser.add_argument('--generator_channels', type=int, default=32, 145 | help="Number of intial feature channels used in G. 32 was used in the paper") 146 | parser.add_argument('--disc_channels', type=int, default=32, 147 | help="Number of intial feature channels used in each discriminator") 148 | 149 | opt = parser.parse_args() 150 | print(opt) 151 | 152 | opt.input_height = opt.fft_size // 2 # No. of freq bins for model 153 | opt.input_width = opt.input_height // 2 # No of time frames 154 | 155 | print("Activating generator mask and sigmoid non-linearity for mask") 156 | opt.generator_mask = 1 # Use a mask for Unet output 157 | opt.generator_activation = "sigmoid" # Use sigmoid output for mask 158 | 159 | return opt 160 | 161 | if __name__ == "__main__": 162 | opt = get_opt() 163 | 164 | if not opt.eval: 165 | train(opt) 166 | eval(opt) -------------------------------------------------------------------------------- /Image2Image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.tensorboard import SummaryWriter 3 | 4 | import os 5 | from torch.utils.data import DataLoader, Subset 6 | 7 | import training.TrainingOptions 8 | import training.AdversarialTraining 9 | import Utils 10 | from datasets.GeneratorInputDataset import GeneratorInputDataset 11 | from datasets.InfiniteDataSampler import InfiniteDataSampler 12 | from datasets.TransformDataSampler import TransformDataSampler 13 | from datasets.image2image import get_aligned_dataset 14 | from eval.Cityscapes import get_L2, get_pixel_acc 15 | from eval.Visualisation import generate_images 16 | from models.generators.Unet import Unet 17 | from training.DiscriminatorTraining import DiscriminatorSetup, DependencyDiscriminatorSetup, DependencyDiscriminatorPair 18 | from models.discriminators.ConvDiscriminator import ConvDiscriminator 19 | from datasets.CropDataset import CropDataset 20 | 21 | def set_paths(opt): 22 | # Set up paths and create folders 23 | opt.experiment_path = os.path.join(opt.out_path, "Image2Image_" + str(opt.dataset), opt.experiment_name) 24 | opt.gen_path = os.path.join(opt.experiment_path, "gen") 25 | opt.log_path = os.path.join(opt.experiment_path, "logs") 26 | Utils.make_dirs([opt.experiment_path, opt.gen_path, opt.log_path]) 27 | 28 | def train(opt): 29 | Utils.set_seeds(opt) 30 | device = Utils.get_device(opt.cuda) 31 | set_paths(opt) 32 | 33 | # DATA 34 | dataset = get_aligned_dataset(opt, "train") 35 | nc = dataset.A_nc + dataset.B_nc 36 | 37 | if opt.num_joint_samples > len(dataset): 38 | print("WARNING: Cannot train with " + str(opt.num_joint_samples) + " samples, dataset has only size of " + str(len(dataset))+ ". Using full dataset!") 39 | opt.num_joint_samples = len(dataset) 40 | 41 | # Joint samples 42 | dataset_train = Subset(dataset, range(opt.num_joint_samples)) 43 | train_joint = InfiniteDataSampler( 44 | DataLoader(dataset_train, num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True, drop_last=True)) 45 | 46 | if opt.factorGAN == 1: 47 | # For marginals, take full dataset and crop 48 | a_dataset = CropDataset(dataset, lambda x : x[0:dataset.A_nc, :, :]) 49 | b_dataset = CropDataset(dataset, lambda x : x[dataset.A_nc:, :, :]) 50 | train_b = InfiniteDataSampler(DataLoader(b_dataset, num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True, drop_last=True)) 51 | 52 | generator_input_data = a_dataset 53 | else: 54 | generator_input_data = CropDataset(dataset_train, lambda x : x[0:dataset.A_nc, :, :]) 55 | 56 | # SETUP GENERATOR MODEL 57 | G = Unet(opt, opt.generator_channels, dataset.A_nc, dataset.B_nc).to(device) 58 | G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz)) 59 | G_opt = Utils.create_optim(G.parameters(), opt) 60 | # Prepare data sources that are a combination of real data and generator network, or purely from the generator network 61 | G_input_data = DataLoader(GeneratorInputDataset(generator_input_data, G_noise), num_workers=int(opt.workers), 62 | batch_size=opt.batchSize, shuffle=True, drop_last=True) 63 | G_inputs = InfiniteDataSampler(G_input_data) 64 | G_filled_outputs = TransformDataSampler(InfiniteDataSampler(G_input_data), G, device) 65 | 66 | # SETUP DISCRIMINATOR(S) 67 | if opt.factorGAN == 1: 68 | # Setup disc networks 69 | D2 = ConvDiscriminator(opt.loadSize, opt.loadSize, dataset.B_nc).to(device) 70 | # If our dep discriminators are only defined on classifier probabilites, integrate classification into discriminator network as first step 71 | if opt.use_real_dep_disc == 1: 72 | DP = ConvDiscriminator(opt.loadSize, opt.loadSize, nc, spectral_norm=(opt.lipschitz_p == 1)).to(device) 73 | else: 74 | DP = lambda x : 0 75 | 76 | DQ = ConvDiscriminator(opt.loadSize, opt.loadSize, nc).to(device) 77 | 78 | # Prepare discriminators for training method 79 | # Marginal discriminator 80 | D_setups = [DiscriminatorSetup("D2", D2, Utils.create_optim(D2.parameters(), opt), 81 | train_b, G_filled_outputs, crop_fake=lambda x: x[:, dataset.A_nc:, :, :])] 82 | # Dependency discriminators 83 | shuffle_batch_func = lambda x: Utils.shuffle_batch_dims(x, [dataset.A_nc]) 84 | if opt.use_real_dep_disc: 85 | DP_setup = DependencyDiscriminatorSetup("DP", DP, Utils.create_optim(DP.parameters(), opt), 86 | train_joint, shuffle_batch_func) 87 | else: 88 | DP_setup = None 89 | 90 | DQ_setup = DependencyDiscriminatorSetup("DQ", DQ, Utils.create_optim(DQ.parameters(), opt), 91 | G_filled_outputs, shuffle_batch_func) 92 | D_dep_setups = [DependencyDiscriminatorPair(DP_setup, DQ_setup)] 93 | else: 94 | D = ConvDiscriminator(opt.loadSize, opt.loadSize, nc).to(device) 95 | print(sum(p.numel() for p in D.parameters())) 96 | 97 | D_setup = DiscriminatorSetup("D", D, Utils.create_optim(D.parameters(), opt), 98 | train_joint, G_filled_outputs) 99 | D_setups = [D_setup] 100 | D_dep_setups = [] 101 | 102 | # RUN TRAINING 103 | training.AdversarialTraining.train(opt, G, G_inputs, G_opt, D_setups, D_dep_setups, device, opt.log_path) 104 | torch.save(G.state_dict(), os.path.join(opt.experiment_path, "G")) 105 | 106 | def eval(opt): 107 | Utils.set_seeds(opt) 108 | device = Utils.get_device(opt.cuda) 109 | set_paths(opt) 110 | 111 | # DATASET 112 | dataset = get_aligned_dataset(opt, "val") 113 | input_dataset = CropDataset(dataset, lambda x: x[0:dataset.A_nc, :, :]) 114 | 115 | # GENERATOR 116 | G = Unet(opt, opt.generator_channels, dataset.A_nc, dataset.B_nc).to(device) 117 | G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz)) 118 | G.load_state_dict(torch.load(os.path.join(opt.experiment_path, opt.eval_model))) 119 | G.eval() 120 | 121 | # EVALUATE: Generate some images using test set and noise as conditional input 122 | G_input_data = DataLoader(GeneratorInputDataset(input_dataset, G_noise), num_workers=int(opt.workers), 123 | batch_size=opt.batchSize, shuffle=False) 124 | G_inputs = InfiniteDataSampler(G_input_data) 125 | 126 | generate_images(G, G_inputs, opt.gen_path, 100, device, lambda x : Utils.create_image_pair(x, dataset.A_nc, dataset.B_nc)) 127 | 128 | # EVALUATE for Cityscapes 129 | if opt.dataset == "cityscapes": 130 | writer = SummaryWriter(opt.log_path) 131 | val_input_data = DataLoader(dataset, num_workers=int(opt.workers),batch_size=opt.batchSize) 132 | 133 | pixel_error = get_pixel_acc(opt, device, G, val_input_data, G_noise) 134 | print("VALIDATION PERFORMANCE Pixel: " + str(pixel_error)) 135 | writer.add_scalar("val_pix", pixel_error) 136 | 137 | L2_error = get_L2(opt, device, G, val_input_data, G_noise) 138 | print("VALIDATION PERFORMANCE L2: " + str(L2_error)) 139 | writer.add_scalar("val_L2", L2_error) 140 | 141 | 142 | def get_opt(): 143 | # COLLECT ALL CMD ARGUMENTS 144 | parser = training.TrainingOptions.get_parser() 145 | 146 | parser.add_argument('--dataset', type=str, default="cityscapes", 147 | help="Dataset to train on - can be cityscapes or edges2shoes (but other img2img datasets can be integrated easily") 148 | parser.add_argument('--num_joint_samples', type=int, default=1000, 149 | help="Number of joint observations available for training normal gan/dependency discriminators") 150 | parser.add_argument('--loadSize', type=int, default=128, 151 | help="Dimensions (no. of pixels) the dataset images are resampled to") 152 | parser.add_argument('--generator_channels', type=int, default=32, 153 | help="Number of intial feature channels used in G. 64 was used in the paper") 154 | 155 | opt = parser.parse_args() 156 | print(opt) 157 | 158 | # Set generator to sigmoid output 159 | opt.generator_activation = "sigmoid" 160 | 161 | # Square images => loadSize determines width and height 162 | opt.input_width = opt.loadSize 163 | opt.input_height = opt.loadSize 164 | 165 | return opt 166 | 167 | if __name__ == "__main__": 168 | opt = get_opt() 169 | 170 | if not opt.eval: 171 | train(opt) 172 | eval(opt) -------------------------------------------------------------------------------- /Image2ImageGrid.py: -------------------------------------------------------------------------------- 1 | # This script will train image segmentation models in different dataset configurations as in the paper 2 | import Image2Image 3 | 4 | opt = Image2Image.get_opt() 5 | 6 | for dataset_name in ["cityscapes", "edges2shoes"]: # Iterate over datasets, more datasets could be added here like maps 7 | opt.dataset = dataset_name 8 | for num_joint_samples in [100, 1000, 10000]: # Try for different amount of paired samples 9 | # Apply settings 10 | print(str(num_joint_samples) + " joint samples") 11 | opt.num_joint_samples = num_joint_samples 12 | 13 | print("Training GAN") 14 | opt.experiment_name = str(num_joint_samples) + "_joint_GAN" 15 | opt.factorGAN = 0 16 | Image2Image.train(opt) 17 | 18 | print("Training factorGAN") 19 | opt.experiment_name = str(num_joint_samples) + "_joint_factorGAN" 20 | opt.factorGAN = 1 21 | Image2Image.train(opt) -------------------------------------------------------------------------------- /ImagePairs.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import torch 3 | import os 4 | from torch.utils.data import DataLoader, Subset 5 | 6 | import training.TrainingOptions 7 | import training.AdversarialTraining 8 | import Utils 9 | from datasets.GeneratorInputDataset import GeneratorInputDataset 10 | from datasets.InfiniteDataSampler import InfiniteDataSampler 11 | from datasets.TransformDataSampler import TransformDataSampler 12 | from datasets.image2image import get_aligned_dataset 13 | from eval import LS 14 | from eval.Visualisation import generate_images 15 | from models.generators.ConvGenerator import ConvGenerator 16 | from training.DiscriminatorTraining import DiscriminatorSetup, DependencyDiscriminatorSetup, DependencyDiscriminatorPair 17 | from models.discriminators.ConvDiscriminator import ConvDiscriminator 18 | from datasets.CropDataset import CropDataset 19 | 20 | def set_paths(opt): 21 | # Set up paths and create folders 22 | opt.experiment_path = os.path.join(opt.out_path, "ImagePairs", opt.dataset, opt.experiment_name) 23 | opt.gen_path = os.path.join(opt.experiment_path, "gen") 24 | opt.log_path = os.path.join(opt.experiment_path, "logs") 25 | Utils.make_dirs([opt.experiment_path, opt.gen_path, opt.log_path]) 26 | 27 | def train(opt): 28 | Utils.set_seeds(opt) 29 | device = Utils.get_device(opt.cuda) 30 | set_paths(opt) 31 | 32 | # DATA 33 | dataset = get_aligned_dataset(opt, "train") 34 | nc = dataset.A_nc + dataset.B_nc 35 | 36 | # Warning if desired number of joint samples is larger than dataset, in that case, use whole dataset as paired 37 | if opt.num_joint_samples > len(dataset): 38 | print("WARNING: Cannot train with " + str(opt.num_joint_samples) + " samples, dataset has only size of " + str(len(dataset))+ ". Using full dataset!") 39 | opt.num_joint_samples = len(dataset) 40 | 41 | # Joint samples 42 | dataset_train = Subset(dataset, range(opt.num_joint_samples)) 43 | train_joint = InfiniteDataSampler( 44 | DataLoader(dataset_train, num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True, drop_last=True)) 45 | 46 | if opt.factorGAN == 1: 47 | # For marginals, take full dataset and crop 48 | train_a = InfiniteDataSampler(DataLoader(CropDataset(dataset, lambda x : x[0:dataset.A_nc, :, :]), 49 | num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True)) 50 | train_b = InfiniteDataSampler(DataLoader(CropDataset(dataset, lambda x : x[dataset.A_nc:, :, :]), 51 | num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True)) 52 | 53 | # SETUP GENERATOR MODEL 54 | G = ConvGenerator(opt, opt.generator_channels, opt.loadSize, nc).to(device) 55 | G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz)) 56 | G_opt = Utils.create_optim(G.parameters(), opt) 57 | 58 | # Prepare data sources that are a combination of real data and generator network, or purely from the generator network 59 | G_input_data = DataLoader(GeneratorInputDataset(None, G_noise), num_workers=int(opt.workers), 60 | batch_size=opt.batchSize, shuffle=True) 61 | G_inputs = InfiniteDataSampler(G_input_data) 62 | G_outputs = TransformDataSampler(InfiniteDataSampler(G_input_data), G, device) 63 | 64 | # SETUP DISCRIMINATOR(S) 65 | if opt.factorGAN == 1: 66 | # Setup disc networks 67 | D1 = ConvDiscriminator(opt.loadSize, opt.loadSize, dataset.A_nc, opt.disc_channels).to(device) 68 | D2 = ConvDiscriminator(opt.loadSize, opt.loadSize, dataset.B_nc, opt.disc_channels).to(device) 69 | # If our dep discriminators are only defined on classifier probabilites, integrate classification into discriminator network as first step 70 | if opt.use_real_dep_disc == 1: 71 | DP = ConvDiscriminator(opt.loadSize, opt.loadSize, nc, opt.disc_channels, spectral_norm=(opt.lipschitz_p == 1)).to(device) 72 | else: 73 | DP = lambda x : 0 74 | 75 | DQ = ConvDiscriminator(opt.loadSize, opt.loadSize, nc, opt.disc_channels).to(device) 76 | print(sum(p.numel() for p in D1.parameters())) 77 | 78 | # Prepare discriminators for training method 79 | # Marginal discriminators 80 | D1_setup = DiscriminatorSetup("D1", D1, Utils.create_optim(D1.parameters(), opt), 81 | train_a, G_outputs, crop_fake=lambda x : x[:, 0:dataset.A_nc, :, :]) 82 | D2_setup = DiscriminatorSetup("D2", D2, Utils.create_optim(D2.parameters(), opt), 83 | train_b, G_outputs, crop_fake=lambda x : x[:, dataset.A_nc:, :, :]) 84 | D_setups = [D1_setup, D2_setup] 85 | 86 | # Dependency discriminators 87 | shuffle_batch_func = lambda x: Utils.shuffle_batch_dims(x, [dataset.A_nc]) 88 | if opt.use_real_dep_disc: 89 | DP_setup = DependencyDiscriminatorSetup("DP", DP, Utils.create_optim(DP.parameters(), opt), 90 | train_joint, shuffle_batch_func) 91 | else: 92 | DP_setup = None 93 | 94 | DQ_setup = DependencyDiscriminatorSetup("DQ", DQ,Utils.create_optim(DQ.parameters(), opt), 95 | G_outputs, shuffle_batch_func) 96 | D_dep_setups = [DependencyDiscriminatorPair(DP_setup, DQ_setup)] 97 | else: 98 | D = ConvDiscriminator(opt.loadSize, opt.loadSize, nc, opt.disc_channels).to(device) 99 | print(sum(p.numel() for p in D.parameters())) 100 | D_setups = [DiscriminatorSetup("D", D, Utils.create_optim(D.parameters(), opt), train_joint, G_outputs)] 101 | D_dep_setups = [] 102 | 103 | # RUN TRAINING 104 | training.AdversarialTraining.train(opt, G, G_inputs, G_opt, D_setups, D_dep_setups, device, opt.log_path) 105 | torch.save(G.state_dict(), os.path.join(opt.experiment_path, "G")) 106 | 107 | def eval(opt): 108 | device = Utils.get_device(opt.cuda) 109 | set_paths(opt) 110 | 111 | # Get test dataset 112 | dataset = get_aligned_dataset(opt, "val") 113 | nc = dataset.A_nc + dataset.B_nc 114 | 115 | # SETUP GENERATOR MODEL 116 | G = ConvGenerator(opt, opt.generator_channels, opt.loadSize, nc).to(device) 117 | G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz)) 118 | 119 | # Prepare data sources that are a combination of real data and generator network, or purely from the generator network 120 | G_input_data = DataLoader(GeneratorInputDataset(None, G_noise), num_workers=int(opt.workers), 121 | batch_size=opt.batchSize, shuffle=True) 122 | G_inputs = InfiniteDataSampler(G_input_data) 123 | G_outputs = TransformDataSampler(InfiniteDataSampler(G_input_data), G, device) 124 | G.load_state_dict(torch.load(os.path.join(opt.experiment_path, opt.eval_model))) 125 | G.eval() 126 | 127 | # EVALUATE 128 | # GENERATE EXAMPLES 129 | generate_images(G, G_inputs, opt.gen_path, 1000, device, lambda x: Utils.create_image_pair(x, dataset.A_nc, dataset.B_nc)) 130 | 131 | # COMPUTE LS DISTANCE 132 | # Partition into test train and test test 133 | test_train_samples = int(0.8 * float(len(dataset))) 134 | test_test_samples = len(dataset) - test_train_samples 135 | print("VALIDATION SAMPLES: " + str(test_train_samples)) 136 | print("TEST SAMPLES: " + str(test_test_samples)) 137 | real_test_train_loader = DataLoader(Subset(dataset, range(test_train_samples)), num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True, drop_last=True) 138 | real_test_test_loader = DataLoader(Subset(dataset, range(test_train_samples, len(dataset))), num_workers=int(opt.workers), batch_size=opt.batchSize) 139 | 140 | # Initialise classifier 141 | classifier_factory = lambda : ConvDiscriminator(opt.loadSize, opt.loadSize, nc, filters=opt.ls_channels, spectral_norm=False).to(device) 142 | # Compute metric 143 | losses = LS.compute_ls_metric(classifier_factory, real_test_train_loader, real_test_test_loader, G_outputs, opt.ls_runs, device) 144 | 145 | # WRITE RESULTS INTO CSV FOR LATER ANALYSIS 146 | file_existed = os.path.exists(os.path.join(opt.experiment_path, "LS.csv")) 147 | with open(os.path.join(opt.experiment_path, "LS.csv"), "a") as csv_file: 148 | writer = csv.writer(csv_file) 149 | model = "factorGAN" if opt.factorGAN else "gan" 150 | if not file_existed: 151 | writer.writerow(["LS", "Model", "Samples", "Dataset", "Samples_Validation","Samples_Test"]) 152 | for val in losses: 153 | writer.writerow([val, model, opt.num_joint_samples, opt.dataset, test_train_samples, test_test_samples]) 154 | 155 | def get_opt(): 156 | # COLLECT ALL CMD ARGUMENTS 157 | parser = training.TrainingOptions.get_parser() 158 | 159 | parser.add_argument('--dataset', type=str, default="edges2shoes", 160 | help="Dataset to train on - can be cityscapes or edges2shoes (but other img2img datasets can be integrated easily") 161 | parser.add_argument('--num_joint_samples', type=int, default=1000, 162 | help="Number of joint observations available for training normal gan/dependency discriminators") 163 | parser.add_argument('--loadSize', type=int, default=64, 164 | help="Dimensions (no. of pixels) the dataset images are resampled to") 165 | parser.add_argument('--generator_channels', type=int, default=64, 166 | help="Number of intial feature channels used in G. 64 was used in the paper") 167 | parser.add_argument('--disc_channels', type=int, default=32, 168 | help="Number of intial feature channels used in each discriminator") 169 | 170 | 171 | # LS distance eval settings 172 | parser.add_argument('--ls_runs', type=int, default=10, 173 | help="Number of LS Discriminator training runs for evaluation") 174 | parser.add_argument('--ls_channels', type=int, default=16, 175 | help="Number of initial feature channels used for LS discriminator. 16 in the paper") 176 | 177 | 178 | opt = parser.parse_args() 179 | print(opt) 180 | 181 | # Set generator to sigmoid output 182 | opt.generator_activation = "sigmoid" 183 | 184 | return opt 185 | 186 | if __name__ == "__main__": 187 | opt = get_opt() 188 | 189 | if not opt.eval: 190 | train(opt) 191 | eval(opt) -------------------------------------------------------------------------------- /ImagePairsGrid.py: -------------------------------------------------------------------------------- 1 | # Run ImagePairs experiment for multiple datasets and number of paired samples, using GAN and FactorGAN 2 | 3 | import ImagePairs 4 | opt = ImagePairs.getImageOpt() 5 | 6 | for dataset_name in ["cityscapes", "edges2shoes"]: 7 | opt.dataset = dataset_name 8 | for num_joint_samples in [100, 1000, 10000]: 9 | 10 | # Apply settings 11 | print(str(num_joint_samples) + " joint samples") 12 | opt.num_joint_samples = num_joint_samples 13 | 14 | print("Training GAN") 15 | opt.experiment_name = str(num_joint_samples) + "_joint_GAN" 16 | opt.factorGAN = 0 17 | ImagePairs.train(opt) 18 | 19 | print("Training factorGAN") 20 | opt.experiment_name = str(num_joint_samples) + "_joint_factorGAN" 21 | opt.factorGAN = 1 22 | ImagePairs.train(opt) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Daniel Stoller 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PairedMNIST.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.tensorboard import SummaryWriter 3 | from torchvision import datasets 4 | import os 5 | from torch.utils.data import DataLoader 6 | import training.TrainingOptions 7 | import training.AdversarialTraining 8 | import Utils 9 | from datasets.GeneratorInputDataset import GeneratorInputDataset 10 | from datasets.InfiniteDataSampler import InfiniteDataSampler 11 | from datasets.TransformDataSampler import TransformDataSampler 12 | from datasets.DoubleMNISTDataset import DoubleMNISTDataset 13 | from eval import Visualisation, FID 14 | from training import MNIST 15 | from training.DiscriminatorTraining import DiscriminatorSetup, DependencyDiscriminatorSetup, DependencyDiscriminatorPair 16 | from models.discriminators.FCDiscriminator import FCDiscriminator 17 | from models.generators.FCGenerator import FCGenerator 18 | from datasets.CropDataset import CropDataset 19 | import numpy as np 20 | 21 | def set_paths(opt): 22 | # PATHS 23 | opt.experiment_path = os.path.join(opt.out_path, "PairedMNIST", opt.experiment_name) 24 | opt.gen_path = os.path.join(opt.experiment_path, "gen") 25 | opt.log_path = os.path.join(opt.experiment_path, "logs") 26 | Utils.make_dirs([opt.experiment_path, opt.gen_path, opt.log_path]) 27 | 28 | def predict_digits_batch(classifier, two_digit_input): 29 | ''' 30 | Takes MNIST classifier and paired-MNIST sample and gives digit label probabilities for both 31 | :param classifier: MNIST classifier model 32 | :param two_digit_input: Paired MNIST sample 33 | :return: 20-dim. vector containing 2*10 digit label probabilities for upper and lower digit 34 | ''' 35 | if len(two_digit_input.shape) == 2: # B, X 36 | two_digit_input = two_digit_input.view(-1, 1, 56, 28) 37 | elif len(two_digit_input.shape) == 3: # B, H, W 38 | two_digit_input = two_digit_input.unsqueeze(1) 39 | 40 | upper_digit = two_digit_input[:, :, :28, :] 41 | lower_digit = two_digit_input[:, :, 28:, :] 42 | 43 | probs = torch.cat([classifier(upper_digit), classifier(lower_digit)], dim=1) 44 | return probs 45 | 46 | def get_class_prob_matrix(G, G_inputs, classifier, num_samples, device): 47 | ''' 48 | Build matrix of digit label combination frequencies (10x10) 49 | :param G: Generator model 50 | :param G_inputs: Input data sampler for generator (noise) 51 | :param classifier: MNIST classifier model 52 | :param num_samples: Number of samples to draw from generator to estimate c_Q 53 | :param device: Device to use 54 | :return: Normalised frequency of digit combination occurrences (10x10 matrix) 55 | ''' 56 | it = 0 57 | joint_class_probs = np.zeros([10, 10]) 58 | while(True): 59 | input_batch = next(G_inputs) 60 | # Get generator samples 61 | input_batch = [item.to(device) for item in input_batch] 62 | gen_batch = G(input_batch) 63 | # Feed through classifier 64 | digit_preds = predict_digits_batch(classifier, gen_batch) 65 | for pred in digit_preds: 66 | upper_pred = torch.argmax(pred[:10]) 67 | lower_pred = torch.argmax(pred[10:]) 68 | joint_class_probs[upper_pred, lower_pred] += 1 69 | 70 | it += 1 71 | if it >= num_samples: 72 | return joint_class_probs / np.sum(joint_class_probs) 73 | 74 | def train(opt): 75 | print("Using " + str(opt.num_joint_samples) + " joint samples!") 76 | Utils.set_seeds(opt) 77 | device = Utils.get_device(opt.cuda) 78 | 79 | # DATA 80 | MNIST_dim = 784 81 | dataset = datasets.MNIST('datasets', train=True, download=True) 82 | 83 | # Create partitions of stacked MNIST 84 | dataset_joint = DoubleMNISTDataset(dataset, range(opt.num_joint_samples),same_digit_prob=opt.mnist_same_digit_prob) 85 | train_joint = InfiniteDataSampler(DataLoader(dataset_joint, num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True)) 86 | if opt.factorGAN == 1: 87 | # For marginals, take full dataset and crop it 88 | full_dataset = DoubleMNISTDataset(dataset, None, same_digit_prob=opt.mnist_same_digit_prob) 89 | train_x1 = InfiniteDataSampler(DataLoader(CropDataset(full_dataset, lambda x : x[:MNIST_dim]), num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True)) 90 | train_x2 = InfiniteDataSampler(DataLoader(CropDataset(full_dataset, lambda x : x[MNIST_dim:]), num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True)) 91 | 92 | # SETUP GENERATOR MODEL 93 | G = FCGenerator(opt, 2*MNIST_dim).to(device) 94 | G.train() 95 | G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz)) 96 | G_opt = Utils.create_optim(G.parameters(), opt) 97 | 98 | # Prepare data sources that are a combination of real data and generator network, or purely from the generator network 99 | G_input_data = DataLoader(GeneratorInputDataset(None, G_noise), num_workers=int(opt.workers), 100 | batch_size=opt.batchSize, shuffle=True) 101 | G_inputs = InfiniteDataSampler(G_input_data) 102 | G_outputs = TransformDataSampler(InfiniteDataSampler(G_input_data), G, device) 103 | 104 | # SETUP DISCRIMINATOR(S) 105 | if opt.factorGAN == 1: 106 | # Setup disc networks 107 | D1 = FCDiscriminator(MNIST_dim).to(device) 108 | D2 = FCDiscriminator(MNIST_dim).to(device) 109 | # If our dep discriminators are only defined on classifier probabilites, integrate classification into discriminator network as first step 110 | if opt.use_real_dep_disc == 1: 111 | DP = FCDiscriminator(2 * MNIST_dim,spectral_norm=(opt.lipschitz_p == 1)).to(device) 112 | else: 113 | DP = lambda x : 0 114 | 115 | DQ = FCDiscriminator(2 * MNIST_dim).to(device) 116 | 117 | # Prepare discriminators for training method 118 | # Marginal discriminators 119 | D1_setup = DiscriminatorSetup("D1", D1, Utils.create_optim(D1.parameters(), opt), 120 | train_x1, G_outputs, crop_fake=lambda x: x[:, :MNIST_dim]) 121 | D2_setup = DiscriminatorSetup("D2", D2, Utils.create_optim(D2.parameters(), opt), 122 | train_x2, G_outputs, crop_fake=lambda x: x[:, MNIST_dim:]) 123 | D_setups = [D1_setup, D2_setup] 124 | 125 | # Dependency discriminators 126 | shuffle_batch_func = lambda x: Utils.shuffle_batch_dims(x, marginal_index=MNIST_dim) 127 | 128 | if opt.use_real_dep_disc: 129 | DP_setup = DependencyDiscriminatorSetup("DP", DP, Utils.create_optim(DP.parameters(), opt), train_joint, shuffle_batch_func) 130 | else: 131 | DP_setup = None 132 | DQ_setup = DependencyDiscriminatorSetup("DQ", DQ, Utils.create_optim(DQ.parameters(), opt), G_outputs, shuffle_batch_func) 133 | D_dep_setups = [DependencyDiscriminatorPair(DP_setup, DQ_setup)] 134 | else: 135 | D = FCDiscriminator(2*MNIST_dim).to(device) 136 | D_setups = [DiscriminatorSetup("D", D, Utils.create_optim(D.parameters(), opt), train_joint, G_outputs)] 137 | D_dep_setups = [] 138 | 139 | # RUN TRAINING 140 | training.AdversarialTraining.train(opt, G, G_inputs, G_opt, D_setups, D_dep_setups, device, opt.log_path) 141 | torch.save(G.state_dict(), os.path.join(opt.out_path, "G")) 142 | 143 | def eval(opt): 144 | print("EVALUATING MNIST MODEL...") 145 | MNIST_dim = 784 146 | device = Utils.get_device(opt.cuda) 147 | 148 | # Train and save a digit classification model, needed for factorGAN variants and evaluation 149 | classifier = MNIST.main(opt) 150 | classifier.to(device) 151 | classifier.eval() 152 | 153 | # SETUP GENERATOR MODEL 154 | G = FCGenerator(opt, 2 * MNIST_dim).to(device) 155 | G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz)) 156 | # Prepare data sources that are a combination of real data and generator network, or purely from the generator network 157 | G_input_data = DataLoader(GeneratorInputDataset(None, G_noise), num_workers=int(opt.workers), 158 | batch_size=opt.batchSize, shuffle=True) 159 | G_inputs = InfiniteDataSampler(G_input_data) 160 | 161 | G.load_state_dict(torch.load(os.path.join(opt.experiment_path, opt.eval_model))) 162 | G.eval() 163 | 164 | # EVALUATE: Save images to eyeball them + FID for marginals + Class probability correlations 165 | writer = SummaryWriter(opt.log_path) 166 | 167 | test_mnist = datasets.MNIST('datasets', train=False, download=True) 168 | test_dataset = DoubleMNISTDataset(test_mnist, None, same_digit_prob=opt.mnist_same_digit_prob) 169 | test_dataset_loader = DataLoader(test_dataset, num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True) 170 | transform_func = lambda x: x.view(-1, 1, 56, 28) 171 | Visualisation.generate_images(G, G_inputs, opt.gen_path, len(test_dataset), device, transform_func) 172 | 173 | crop_upper = lambda x: x[:, :, :28, :] 174 | crop_lower = lambda x: x[:, :, 28:, :] 175 | fid_upper = FID.evaluate_MNIST(opt, classifier, test_dataset_loader, opt.gen_path, device,crop_real=crop_upper,crop_fake=crop_upper) 176 | fid_lower = FID.evaluate_MNIST(opt, classifier, test_dataset_loader, opt.gen_path, device,crop_real=crop_lower,crop_fake=crop_lower) 177 | print("FID Upper Digit: " + str(fid_upper)) 178 | print("FID Lower Digit: " + str(fid_lower)) 179 | writer.add_scalar("FID_lower", fid_lower) 180 | writer.add_scalar("FID_upper", fid_upper) 181 | 182 | # ESTIMATE QUALITY OF DEPENDENCY MODELLING 183 | # cp(...) = cq(...) ideally for all inputs on the test set if dependencies are perfectly modelled. So compute average of that value and take difference to 1 184 | # Get joint distribution of real class indices in the data 185 | test_dataset = DoubleMNISTDataset(test_mnist, None, 186 | same_digit_prob=opt.mnist_same_digit_prob, deterministic=True, return_labels=True) 187 | test_it = DataLoader(test_dataset) 188 | real_class_probs = np.zeros((10, 10)) 189 | for sample in test_it: 190 | _, d1, d2 = sample 191 | real_class_probs[d1, d2] += 1 192 | real_class_probs /= np.sum(real_class_probs) 193 | 194 | # Compute marginal distribution of real class indices from joint one 195 | real_class_probs_upper = np.sum(real_class_probs, axis=1) # a 196 | real_class_probs_lower = np.sum(real_class_probs, axis=0) # b 197 | real_class_probs_marginal = real_class_probs_upper * np.reshape(real_class_probs_lower, [-1, 1]) 198 | 199 | # Get joint distribution of class indices on generated data (using classifier predictions) 200 | fake_class_probs = get_class_prob_matrix(G, G_inputs, classifier, len(test_dataset), device) 201 | # Compute marginal distribution of class indices on generated data 202 | fake_class_probs_upper = np.sum(fake_class_probs, axis=1) 203 | fake_class_probs_lower = np.sum(fake_class_probs, axis=0) 204 | fake_class_probs_marginal = fake_class_probs_upper * np.reshape(fake_class_probs_lower, [-1, 1]) 205 | 206 | # Compute average of |cp(...) - cq(...)| 207 | cp = np.divide(real_class_probs, real_class_probs_marginal + 0.001) 208 | cq = np.divide(fake_class_probs, fake_class_probs_marginal + 0.001) 209 | 210 | diff_metric = np.mean(np.abs(cp - cq)) 211 | 212 | print("Dependency cp/cq diff metric: " + str(diff_metric)) 213 | writer.add_scalar("Diff-Dep", diff_metric) 214 | 215 | return fid_upper, fid_lower 216 | 217 | def get_opt(): 218 | # COLLECT ALL CMD ARGUMENTS 219 | parser = training.TrainingOptions.get_parser() 220 | 221 | parser.add_argument('--mnist_same_digit_prob', type=float, default=0.4, 222 | help="Probability of same digits occuring together. 0.1 means indpendently put together, 1.0 means always same digits, 0.0 never same digits") 223 | parser.add_argument('--num_joint_samples', type=int, default=50, 224 | help="Number of joint observations available for training normal gan/dependency discriminators") 225 | 226 | opt = parser.parse_args() 227 | # Set generator to sigmoid output 228 | opt.generator_activation = "sigmoid" 229 | print(opt) 230 | return opt 231 | 232 | if __name__ == "__main__": 233 | opt = get_opt() 234 | 235 | set_paths(opt) 236 | 237 | if not opt.eval: 238 | train(opt) 239 | eval(opt) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | This respository implements "FactorGAN" as described in the paper 4 | 5 | "Training Generative Adversarial Networks from Incomplete Observations using Factorised Discriminators" 6 | 7 | ## FactorGAN - Quick introduction 8 | 9 | Consider training a GAN to solve some generation task, where a sample can be naturally divided into multiple parts. 10 | As one example, we will use images of shoes along with a drawing of their edges (see diagram below). 11 | 12 | With a normal GAN for this task, the generator just outputs a pair of images, which the discriminator then analyses as a whole. 13 | For training the discriminator, you use real shoe-edge image pairs from a "paired" dataset. 14 | 15 | But what if you only have a few of these paired samples, but many more individual shoe OR edge images? 16 | You cannot use these for GAN training to improve the quality of your shoes and edges further, as the discriminator needs shoe-edge image pairs. 17 | 18 | That is where the FactorGAN comes in. As shown below, it makes use of all available data (shoe-edge pairs, shoes alone, edges alone): 19 | 20 | 21 | 22 | To achieve this, FactorGAN uses four discriminators: 23 | * a discriminator to judge the generator's shoe quality 24 | * another to do the same for the edges 25 | * two "dependency discriminators" to ensure the generator outputs edge maps that fit to their respective shoe images: 26 | * The "real dependency" discriminator tries to distinguish real paired examples from real ones where each shoe was randomly assigned to an edge map (by "shuffling" the real batch), thereby having to learn which edge maps go along which shoes. 27 | * The "fake dependency" discriminator does the same for generator samples. 28 | The real dependency discriminator is the only component that needs paired samples for training, while the other components can make use of the extra available shoes and edge images. 29 | 30 | Training works by alternating between a) updating all discriminators (individually) and b) updating the generator. 31 | Amazingly, we can update the generator just like in a normal GAN, by simply adding the unnormalised discriminator outputs and using the result for the generator loss. 32 | This combined output can be proven to approximate the same probability for real and fake inputs as estimated by a normal GAN discriminator. 33 | 34 | In our experiments, the FactorGAN provides very good output quality even with just very few paired samples, as the shoe and edge discriminators trained on the additional unpaired samples help the generator to output realistic shoes and edge maps. 35 | 36 | This principle can also be used for conditional generation (aka prediction tasks). 37 | Let's take image segmentation as an example: 38 | 39 | 40 | 41 | The generator now acts as a segmentation model, predicting the segmentation from a given city scene. 42 | In contrast to a normal conditional GAN, whose discriminator requires the scene along with its segmentation as "paired" input, here we use 43 | * a discriminator acting only on real and fake segmentations, trainable with individual scenes and segmentation maps, ensuring the predicted segmentation is "realistic" on its own, irrespective of the scene it was predicted from 44 | * a fake dependency discriminator that distinguishes (real scene, fake segmentation) pairs from their shuffled variant, to learn how the generator output corresponds to the input 45 | * a real dependency discriminator that distinguishes (real scene, real segmentation) pairs from their shuffled variants. 46 | 47 | We perform segmentaiton experiments on the Cityscapes dataset, treating the samples as unpaired (like the CycleGAN). 48 | But we find that adding as few as 25 paired samples yields substantially higher segmentation accuracy than the CycleGAN - suggesting that the FactorGAN fills a gap between fully unsupervised and fully supervised methods by making efficient use of both paired and unpaired samples. 49 | 50 | # Requirements 51 | 52 | * Python 3.6 53 | * Pip for installing Python packages 54 | * [libsnd](http://www.mega-nerd.com/libsndfile/) library installed 55 | * [wget](https://www.gnu.org/software/wget/) installed for downloading the datasets automatically 56 | * GPU is optional, but strongly recommended to avoid long computation times 57 | 58 | # Installation 59 | 60 | Install the required packages as listed in ``requirements.txt``. 61 | To ensure existing packages do not interfere with the installation, it is best to create a virtual environment with ``virtualenv`` first and then install the packages separately into that environment. 62 | Easy installation can be performed using 63 | ``` 64 | pip install -r requirements.txt 65 | ``` 66 | 67 | ## Dataset download 68 | 69 | ### Cityscapes and Edges2Shoes 70 | 71 | For experiments involving Cityscapes or Edges2Shoes data, you need to download these datasets first. 72 | To do this, change to the ``datasets/image2image`` subfolder in your commandline using 73 | 74 | ``` 75 | cd datasets/image2image 76 | ``` 77 | 78 | and then simply execute 79 | 80 | ``` 81 | ./download_image2image.sh cityscapes 82 | ``` 83 | 84 | or 85 | 86 | ``` 87 | ./download_image2image.sh edges2shoes 88 | ``` 89 | 90 | ### MUSDB18 (audio separation) 91 | 92 | For audio source separation experiments, you will need to download the [MUSDB18 dataset](https://sigsep.github.io/datasets/musdb.html) from [Zenodo](https://zenodo.org/record/1117372) manually, since it requires requesting access, and extract it to a folder of your choice. 93 | 94 | When running the training script, you can point to the MUSDB dataset folder by giving its path as a command-line parameter. 95 | 96 | # Running experiments 97 | 98 | To run the experiments, execute the script corresponding to the particular application, from the root directory of the repository: 99 | * ```PairedMNIST.py```: Paired MNIST experiments 100 | * ```ImagePairs.py```: Generation of image pairs (Cityscapes, Edges2Shoes) 101 | * ```Image2Image.py```: Used for image segmentation (Cityscapes) 102 | * ```AudioSeparation.py```: For vocal separation 103 | 104 | Each experiment in the paper can be replicated by specifying the experimental parameters via the commandline. 105 | 106 | Firstly, there is a set of parameters shared between all experiments, which are described in ```training/TrainingOptions.py```. 107 | The most important ones are: 108 | * ```--cuda```: Activate GPU training flag 109 | * ```--experiment_name```: Provide a string to name the experiment, which will be used to name the output folder 110 | * ```--out_path```: Provide the output folder where results and logs are saved. All output will usually be in ```out_path/TASKNAME/experiment_name```. 111 | * ```--eval```: Append this flag if you only want to perform model evaluation for an already trained model. CAUTION: ``experiment_name`` path as well as network parameters have to be set correctly (like the one used during training) to ensure this works correctly. 112 | * ```--factorGAN```: Provide a 0 to use the normal GAN, 1 for FactorGAN 113 | * ```--use_real_dep_disc```: Provide a 0 to not use a p-dependency discriminator, 1 for the full FactorGAN 114 | 115 | Every experiment also has specific extra commandline parameters which are explained in the code file for each experiment. 116 | 117 | ## Examples 118 | 119 | Train a GAN on PairedMNIST with 1000 joint samples, using GPU, and save results in ```out/PairedMNIST/100_samples_GAN```: 120 | ``` 121 | python PairedMNIST.py --cuda --factorGAN 0 --num_joint_samples 1000 --experiment_name "100_samples_GAN" 122 | ``` 123 | 124 | Train a FactorGAN to generate scene-segmentation image pairs with 100 joint samples on the Cityscapes dataset, using GPU, and save results in ```out/ImagePairs/cityscapes/100_samples_factorGAN```: 125 | ``` 126 | python ImagePairs.py --cuda --dataset "cityscapes" --num_joint_samples 100 --factorGAN 1 --experiment_name "100_samples_factorGAN" 127 | ``` 128 | 129 | Train a FactorGAN to for image segmentation on the Cityscapes dataset with 25 joint samples on the Cityscapes dataset, using GPU, and save results in ```out/Image2Image_cityscapes/25_samples_factorGAN```: 130 | ``` 131 | python Image2Image.py --cuda --dataset "cityscapes" --num_joint_samples 25 --factorGAN 1 --experiment_name "25_samples_factorGAN" 132 | ``` 133 | 134 | # Analysing and plotting results 135 | 136 | The ```analysis``` subfolder contains 137 | * some of the results obtained during our experiments, such as performance metrics 138 | * all scripts that were used to produce the figures 139 | * the figures used in the paper 140 | 141 | Some of them require the full output of an experiment, so the experiment needs to be run first, and the path to the resulting output folder can be inserted into the script. They are not included in the repository directly as they can be quite large. 142 | -------------------------------------------------------------------------------- /Utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import numpy as np 5 | import torch 6 | import random 7 | import math 8 | 9 | def create_image_pair(x, ch1, ch2): 10 | ''' 11 | Concatenates two images horizontally (that are saved in x using 3 or 1 channels for each image 12 | :param x: Pair of images (along channel dimension) 13 | :param ch1: Number of channels for image 1 14 | :param ch2: Number of channels for image 2 15 | :return: Horizontally stacked image pair 16 | ''' 17 | assert(ch1 == 3 or ch1 == 1) 18 | assert(ch2 == 3 or ch2 == 1) 19 | assert(x.shape[1] == ch1 + ch2) 20 | 21 | repeat_left = 3 if ch1 == 1 else 1 22 | repeat_right = 3 if ch2 == 1 else 1 23 | return torch.cat([x[:, :ch1, :, :].repeat(1, repeat_left, 1, 1), x[:, ch1:, :, :].repeat(1,repeat_right,1,1)], dim=3) 24 | 25 | def is_square(integer): 26 | ''' 27 | Check if number is a square of another number 28 | :param integer: Number to be checked 29 | :return: Whether number is square of another number 30 | ''' 31 | root = math.sqrt(integer) 32 | if int(root + 0.5) ** 2 == integer: 33 | return True 34 | else: 35 | return False 36 | 37 | def shuffle_batch_dims(batch, marginal_index, dim=1): 38 | ''' 39 | Shuffles groups of dimensions of a batch of samples so that groups are drawn independently of each other 40 | :param batch: Input batch to be shuffled 41 | :param marginal_index: If list: List of indices that denote the boundaries of the groups to be shuffled, excluding 0 and batch.shape[1]. 42 | If int: Each group has this many dimensions, batch.shape[1] must be divisible by this number. If None: Input batch needs to have groups as dimensions: [Num_samples, Group1_dim, ... GroupN_dim] 43 | :return: Shuffled batch 44 | ''' 45 | 46 | if isinstance(batch, torch.Tensor): 47 | out = batch.clone() 48 | else: 49 | out = batch.copy() 50 | 51 | if isinstance(marginal_index, int): 52 | assert (batch.shape[dim] % marginal_index == 0) 53 | marginal_index = [(x+1)*marginal_index for x in range(int(batch.shape[1] / marginal_index) - 1)] 54 | if isinstance(marginal_index, list): 55 | groups = marginal_index + [batch.shape[dim]] 56 | for group_idx in range(len(groups)-1): # Shuffle each group, except the first one 57 | dim_start = groups[group_idx] 58 | dim_end = groups[group_idx+1] 59 | ordering = np.random.permutation(batch.shape[0]) 60 | if dim == 1: 61 | out[:,dim_start:dim_end] = batch[ordering, dim_start:dim_end] 62 | elif dim == 2: 63 | out[:, :, dim_start:dim_end] = batch[ordering, :, dim_start:dim_end] 64 | elif dim == 3: 65 | out[:, :, :, dim_start:dim_end] = batch[ordering, :, :, dim_start:dim_end] 66 | else: 67 | raise NotImplementedError 68 | else: 69 | raise NotImplementedError 70 | 71 | return out 72 | 73 | def load(path, sr=22050, mono=True, offset=0.0, duration=None, dtype=np.float32): 74 | # ALWAYS output (n_frames, n_channels) audio 75 | y, orig_sr = librosa.load(path, sr, mono, offset, duration, dtype) 76 | if len(y.shape) == 1: 77 | y = np.expand_dims(y, axis=0) 78 | return y.T, orig_sr 79 | 80 | def shuffle_batch_image_quadrants(batch): 81 | ''' 82 | Given an input batch of square images, shuffle the four quadrants independently across examples 83 | :param batch: Input batch of square images 84 | :return: Shuffled square images 85 | ''' 86 | input_shape = batch.shape 87 | if len(batch.shape) == 2: # [batch, dim] shape means we have to reshape 88 | # Check if data can be shaped into square image, if it is not already 89 | dim = int(batch.shape[1]) 90 | root = int(math.sqrt(dim) + 0.5) 91 | assert(root ** 2 == dim) 92 | elif len(batch.shape) > 2: 93 | # Check if last two dimensions are the same size N, and reshape to [-1, C, N, N] 94 | assert(batch.shape[-2] == batch.shape[-1]) 95 | root = batch.shape[-1] 96 | else: 97 | raise SyntaxError 98 | assert(root % 2 == 0) # Image should be splittable in half 99 | q = root // 2 # Length/width of each quadrant 100 | 101 | # Change to [B, C, N, N] shape 102 | if isinstance(batch, torch.Tensor): 103 | batch_reshape = batch.view((batch.shape[0], -1, root, root)) 104 | out = batch_reshape.clone() 105 | else: 106 | batch_reshape = np.reshape(batch, (batch.shape[0], -1, root, root)) 107 | out = batch_reshape.copy() 108 | 109 | # Shuffle the four quadrants of the square image around 110 | for row in range(2): 111 | for col in range(2): 112 | if row == 0 and col == 0: continue # Do not need to shuffle first quadrant, if we shuffle all the others across the batch 113 | 114 | ordering = np.random.permutation(batch.shape[0]) 115 | out[:, :, row*q:(row+1)*q, col*q:(col+1)*q] = batch_reshape[ordering, :, row*q:(row+1)*q, col*q:(col+1)*q] 116 | 117 | # Reshape to the shape of the original input 118 | if isinstance(batch, torch.Tensor): 119 | out = out.view(input_shape) 120 | else: 121 | out = np.reshape(out, input_shape) 122 | 123 | return out 124 | 125 | def compute_spectrogram(audio, fft_size, hop_size): 126 | ''' 127 | Compute magnitude spectrogram for audio signal 128 | :param audio: Audio input signal 129 | :param fft_size: FFT Window size (samples) 130 | :param hop_size: Hop size (samples) for STFT 131 | :return: Magnitude spectrogram 132 | ''' 133 | stft = librosa.core.stft(audio, fft_size, hop_size) 134 | mag, ph = librosa.core.magphase(stft) 135 | 136 | return normalise_spectrogram(mag), ph 137 | 138 | def normalise_spectrogram(mag, cut_last_freq=True): 139 | ''' 140 | Normalise audio spectrogram with log-normalisation 141 | :param mag: Magnitude spectrogram to be normalised 142 | :param cut_last_freq: Whether to cut highest frequency bin to reach power of 2 in number of bins 143 | :return: Normalised spectrogram 144 | ''' 145 | if cut_last_freq: 146 | # Throw away last freq bin to make it number of freq bins a power of 2 147 | out = mag[:-1,:] 148 | 149 | # Normalize with log1p 150 | out = np.log1p(out) 151 | return out 152 | 153 | def normalise_spectrogram_torch(mag): 154 | return torch.log1p(mag) 155 | 156 | def denormalise_spectrogram(mag, pad_freq=True): 157 | ''' 158 | Reverses normalisation performed in "normalise_spectrogram" function 159 | :param mag: Normalised magnitudes 160 | :param pad_freq: Whether to append a frequency bin as highest frequency with 0 as energy 161 | :return: Reconstructed spectrogram 162 | ''' 163 | out = np.expm1(mag) 164 | 165 | if pad_freq: 166 | out = np.pad(out, [(0,1), (0, 0)], mode="constant") 167 | 168 | return out 169 | 170 | def denormalise_spectrogram_torch(mag): 171 | return torch.expm1(mag) 172 | 173 | def spectrogramToAudioFile(magnitude, fftWindowSize, hopSize, phaseIterations=10, phase=None, length=None): 174 | ''' 175 | Computes an audio signal from the given magnitude spectrogram, and optionally an initial phase. 176 | Griffin-Lim is executed to recover/refine the given the phase from the magnitude spectrogram. 177 | :param magnitude: Magnitudes to be converted to audio 178 | :param fftWindowSize: Size of FFT window used to create magnitudes 179 | :param hopSize: Hop size in frames used to create magnitudes 180 | :param phaseIterations: Number of Griffin-Lim iterations to recover phase 181 | :param phase: If given, starts ISTFT with this particular phase matrix 182 | :param length: If given, audio signal is clipped/padded to this number of frames 183 | :return: 184 | ''' 185 | if phase is not None: 186 | if phaseIterations > 0: 187 | # Refine audio given initial phase with a number of iterations 188 | return reconPhase(magnitude, fftWindowSize, hopSize, phaseIterations, phase, length) 189 | # reconstructing the new complex matrix 190 | stftMatrix = magnitude * np.exp(phase * 1j) # magnitude * e^(j*phase) 191 | audio = librosa.istft(stftMatrix, hop_length=hopSize, length=length) 192 | else: 193 | audio = reconPhase(magnitude, fftWindowSize, hopSize, phaseIterations) 194 | return audio 195 | 196 | def reconPhase(magnitude, fftWindowSize, hopSize, phaseIterations=10, initPhase=None, length=None): 197 | ''' 198 | Griffin-Lim algorithm for reconstructing the phase for a given magnitude spectrogram, optionally with a given 199 | intial phase. 200 | :param magnitude: Magnitudes to be converted to audio 201 | :param fftWindowSize: Size of FFT window used to create magnitudes 202 | :param hopSize: Hop size in frames used to create magnitudes 203 | :param phaseIterations: Number of Griffin-Lim iterations to recover phase 204 | :param initPhase: If given, starts reconstruction with this particular phase matrix 205 | :param length: If given, audio signal is clipped/padded to this number of frames 206 | :return: 207 | ''' 208 | for i in range(phaseIterations): 209 | if i == 0: 210 | if initPhase is None: 211 | reconstruction = np.random.random_sample(magnitude.shape) + 1j * (2 * np.pi * np.random.random_sample(magnitude.shape) - np.pi) 212 | else: 213 | reconstruction = np.exp(initPhase * 1j) # e^(j*phase), so that angle => phase 214 | else: 215 | reconstruction = librosa.stft(audio, fftWindowSize, hopSize) 216 | spectrum = magnitude * np.exp(1j * np.angle(reconstruction)) 217 | if i == phaseIterations - 1: 218 | audio = librosa.istft(spectrum, hopSize, length=length) 219 | else: 220 | audio = librosa.istft(spectrum, hopSize) 221 | return audio 222 | 223 | def make_dirs(dirs): 224 | if isinstance(dirs, str): 225 | dirs = [dirs] 226 | assert(isinstance(dirs, list)) 227 | for dir in dirs: 228 | if not os.path.exists(dir): 229 | os.makedirs(dir) 230 | 231 | def create_optim(parameters, opt): 232 | return torch.optim.Adam(parameters, lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=opt.L2) 233 | 234 | def get_device(cuda): 235 | if torch.cuda.is_available() and not cuda: 236 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 237 | device = torch.device("cuda:0" if cuda else "cpu") 238 | return device 239 | 240 | def set_seeds(opt): 241 | ''' 242 | Set Python, numpy as and torch random seeds to a fixed number 243 | :param opt: Option dictionary containined .seed member value 244 | ''' 245 | if opt.seed is None: 246 | opt.seed = random.randint(1, 10000) 247 | print("Random Seed: ", opt.seed) 248 | random.seed(opt.seed) 249 | torch.manual_seed(opt.seed) 250 | np.random.seed(opt.seed) -------------------------------------------------------------------------------- /analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/__init__.py -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs Results.csv: -------------------------------------------------------------------------------- 1 | LS,Model,Paired samples,Dataset,Samples_Validation,Samples_Test 2 | 0.19918266125023365,GAN,100,edges2shoes,160,40 3 | 0.197277020663023,GAN,100,edges2shoes,160,40 4 | 0.18398787640035152,GAN,100,edges2shoes,160,40 5 | 0.20709695480763912,GAN,100,edges2shoes,160,40 6 | 0.1931029949337244,GAN,100,edges2shoes,160,40 7 | 0.1648500096052885,GAN,100,edges2shoes,160,40 8 | 0.2281373254954815,GAN,100,edges2shoes,160,40 9 | 0.1764399316161871,GAN,100,edges2shoes,160,40 10 | 0.19546093232929707,GAN,100,edges2shoes,160,40 11 | 0.24289108626544476,GAN,100,edges2shoes,160,40 12 | 0.4614328145980835,GAN,1000,edges2shoes,160,40 13 | 0.4914138615131378,GAN,1000,edges2shoes,160,40 14 | 0.4868180975317955,GAN,1000,edges2shoes,160,40 15 | 0.4894701614975929,GAN,1000,edges2shoes,160,40 16 | 0.5159566402435303,GAN,1000,edges2shoes,160,40 17 | 0.4466138854622841,GAN,1000,edges2shoes,160,40 18 | 0.496683269739151,GAN,1000,edges2shoes,160,40 19 | 0.4755656234920025,GAN,1000,edges2shoes,160,40 20 | 0.5014562234282494,GAN,1000,edges2shoes,160,40 21 | 0.4484795890748501,GAN,1000,edges2shoes,160,40 22 | 0.4587823525071144,GAN,49825,edges2shoes,160,40 23 | 0.49065449088811874,GAN,49825,edges2shoes,160,40 24 | 0.47518329322338104,GAN,49825,edges2shoes,160,40 25 | 0.5073818042874336,GAN,49825,edges2shoes,160,40 26 | 0.4914328381419182,GAN,49825,edges2shoes,160,40 27 | 0.4772345498204231,GAN,49825,edges2shoes,160,40 28 | 0.46360666304826736,GAN,49825,edges2shoes,160,40 29 | 0.45902401581406593,GAN,49825,edges2shoes,160,40 30 | 0.47111673653125763,GAN,49825,edges2shoes,160,40 31 | 0.48572762310504913,GAN,49825,edges2shoes,160,40 32 | 0.49867895245552063,FactorGAN,100,edges2shoes,160,40 33 | 0.48741042986512184,FactorGAN,100,edges2shoes,160,40 34 | 0.4798458069562912,FactorGAN,100,edges2shoes,160,40 35 | 0.4852375388145447,FactorGAN,100,edges2shoes,160,40 36 | 0.5021589547395706,FactorGAN,100,edges2shoes,160,40 37 | 0.5132415220141411,FactorGAN,100,edges2shoes,160,40 38 | 0.5196625292301178,FactorGAN,100,edges2shoes,160,40 39 | 0.4916936084628105,FactorGAN,100,edges2shoes,160,40 40 | 0.4872872121632099,FactorGAN,100,edges2shoes,160,40 41 | 0.5147524625062943,FactorGAN,100,edges2shoes,160,40 42 | 0.5059081017971039,FactorGAN,1000,edges2shoes,160,40 43 | 0.5109081342816353,FactorGAN,1000,edges2shoes,160,40 44 | 0.47647980600595474,FactorGAN,1000,edges2shoes,160,40 45 | 0.5170110575854778,FactorGAN,1000,edges2shoes,160,40 46 | 0.49985283613204956,FactorGAN,1000,edges2shoes,160,40 47 | 0.50772500410676,FactorGAN,1000,edges2shoes,160,40 48 | 0.5248940289020538,FactorGAN,1000,edges2shoes,160,40 49 | 0.525466576218605,FactorGAN,1000,edges2shoes,160,40 50 | 0.5174819976091385,FactorGAN,1000,edges2shoes,160,40 51 | 0.53387051820755,FactorGAN,1000,edges2shoes,160,40 52 | 0.5184591487050056,FactorGAN,49825,edges2shoes,160,40 53 | 0.5095947757363319,FactorGAN,49825,edges2shoes,160,40 54 | 0.49991364777088165,FactorGAN,49825,edges2shoes,160,40 55 | 0.5085523054003716,FactorGAN,49825,edges2shoes,160,40 56 | 0.4887695945799351,FactorGAN,49825,edges2shoes,160,40 57 | 0.523505188524723,FactorGAN,49825,edges2shoes,160,40 58 | 0.4910292327404022,FactorGAN,49825,edges2shoes,160,40 59 | 0.5000504329800606,FactorGAN,49825,edges2shoes,160,40 60 | 0.48224231973290443,FactorGAN,49825,edges2shoes,160,40 61 | 0.49068504571914673,FactorGAN,49825,edges2shoes,160,40 62 | 0.06610106024891138,GAN,100,cityscapes,400,100 63 | 0.0486236484721303,GAN,100,cityscapes,400,100 64 | 0.06520971702411771,GAN,100,cityscapes,400,100 65 | 0.043228451162576675,GAN,100,cityscapes,400,100 66 | 0.06359246838837862,GAN,100,cityscapes,400,100 67 | 0.03645658027380705,GAN,100,cityscapes,400,100 68 | 0.04610483441501856,GAN,100,cityscapes,400,100 69 | 0.05799236707389355,GAN,100,cityscapes,400,100 70 | 0.049013398587703705,GAN,100,cityscapes,400,100 71 | 0.06162051111459732,GAN,100,cityscapes,400,100 72 | 0.041525376960635185,GAN,1000,cityscapes,400,100 73 | 0.07696561142802238,GAN,1000,cityscapes,400,100 74 | 0.05571409035474062,GAN,1000,cityscapes,400,100 75 | 0.06277178600430489,GAN,1000,cityscapes,400,100 76 | 0.03282815124839544,GAN,1000,cityscapes,400,100 77 | 0.05692122969776392,GAN,1000,cityscapes,400,100 78 | 0.05593732185661793,GAN,1000,cityscapes,400,100 79 | 0.03852173686027527,GAN,1000,cityscapes,400,100 80 | 0.03605717979371548,GAN,1000,cityscapes,400,100 81 | 0.04344829358160496,GAN,1000,cityscapes,400,100 82 | 0.0877392329275608,GAN,2975,cityscapes,400,100 83 | 0.08541159704327583,GAN,2975,cityscapes,400,100 84 | 0.08468282781541348,GAN,2975,cityscapes,400,100 85 | 0.06158710457384586,GAN,2975,cityscapes,400,100 86 | 0.10693855583667755,GAN,2975,cityscapes,400,100 87 | 0.08687366731464863,GAN,2975,cityscapes,400,100 88 | 0.10995074361562729,GAN,2975,cityscapes,400,100 89 | 0.09155567362904549,GAN,2975,cityscapes,400,100 90 | 0.08998952526599169,GAN,2975,cityscapes,400,100 91 | 0.07705152779817581,GAN,2975,cityscapes,400,100 92 | 0.15152879618108273,FactorGAN,100,cityscapes,400,100 93 | 0.1284797377884388,FactorGAN,100,cityscapes,400,100 94 | 0.13163253664970398,FactorGAN,100,cityscapes,400,100 95 | 0.18368030712008476,FactorGAN,100,cityscapes,400,100 96 | 0.12400678172707558,FactorGAN,100,cityscapes,400,100 97 | 0.1800699234008789,FactorGAN,100,cityscapes,400,100 98 | 0.14700616523623466,FactorGAN,100,cityscapes,400,100 99 | 0.2166460081934929,FactorGAN,100,cityscapes,400,100 100 | 0.1466676127165556,FactorGAN,100,cityscapes,400,100 101 | 0.13666629791259766,FactorGAN,100,cityscapes,400,100 102 | 0.24428238719701767,FactorGAN,1000,cityscapes,400,100 103 | 0.220618337392807,FactorGAN,1000,cityscapes,400,100 104 | 0.20961547270417213,FactorGAN,1000,cityscapes,400,100 105 | 0.20297425985336304,FactorGAN,1000,cityscapes,400,100 106 | 0.24880168214440346,FactorGAN,1000,cityscapes,400,100 107 | 0.20047229900956154,FactorGAN,1000,cityscapes,400,100 108 | 0.19510486349463463,FactorGAN,1000,cityscapes,400,100 109 | 0.213859923183918,FactorGAN,1000,cityscapes,400,100 110 | 0.1974785439670086,FactorGAN,1000,cityscapes,400,100 111 | 0.28764569014310837,FactorGAN,1000,cityscapes,400,100 112 | 0.28121019154787064,FactorGAN,2975,cityscapes,400,100 113 | 0.24645909294486046,FactorGAN,2975,cityscapes,400,100 114 | 0.22664135321974754,FactorGAN,2975,cityscapes,400,100 115 | 0.24045272171497345,FactorGAN,2975,cityscapes,400,100 116 | 0.26552461832761765,FactorGAN,2975,cityscapes,400,100 117 | 0.2997002340853214,FactorGAN,2975,cityscapes,400,100 118 | 0.21766618266701698,FactorGAN,2975,cityscapes,400,100 119 | 0.28905053064227104,FactorGAN,2975,cityscapes,400,100 120 | 0.21604740619659424,FactorGAN,2975,cityscapes,400,100 121 | 0.19727909564971924,FactorGAN,2975,cityscapes,400,100 122 | 0.09539508633315563,GAN (big),2975,cityscapes,400,100 123 | 0.18587369099259377,GAN (big),2975,cityscapes,400,100 124 | 0.1633717156946659,GAN (big),2975,cityscapes,400,100 125 | 0.16083178855478764,GAN (big),2975,cityscapes,400,100 126 | 0.16951332241296768,GAN (big),2975,cityscapes,400,100 127 | 0.1349962204694748,GAN (big),2975,cityscapes,400,100 128 | 0.15365711972117424,GAN (big),2975,cityscapes,400,100 129 | 0.1875893995165825,GAN (big),2975,cityscapes,400,100 130 | 0.12070108018815517,GAN (big),2975,cityscapes,400,100 131 | 0.1576228328049183,GAN (big),2975,cityscapes,400,100 132 | 0.12126125209033489,GAN (big),1000,cityscapes,400,100 133 | 0.1332952231168747,GAN (big),1000,cityscapes,400,100 134 | 0.09272367879748344,GAN (big),1000,cityscapes,400,100 135 | 0.18972866609692574,GAN (big),1000,cityscapes,400,100 136 | 0.1441386379301548,GAN (big),1000,cityscapes,400,100 137 | 0.14152801036834717,GAN (big),1000,cityscapes,400,100 138 | 0.05023533571511507,GAN (big),1000,cityscapes,400,100 139 | 0.19023562967777252,GAN (big),1000,cityscapes,400,100 140 | 0.1459544152021408,GAN (big),1000,cityscapes,400,100 141 | 0.1345924697816372,GAN (big),1000,cityscapes,400,100 142 | 0.04006412858143449,GAN (big),100,cityscapes,400,100 143 | 0.03009739425033331,GAN (big),100,cityscapes,400,100 144 | 0.04413636680692434,GAN (big),100,cityscapes,400,100 145 | 0.021442324155941606,GAN (big),100,cityscapes,400,100 146 | 0.03707707207649946,GAN (big),100,cityscapes,400,100 147 | 0.036953238770365715,GAN (big),100,cityscapes,400,100 148 | 0.04017030727118254,GAN (big),100,cityscapes,400,100 149 | 0.03742054011672735,GAN (big),100,cityscapes,400,100 150 | 0.0391128808259964,GAN (big),100,cityscapes,400,100 151 | 0.04093080945312977,GAN (big),100,cityscapes,400,100 152 | 0.4937850534915924,GAN (big),49825,edges2shoes,160,40 153 | 0.4875243380665779,GAN (big),49825,edges2shoes,160,40 154 | 0.4851335920393467,GAN (big),49825,edges2shoes,160,40 155 | 0.5300062000751495,GAN (big),49825,edges2shoes,160,40 156 | 0.5167667642235756,GAN (big),49825,edges2shoes,160,40 157 | 0.48705537244677544,GAN (big),49825,edges2shoes,160,40 158 | 0.5067377276718616,GAN (big),49825,edges2shoes,160,40 159 | 0.4927330017089844,GAN (big),49825,edges2shoes,160,40 160 | 0.5186118818819523,GAN (big),49825,edges2shoes,160,40 161 | 0.49117034673690796,GAN (big),49825,edges2shoes,160,40 162 | 0.44656917452812195,GAN (big),1000,edges2shoes,160,40 163 | 0.43908343836665154,GAN (big),1000,edges2shoes,160,40 164 | 0.47615355253219604,GAN (big),1000,edges2shoes,160,40 165 | 0.4346071891486645,GAN (big),1000,edges2shoes,160,40 166 | 0.44736265018582344,GAN (big),1000,edges2shoes,160,40 167 | 0.43023814633488655,GAN (big),1000,edges2shoes,160,40 168 | 0.49190129339694977,GAN (big),1000,edges2shoes,160,40 169 | 0.4424554482102394,GAN (big),1000,edges2shoes,160,40 170 | 0.4625251926481724,GAN (big),1000,edges2shoes,160,40 171 | 0.39753028750419617,GAN (big),1000,edges2shoes,160,40 172 | 0.11263203108683228,GAN (big),100,edges2shoes,160,40 173 | 0.10219825804233551,GAN (big),100,edges2shoes,160,40 174 | 0.11926222685724497,GAN (big),100,edges2shoes,160,40 175 | 0.15174391120672226,GAN (big),100,edges2shoes,160,40 176 | 0.12145963963121176,GAN (big),100,edges2shoes,160,40 177 | 0.13108434528112411,GAN (big),100,edges2shoes,160,40 178 | 0.1380934789776802,GAN (big),100,edges2shoes,160,40 179 | 0.16618123278021812,GAN (big),100,edges2shoes,160,40 180 | 0.10596461221575737,GAN (big),100,edges2shoes,160,40 181 | 0.10804288182407618,GAN (big),100,edges2shoes,160,40 182 | -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_cityscapes_1000_joint_GAN_big_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_cityscapes_1000_joint_GAN_big_gens.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_cityscapes_1000_joint_GAN_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_cityscapes_1000_joint_GAN_gens.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_cityscapes_1000_joint_factorGAN_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_cityscapes_1000_joint_factorGAN_gens.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_cityscapes_100_joint_GAN_big_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_cityscapes_100_joint_GAN_big_gens.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_cityscapes_100_joint_GAN_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_cityscapes_100_joint_GAN_gens.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_cityscapes_100_joint_factorGAN_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_cityscapes_100_joint_factorGAN_gens.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_cityscapes_All_joint_GAN_big_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_cityscapes_All_joint_GAN_big_gens.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_cityscapes_All_joint_GAN_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_cityscapes_All_joint_GAN_gens.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_cityscapes_All_joint_factorGAN_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_cityscapes_All_joint_factorGAN_gens.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_edges2shoes_1000_joint_GAN_big_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_1000_joint_GAN_big_gens.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_edges2shoes_1000_joint_GAN_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_1000_joint_GAN_gens.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_edges2shoes_1000_joint_factorGAN_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_1000_joint_factorGAN_gens.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_edges2shoes_100_joint_GAN_big_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_100_joint_GAN_big_gens.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_edges2shoes_100_joint_GAN_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_100_joint_GAN_gens.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_edges2shoes_100_joint_GAN_gens_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_100_joint_GAN_gens_8.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_edges2shoes_100_joint_factorGAN_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_100_joint_factorGAN_gens.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_edges2shoes_100_joint_factorGAN_gens_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_100_joint_factorGAN_gens_8.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_edges2shoes_All_joint_GAN_big_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_All_joint_GAN_big_gens.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_edges2shoes_All_joint_GAN_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_All_joint_GAN_gens.png -------------------------------------------------------------------------------- /analysis/imagepairs/ImagePairs_edges2shoes_All_joint_factorGAN_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_All_joint_factorGAN_gens.png -------------------------------------------------------------------------------- /analysis/imagepairs/PlotExamples.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Script to plot ImagePair experiment generator examples. 3 | Point the RESULTS_PATH to the folder where experiment logs are saved (that contains an "ImagePairs" folder) 4 | The script will automatically plot generator examples for each experiment that is found in the folder. 5 | The number of images to show for each model can be set by NUM_EXAMPLES. 6 | NUM_COLS sets the number of columns (along horizontal axis) to use for plotting. 7 | ''' 8 | 9 | import glob 10 | 11 | import imageio as imageio 12 | import numpy as np 13 | import torch 14 | import torchvision 15 | import os 16 | 17 | RESULTS_PATH = "/mnt/windaten/Results/factorGAN/" 18 | DATASET_FOLDERS = ["cityscapes", "edges2shoes"] 19 | OUT_PATH = "" 20 | NUM_EXAMPLES = 16 21 | NUM_COLS = 4 22 | 23 | for dataset in DATASET_FOLDERS: 24 | for experiment_path in glob.glob(os.path.join(RESULTS_PATH, "ImagePairs", dataset, "*")): 25 | model = os.path.basename(experiment_path) 26 | gan_paths = [os.path.join(RESULTS_PATH, "ImagePairs", dataset, model, "gen", "gen_" + str(i) + ".png") for i in range(NUM_EXAMPLES)] 27 | gan_imgs = list() 28 | 29 | for file in gan_paths: 30 | gan_imgs.append(imageio.imread(file)) 31 | gan_imgs = torch.from_numpy(np.transpose(np.stack(gan_imgs), [0, 3, 1, 2])) 32 | gan_imgs = torchvision.utils.make_grid(gan_imgs, nrow=NUM_COLS, padding=10, pad_value=255.0).permute(1, 2, 0) 33 | 34 | imageio.imwrite(os.path.join(OUT_PATH, "ImagePairs_" + dataset + "_" + model + "_gens.png"), gan_imgs) -------------------------------------------------------------------------------- /analysis/imagepairs/PlotLS.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | 5 | sns.set(font_scale=1.1, rc={'text.usetex' : True}) 6 | 7 | df = pd.read_csv("ImagePairs Results.csv", delimiter=",") 8 | df.loc[df["Paired samples"] > 1000, "Paired samples"] = "All" 9 | 10 | g = sns.catplot(x="Paired samples", y="LS", hue="Model", col="Dataset", data=df, kind="bar", ci=95, height=3, aspect=2)#, hue_order=["FactorGAN", "GAN", "GAN (big)"]) 11 | 12 | g.axes[0][0].set_title("Edges2Shoes") 13 | g.axes[0][1].set_title("Cityscapes") 14 | 15 | #plt.tight_layout() 16 | plt.savefig("imagepairs_LS.pdf") -------------------------------------------------------------------------------- /analysis/imagepairs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/__init__.py -------------------------------------------------------------------------------- /analysis/imagepairs/imagepairs_LS.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/imagepairs_LS.pdf -------------------------------------------------------------------------------- /analysis/img2img_cityscapes/Image2Image_Cityscapes.csv: -------------------------------------------------------------------------------- 1 | Paired samples,Perf,Model,Metric 2 | 25,0.0409,GAN,MSE 3 | 100,0.03447,GAN,MSE 4 | 500,0.02419,GAN,MSE 5 | 1000,0.01954,GAN,MSE 6 | 3000,0.02029,GAN,MSE 7 | 25,0.02665,FactorGAN,MSE 8 | 100,0.02688,FactorGAN,MSE 9 | 500,0.02116,FactorGAN,MSE 10 | 1000,0.02038,FactorGAN,MSE 11 | 3000,0.0196,FactorGAN,MSE 12 | 25,53.98,GAN,Accuracy 13 | 100,64.41,GAN,Accuracy 14 | 500,73.26,GAN,Accuracy 15 | 1000,79.24,GAN,Accuracy 16 | 3000,78.18,GAN,Accuracy 17 | 25,71.66,FactorGAN,Accuracy 18 | 100,71.98,FactorGAN,Accuracy 19 | 500,77.41,FactorGAN,Accuracy 20 | 1000,77.95,FactorGAN,Accuracy 21 | 3000,78.66,FactorGAN,Accuracy -------------------------------------------------------------------------------- /analysis/img2img_cityscapes/PlotExamples.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Visualise Cityscapes image segmentation predictions. Requires experiments folders to be existing with the correct names first! 3 | ''' 4 | import glob 5 | 6 | import imageio as imageio 7 | import numpy as np 8 | import torch 9 | import torchvision 10 | import os 11 | 12 | NUM_ROWS = 3 13 | NUM_COLUMNS = 2 14 | 15 | BASEDIR = "../../out/Image2Image_cityscapes/" 16 | EXPERIMENTS = [os.path.basename(p) for p in glob.glob(os.path.join(BASEDIR, "*"))] 17 | 18 | for experiment in EXPERIMENTS: 19 | gan_paths = [os.path.join(BASEDIR, experiment, "gen", "gen_" + str(i) + ".png") for i in range(NUM_ROWS*NUM_COLUMNS)] 20 | gan_imgs = list() 21 | 22 | for file in gan_paths: 23 | gan_imgs.append(imageio.imread(file)) 24 | gan_imgs = torch.from_numpy(np.transpose(np.stack(gan_imgs), [0, 3, 1, 2])) 25 | gan_imgs = torchvision.utils.make_grid(gan_imgs, nrow=2, padding=10, pad_value=255.0).permute(1, 2, 0) 26 | 27 | imageio.imwrite("cityscapes_" + experiment + "_gens.png", gan_imgs) -------------------------------------------------------------------------------- /analysis/img2img_cityscapes/PlotL2Acc.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Plot L2 and MSE accuracy of image segmentation models (GAN vs FactorGAN) 3 | ''' 4 | 5 | import seaborn as sns 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | 9 | sns.set(font_scale=1.1, rc={'text.usetex' : True}) 10 | 11 | df = pd.read_csv("Image2Image_Cityscapes.csv", delimiter=",") 12 | g = sns.catplot("Paired samples", "Perf", hue="Model", col="Metric", data=df, kind="bar", sharey=False, height=3, aspect=2) 13 | g.axes[0][0].set_ylabel("Mean squared error") 14 | g.axes[0][0].set_title("") 15 | g.axes[0][1].set_title("") 16 | g.axes[0][1].set_ylabel("Accuracy (\%)") 17 | 18 | g.fig.subplots_adjust(top=0.95, wspace=0.15) 19 | 20 | plt.savefig("cityscapes.pdf") -------------------------------------------------------------------------------- /analysis/img2img_cityscapes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/__init__.py -------------------------------------------------------------------------------- /analysis/img2img_cityscapes/cityscapes.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/cityscapes.pdf -------------------------------------------------------------------------------- /analysis/img2img_cityscapes/cityscapes_1000_joint_GAN_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/cityscapes_1000_joint_GAN_gens.png -------------------------------------------------------------------------------- /analysis/img2img_cityscapes/cityscapes_1000_joint_factorGAN_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/cityscapes_1000_joint_factorGAN_gens.png -------------------------------------------------------------------------------- /analysis/img2img_cityscapes/cityscapes_100_joint_GAN_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/cityscapes_100_joint_GAN_gens.png -------------------------------------------------------------------------------- /analysis/img2img_cityscapes/cityscapes_100_joint_GAN_gens_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/cityscapes_100_joint_GAN_gens_small.png -------------------------------------------------------------------------------- /analysis/img2img_cityscapes/cityscapes_100_joint_factorGAN_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/cityscapes_100_joint_factorGAN_gens.png -------------------------------------------------------------------------------- /analysis/img2img_cityscapes/cityscapes_100_joint_factorGAN_gens_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/cityscapes_100_joint_factorGAN_gens_small.png -------------------------------------------------------------------------------- /analysis/img2img_cityscapes/cityscapes_all_joint_GAN_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/cityscapes_all_joint_GAN_gens.png -------------------------------------------------------------------------------- /analysis/img2img_cityscapes/cityscapes_all_joint_factorGAN_gens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/cityscapes_all_joint_factorGAN_gens.png -------------------------------------------------------------------------------- /analysis/mnist/PairedMNIST Results Diff.csv: -------------------------------------------------------------------------------- 1 | Paired samples;Lambda;Diff;Model 2 | 10;0.1;4.142;GAN 3 | 20;0.1;2.296;GAN 4 | 50;0.1;1.03;GAN 5 | 100;0.1;0.746;GAN 6 | 1000;0.1;0.2392;GAN 7 | 5000;0.1;0.2024;GAN 8 | 10000;0.1;0.2072;GAN 9 | 20000;0.1;0.1986;GAN 10 | 60000;0.1;0.2293;GAN 11 | 10;0.9;0.6944;GAN 12 | 20;0.9;0.3906;GAN 13 | 50;0.9;0.4542;GAN 14 | 100;0.9;0.3664;GAN 15 | 1000;0.9;0.6986;GAN 16 | 5000;0.9;0.6833;GAN 17 | 10000;0.9;0.6688;GAN 18 | 20000;0.9;0.723;GAN 19 | 60000;0.9;0.7067;GAN 20 | 10;0.1;0.4388;FactorGAN 21 | 20;0.1;0.381;FactorGAN 22 | 50;0.1;0.306;FactorGAN 23 | 100;0.1;0.3545;FactorGAN 24 | 1000;0.1;0.2586;FactorGAN 25 | 5000;0.1;0.2501;FactorGAN 26 | 10000;0.1;0.2225;FactorGAN 27 | 20000;0.1;0.2162;FactorGAN 28 | 60000;0.1;0.2025;FactorGAN 29 | 10;0.9;1.249;FactorGAN 30 | 20;0.9;1.174;FactorGAN 31 | 50;0.9;1.224;FactorGAN 32 | 100;0.9;1.208;FactorGAN 33 | 1000;0.9;0.815;FactorGAN 34 | 5000;0.9;0.513;FactorGAN 35 | 10000;0.9;0.4168;FactorGAN 36 | 20000;0.9;0.3421;FactorGAN 37 | 60000;0.9;0.3271;FactorGAN 38 | 10;0.1;0.267;FactorGAN-no-cp 39 | 20;0.1;0.267;FactorGAN-no-cp 40 | 50;0.1;0.267;FactorGAN-no-cp 41 | 100;0.1;0.267;FactorGAN-no-cp 42 | 1000;0.1;0.267;FactorGAN-no-cp 43 | 5000;0.1;0.267;FactorGAN-no-cp 44 | 10000;0.1;0.267;FactorGAN-no-cp 45 | 20000;0.1;0.267;FactorGAN-no-cp 46 | 60000;0.1;0.267;FactorGAN-no-cp 47 | 10;0.9;1.488;FactorGAN-no-cp 48 | 20;0.9;1.488;FactorGAN-no-cp 49 | 50;0.9;1.488;FactorGAN-no-cp 50 | 100;0.9;1.488;FactorGAN-no-cp 51 | 1000;0.9;1.488;FactorGAN-no-cp 52 | 5000;0.9;1.488;FactorGAN-no-cp 53 | 10000;0.9;1.488;FactorGAN-no-cp 54 | 20000;0.9;1.488;FactorGAN-no-cp 55 | 60000;0.9;1.488;FactorGAN-no-cp 56 | -------------------------------------------------------------------------------- /analysis/mnist/PairedMNIST Results FID.csv: -------------------------------------------------------------------------------- 1 | Paired samples;Lambda;FID;Model 2 | 10;0.1;154.45;GAN 3 | 20;0.1;83.97;GAN 4 | 50;0.1;37.24;GAN 5 | 100;0.1;22.58;GAN 6 | 1000;0.1;4.137;GAN 7 | 5000;0.1;3.0715;GAN 8 | 10000;0.1;2.752;GAN 9 | 20000;0.1;2.953;GAN 10 | 60000;0.1;2.615;GAN 11 | 10;0.9;144.05;GAN 12 | 20;0.9;62.66;GAN 13 | 50;0.9;34.265;GAN 14 | 100;0.9;18.52;GAN 15 | 1000;0.9;5.9085;GAN 16 | 5000;0.9;2.8295;GAN 17 | 10000;0.9;2.9965;GAN 18 | 20000;0.9;2.48;GAN 19 | 60000;0.9;2.129;GAN 20 | 10;0.1;7.7045;FactorGAN 21 | 20;0.1;6.793;FactorGAN 22 | 50;0.1;3.2665;FactorGAN 23 | 100;0.1;4.341;FactorGAN 24 | 1000;0.1;3.466;FactorGAN 25 | 5000;0.1;2.5155;FactorGAN 26 | 10000;0.1;2.747;FactorGAN 27 | 20000;0.1;2.814;FactorGAN 28 | 60000;0.1;2.861;FactorGAN 29 | 10;0.9;10.815;FactorGAN 30 | 20;0.9;5.4855;FactorGAN 31 | 50;0.9;4.057;FactorGAN 32 | 100;0.9;4.4495;FactorGAN 33 | 1000;0.9;3.0935;FactorGAN 34 | 5000;0.9;1.382;FactorGAN 35 | 10000;0.9;1.5345;FactorGAN 36 | 20000;0.9;1.3675;FactorGAN 37 | 60000;0.9;1.142;FactorGAN 38 | 10;0.1;2.827;FactorGAN-no-cp 39 | 20;0.1;2.827;FactorGAN-no-cp 40 | 50;0.1;2.827;FactorGAN-no-cp 41 | 100;0.1;2.827;FactorGAN-no-cp 42 | 1000;0.1;2.827;FactorGAN-no-cp 43 | 5000;0.1;2.827;FactorGAN-no-cp 44 | 10000;0.1;2.827;FactorGAN-no-cp 45 | 20000;0.1;2.827;FactorGAN-no-cp 46 | 60000;0.1;2.827;FactorGAN-no-cp 47 | 10;0.9;2.598;FactorGAN-no-cp 48 | 20;0.9;2.598;FactorGAN-no-cp 49 | 50;0.9;2.598;FactorGAN-no-cp 50 | 100;0.9;2.598;FactorGAN-no-cp 51 | 1000;0.9;2.598;FactorGAN-no-cp 52 | 5000;0.9;2.598;FactorGAN-no-cp 53 | 10000;0.9;2.598;FactorGAN-no-cp 54 | 20000;0.9;2.598;FactorGAN-no-cp 55 | 60000;0.9;2.598;FactorGAN-no-cp 56 | -------------------------------------------------------------------------------- /analysis/mnist/PlotDep.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Plot d_dep dependency metric for PairedMNIST experiment 3 | ''' 4 | 5 | import seaborn as sns 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | 9 | sns.set(font_scale=1.1, rc={'text.usetex' : True}) 10 | 11 | fig, ax = plt.subplots(figsize=(6,3.5)) 12 | 13 | df = pd.read_csv("PairedMNIST Results Diff.csv", delimiter=";") 14 | 15 | # Get no-cp performances 16 | low_lambda_nodep = df[(df["Model"] == "FactorGAN-no-cp") & (df["Lambda"] == 0.1)]["Diff"].as_matrix()[0] 17 | high_lambda_nodep = df[(df["Model"] == "FactorGAN-no-cp") & (df["Lambda"] == 0.9)]["Diff"].as_matrix()[0] 18 | 19 | # Filter facgan-no-cp 20 | df = df[df["Model"] != "FactorGAN-no-cp"] 21 | 22 | df["Model"] = df["Model"] + ", $\lambda$=" + df["Lambda"].apply(lambda x: str(x)) 23 | 24 | ax.axhline(y=low_lambda_nodep, c='black', linestyle='--', label="FactorGAN-no-cp, $\lambda=0.1$", alpha=0.8) 25 | ax.axhline(y=high_lambda_nodep, c='gray', linestyle='--', label="FactorGAN-no-cp, $\lambda=0.9$", alpha=0.8) 26 | 27 | sns.barplot("Paired samples", "Diff", hue="Model", data=df, ax=ax) 28 | 29 | ax.set_ylabel("Dependency metric $d_{dep}$") 30 | 31 | # Sort legend 32 | handles, labels = ax.get_legend_handles_labels() 33 | handles = handles[2:] + handles[0:2] 34 | labels = labels[2:] + labels[0:2] 35 | ax.legend(handles, labels) 36 | 37 | #plt.legend() 38 | fig.tight_layout() 39 | fig.savefig("mnist_dep.pdf") -------------------------------------------------------------------------------- /analysis/mnist/PlotExamples.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Plot PairedMNIST generated examples 3 | ''' 4 | 5 | import os 6 | 7 | import imageio as imageio 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import torch 11 | import torchvision 12 | 13 | fig, axes = plt.subplots(2, 3) 14 | root_path = "../../out/PairedMNIST" 15 | 16 | for idx, samples in enumerate([100, 5000, 20000]): 17 | for model_idx, model in enumerate(["GAN", "factorGAN"]): 18 | gan_paths = [os.path.join(root_path, str(samples) + "_joint_0.9_samedigit_" + model, "gen", "gen_" + str(i) + ".png") for i in range(28)] 19 | 20 | gan_imgs = list() 21 | for file in gan_paths: 22 | gan_imgs.append(imageio.imread(file)) 23 | gan_imgs = torch.from_numpy(np.transpose(np.stack(gan_imgs), [0, 3, 1, 2])) 24 | gan_imgs = torchvision.utils.make_grid(gan_imgs, nrow=7, padding=5, pad_value=255.0).permute(1, 2, 0) 25 | 26 | axes[model_idx][idx].imshow(gan_imgs) 27 | axes[model_idx][idx].axis("off") 28 | 29 | plt.subplots_adjust(wspace=0.1, hspace=.1) 30 | 31 | axes[0][0].text(-0.3,0.5, "GAN", size=12, ha="center", transform=axes[0][0].transAxes) 32 | axes[1][0].text(-0.3,0.5, "factorGAN", size=12, ha="center", transform=axes[1][0].transAxes) 33 | 34 | axes[1][0].text(0.5,-0.1, "100", size=12, ha="center", transform=axes[1][0].transAxes) 35 | axes[1][1].text(0.5,-0.1, "500", size=12, ha="center", transform=axes[1][1].transAxes) 36 | axes[1][2].text(0.5,-0.1, "20000", size=12, ha="center", transform=axes[1][2].transAxes) 37 | 38 | plt.savefig("mnist_examples.pdf", bbox_inches="tight") -------------------------------------------------------------------------------- /analysis/mnist/PlotFID.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Plot FID values for PairedMNIST experiment 3 | ''' 4 | 5 | import seaborn as sns 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | 9 | sns.set(font_scale=1.1, rc={'text.usetex' : True}) 10 | 11 | fig, ax = plt.subplots(figsize=(6,3.5)) 12 | # PLOT FID 13 | df = pd.read_csv("PairedMNIST Results FID.csv", delimiter=";") 14 | 15 | # Get no-cp performances 16 | low_lambda_nodep = df[(df["Model"] == "FactorGAN-no-cp") & (df["Lambda"] == 0.1)]["FID"].as_matrix()[0] 17 | high_lambda_nodep = df[(df["Model"] == "FactorGAN-no-cp") & (df["Lambda"] == 0.9)]["FID"].as_matrix()[0] 18 | 19 | # Filter facgan-no-cp 20 | df = df[df["Model"] != "FactorGAN-no-cp"] 21 | 22 | # Combine model with lambda 23 | df["Model"] = df["Model"] + ", $\lambda$=" + df["Lambda"].apply(lambda x: str(x)) 24 | 25 | sns.barplot("Paired samples", "FID", hue="Model", data=df, ax=ax) 26 | ax.set_yscale("log") 27 | 28 | ax.axhline(y=low_lambda_nodep, c='black', linestyle='dashed', label="FactorGAN-no-cp, $\lambda=0.1$", alpha=0.8) 29 | ax.axhline(y=high_lambda_nodep, c='gray', linestyle='dashed', label="FactorGAN-no-cp, $\lambda=0.9$", alpha=0.8) 30 | 31 | # PLOT 32 | # Sort legend 33 | handles, labels = ax.get_legend_handles_labels() 34 | handles = handles[2:] + handles[0:2] 35 | labels = labels[2:] + labels[0:2] 36 | ax.legend(handles, labels) # bbox_to_anchor=(1.04,0.5), loc="center left", borderaxespad=0) 37 | 38 | #ax.get_legend().remove() 39 | fig.tight_layout() 40 | fig.savefig("mnist_fid.pdf") -------------------------------------------------------------------------------- /analysis/mnist/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/mnist/__init__.py -------------------------------------------------------------------------------- /analysis/mnist/mnist_dep.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/mnist/mnist_dep.pdf -------------------------------------------------------------------------------- /analysis/mnist/mnist_examples.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/mnist/mnist_examples.pdf -------------------------------------------------------------------------------- /analysis/mnist/mnist_fid.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/mnist/mnist_fid.pdf -------------------------------------------------------------------------------- /analysis/source_separation/PlotSDR.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Plot SDR values for source separation experiment 3 | ''' 4 | 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | import seaborn as sns 8 | 9 | sns.set(font_scale=1.1, rc={'text.usetex' : True}) 10 | 11 | df = pd.read_csv("sdr.csv", delimiter=",") 12 | 13 | g = sns.catplot("Songs", "Mean SDR", hue="Model", col="Source", data=df, kind="bar", height=3, aspect=2, sharey=False) 14 | 15 | g.axes[0][0].set_title("Vocals") 16 | g.axes[0][1].set_title("Accompaniment") 17 | 18 | plt.savefig("sdr.pdf") -------------------------------------------------------------------------------- /analysis/source_separation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/source_separation/__init__.py -------------------------------------------------------------------------------- /analysis/source_separation/sdr.csv: -------------------------------------------------------------------------------- 1 | Songs,Model,Median SDR,MAD SDR,Mean SDR,SDR Std,Source 2 | 10,GAN,1.983965,1.45067,0.0410285857142858,10.689719135518489,Vocals 3 | 20,GAN,2.08704,1.42766,0.2823284285714284,9.80545661853073,Vocals 4 | 50,GAN,1.4713150000000002,1.2996700000000003,0.07224964285714285,9.248452231130631,Vocals 5 | 10,FactorGAN,3.00948,1.645375,1.7975211571428573,8.113460729156188,Vocals 6 | 20,FactorGAN,3.156565,1.6046550000000002,1.8749713857142856,8.30691717643957,Vocals 7 | 50,FactorGAN,3.4388,1.559725,2.0191757,8.65596928418218,Vocals 8 | 10,GAN,6.17731,1.540635,6.183414271428571,2.625664373284633,Accompaniment 9 | 20,GAN,5.991875,1.5996899999999998,6.233455328571428,2.8197917995690553,Accompaniment 10 | 50,GAN,5.7209900000000005,1.7076749999999996,5.969218571428572,2.8519183877568577,Accompaniment 11 | 10,FactorGAN,6.679995,1.7517950000000004,6.891838142857144,2.9684880821473953,Accompaniment 12 | 20,FactorGAN,6.799595,1.7324549999999999,6.9449088,2.9332867432936025,Accompaniment 13 | 50,FactorGAN,7.022645,1.7381000000000002,7.1121528,2.8895338557354373,Accompaniment -------------------------------------------------------------------------------- /analysis/source_separation/sdr.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/source_separation/sdr.pdf -------------------------------------------------------------------------------- /datasets/AudioSeparationDataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path 4 | from multiprocessing import Pool 5 | 6 | import musdb 7 | import numpy as np 8 | import soundfile 9 | import torch 10 | from torch.utils.data.dataset import Dataset 11 | 12 | import Utils 13 | 14 | def preprocess_song(item): 15 | idx, song, out_path, sample_rate, input_width, mode, fft_size, hop_size = item 16 | 17 | if not os.path.exists(os.path.join(out_path, str(idx))): 18 | os.makedirs(os.path.join(out_path, str(idx))) 19 | 20 | length = np.Inf 21 | if mode == "paired" or mode == "mix": 22 | mix_audio, _ = Utils.load(song["mix"], sr=sample_rate, mono=True) 23 | mix, _ = Utils.compute_spectrogram(np.squeeze(mix_audio, 1), fft_size, hop_size) 24 | length = min(mix.shape[1], length) 25 | 26 | if mode == "paired" or mode == "accompaniment": 27 | accompaniment_audio, _ = Utils.load(song["accompaniment"], sr=sample_rate, mono=True) 28 | accompaniment, _ = Utils.compute_spectrogram(np.squeeze(accompaniment_audio, 1), fft_size, hop_size) 29 | length = min(accompaniment.shape[1], length) 30 | 31 | if mode == "paired" or mode == "vocals": 32 | vocals_audio, _ = Utils.load(song["vocals"], sr=sample_rate, mono=True) 33 | vocals, _ = Utils.compute_spectrogram(np.squeeze(vocals_audio, 1), fft_size, hop_size) 34 | length = min(vocals.shape[1], length) 35 | 36 | sample_num = 0 37 | for start_pos in range(0, length - input_width, input_width // 2): 38 | sample = list() 39 | if mode == "paired" or mode == "mix": 40 | sample.append(mix[:, start_pos:start_pos + input_width]) 41 | 42 | if mode == "paired" or mode == "accompaniment": 43 | sample.append(accompaniment[:, start_pos:start_pos + input_width]) 44 | 45 | if mode == "paired" or mode == "vocals": 46 | sample.append(vocals[:, start_pos:start_pos + input_width]) 47 | 48 | # Write current snippet 49 | sample = np.stack(sample, axis=0) 50 | np.save(os.path.join(out_path, str(idx), str(sample_num) + ".npy"), sample) 51 | sample_num += 1 52 | 53 | class MUSDBDataset(Dataset): 54 | def __init__(self, opt, song_idx, mode): 55 | self.opt = opt 56 | self.mode = mode 57 | # Load MUSDB/convert to wav 58 | dataset = getMUSDB(opt.musdb_path)[0] 59 | 60 | self.out_path = os.path.join(opt.preprocessed_dataset_path, mode) 61 | 62 | if not os.path.exists(self.out_path): 63 | # Preprocess audio into spectrogram and write into each sample into a numpy file 64 | p = Pool(10) #multiprocessing.cpu_count()) 65 | p.map(preprocess_song, [(curr_song_idx, song, self.out_path, opt.sample_rate, opt.input_width, mode, opt.fft_size, opt.hop_size) for curr_song_idx, song in enumerate(dataset)]) 66 | 67 | # Select songs to use for training 68 | file_list = list() 69 | for idx in song_idx: 70 | npy_files = glob.glob(os.path.join(self.out_path, str(idx), "*.npy")) 71 | file_list.extend(npy_files) 72 | 73 | self.dataset = file_list 74 | 75 | def __getitem__(self, index): 76 | return self.npy_loader(self.dataset[index]) 77 | 78 | def __len__(self): 79 | return len(self.dataset) 80 | 81 | def npy_loader(self, path): 82 | sample = torch.from_numpy(np.load(path)) 83 | return sample 84 | 85 | def getMUSDB(database_path): 86 | mus = musdb.DB(root_dir=database_path, is_wav=False) 87 | 88 | subsets = list() 89 | 90 | for subset in ["train", "test"]: 91 | tracks = mus.load_mus_tracks(subset) 92 | samples = list() 93 | 94 | # Go through tracks 95 | for track in tracks: 96 | # Skip track if mixture is already written, assuming this track is done already 97 | track_path = track.path[:-4] 98 | mix_path = track_path + "_mix.wav" 99 | acc_path = track_path + "_accompaniment.wav" 100 | if os.path.exists(mix_path): 101 | print("WARNING: Skipping track " + mix_path + " since it exists already") 102 | 103 | # Add paths and then skip 104 | paths = {"mix" : mix_path, "accompaniment" : acc_path} 105 | paths.update({key : track_path + "_" + key + ".wav" for key in ["bass", "drums", "other", "vocals"]}) 106 | 107 | samples.append(paths) 108 | 109 | continue 110 | 111 | rate = track.rate 112 | 113 | # Go through each instrument 114 | paths = dict() 115 | stem_audio = dict() 116 | for stem in ["bass", "drums", "other", "vocals"]: 117 | path = track_path + "_" + stem + ".wav" 118 | audio = track.targets[stem].audio 119 | soundfile.write(path, audio, rate, "PCM_16") 120 | stem_audio[stem] = audio 121 | paths[stem] = path 122 | 123 | # Add other instruments to form accompaniment 124 | acc_audio = np.clip(sum([stem_audio[key] for key in stem_audio.keys() if key != "vocals"]), -1.0, 1.0) 125 | soundfile.write(acc_path, acc_audio, rate, "PCM_16") 126 | paths["accompaniment"] = acc_path 127 | 128 | # Create mixture 129 | mix_audio = track.audio 130 | soundfile.write(mix_path, mix_audio, rate, "PCM_16") 131 | paths["mix"] = mix_path 132 | 133 | diff_signal = np.abs(mix_audio - acc_audio - stem_audio["vocals"]) 134 | print("Maximum absolute deviation from source additivity constraint: " + str(np.max(diff_signal)))# Check if acc+vocals=mix 135 | print("Mean absolute deviation from source additivity constraint: " + str(np.mean(diff_signal))) 136 | 137 | samples.append(paths) 138 | 139 | subsets.append(samples) 140 | 141 | return subsets -------------------------------------------------------------------------------- /datasets/CropDataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import Dataset 2 | 3 | class CropDataset(Dataset): 4 | def __init__(self, dataset, crop_func): 5 | self.dataset = dataset 6 | self.crop_func = crop_func 7 | 8 | def __getitem__(self, index): 9 | sample = self.dataset[index] 10 | return self.crop_func(sample) 11 | 12 | def __len__(self): 13 | return len(self.dataset) -------------------------------------------------------------------------------- /datasets/DoubleMNISTDataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import Dataset 2 | import torch 3 | import numpy as np 4 | 5 | class DoubleMNISTDataset(Dataset): 6 | def __init__(self, mnist_dataset, idx_range, deterministic=True, same_digit_prob=0.9, return_labels=False): 7 | if idx_range is None: 8 | idx_range = range(len(mnist_dataset.data)) 9 | self.data = mnist_dataset.data[idx_range].clone().detach().float() / 255.0 10 | self.labels = mnist_dataset.targets[idx_range].clone().detach() 11 | 12 | self.idx_range = range(len(idx_range)) 13 | 14 | self.same_digit_prob = same_digit_prob 15 | self.deterministic = deterministic 16 | 17 | self.idx_digits = list() 18 | self.idx_digits_proportion = list() 19 | for digit_num in range(10): 20 | indices = [i for i in self.idx_range if self.labels[i] == digit_num] 21 | self.idx_digits.append(indices) 22 | self.idx_digits_proportion.append(float(len(indices))) 23 | self.idx_digits_proportion /= np.sum(self.idx_digits_proportion) 24 | 25 | self.second_digits = np.full(len(self.idx_range), -1) 26 | 27 | self.return_labels = return_labels 28 | 29 | def __getitem__(self, index): 30 | first_digit, first_digit_label = self.data[index], self.labels[index] 31 | if self.deterministic and self.second_digits[index] != -1: 32 | second_digit_idx = self.second_digits[index] 33 | else: 34 | # Draw same digit with given higher probability 35 | same_digit = (np.random.rand() < self.same_digit_prob) 36 | 37 | if same_digit: 38 | # Draw from pool of same-digit numbers 39 | second_digit_label = first_digit_label 40 | else: 41 | # Draw from pool of different-digit numbers 42 | second_digit_label = np.random.choice(range(10), p=self.idx_digits_proportion) 43 | second_digit_idx = self.idx_digits[second_digit_label.item()][np.random.randint(0, len(self.idx_digits[second_digit_label]))] 44 | 45 | # Finally, save second digit choice if deterministic 46 | if self.deterministic: 47 | self.second_digits[index] = second_digit_idx 48 | 49 | sample = torch.cat([first_digit.view(-1), self.data[second_digit_idx].view(-1)]) 50 | if self.return_labels: 51 | return (sample, first_digit_label, second_digit_label) 52 | else: 53 | return sample 54 | 55 | 56 | def __len__(self): 57 | return len(self.idx_range) 58 | 59 | def get_digits(self, digit_label): 60 | return sum([self.idx_digits[d] for d in range(10) if d != digit_label], []) -------------------------------------------------------------------------------- /datasets/GeneratorInputDataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import Dataset 2 | import numpy as np 3 | 4 | class GeneratorInputDataset(Dataset): 5 | def __init__(self, cond_dataset, noise_distribution): 6 | assert(cond_dataset != None or noise_distribution != None) 7 | self.cond_dataset = cond_dataset 8 | self.noise = noise_distribution 9 | 10 | def __getitem__(self, index): 11 | output = list() 12 | 13 | if self.cond_dataset != None: 14 | # Sample input for generator from other dataset 15 | output.append(self.cond_dataset[np.random.randint(0, len(self.cond_dataset))]) 16 | 17 | if self.noise != None: 18 | # Sample noise from noise_distribution, if we are using noise in the conditional generator network 19 | output.append(self.noise.sample()) 20 | 21 | # Get generator input 22 | return output 23 | 24 | def __len__(self): 25 | if self.cond_dataset != None: 26 | return len(self.cond_dataset) 27 | else: 28 | return 10000 -------------------------------------------------------------------------------- /datasets/InfiniteDataSampler.py: -------------------------------------------------------------------------------- 1 | class InfiniteDataSampler(object): 2 | def __init__(self, data_loader): 3 | ''' 4 | Can be used to infinitely loop over a dataset 5 | :param data_loader: Data loader object for a dataset 6 | ''' 7 | self.data_loader = data_loader 8 | 9 | if not hasattr(data_loader, "__next__"): 10 | self.data_iter = iter(data_loader) 11 | else: 12 | self.data_iter = data_loader 13 | 14 | def next(self): 15 | return self.__next__() 16 | 17 | def __next__(self): 18 | try: 19 | data = next(self.data_iter) 20 | except StopIteration: 21 | # StopIteration is thrown if dataset ends 22 | # reinitialize data loader 23 | self.data_iter = iter(self.data_loader) 24 | data = next(self.data_iter) 25 | 26 | return data -------------------------------------------------------------------------------- /datasets/TransformDataSampler.py: -------------------------------------------------------------------------------- 1 | class TransformDataSampler(object): 2 | def __init__(self, data_loader, transform, transform_device): 3 | # create dataloader-iterator 4 | self.data_loader = data_loader 5 | 6 | if not hasattr(data_loader, "__next__"): 7 | self.data_iter = iter(data_loader) 8 | else: 9 | self.data_iter = data_loader 10 | 11 | self.transform = transform 12 | self.device = transform_device 13 | 14 | def next(self): 15 | return self.__next__() 16 | 17 | def __next__(self): 18 | data = next(self.data_iter) 19 | 20 | # Put transform input to proper device first 21 | if self.device != None: 22 | if isinstance(data, list): 23 | data = [item.to(self.device) for item in data] 24 | else: 25 | data = data.to(self.device) 26 | 27 | # Transform batch of samples 28 | data = self.transform(data).detach() 29 | 30 | return data -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/image2image/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from datasets.image2image.base_dataset import BaseDataset 3 | import os 4 | 5 | def get_aligned_dataset(opt, subset): 6 | opt.dataset_mode = "aligned" 7 | opt.dataroot = os.path.join("datasets", "image2image", opt.dataset) 8 | opt.phase = subset 9 | opt.direction = "AtoB" 10 | opt.no_flip = True 11 | dataset = create_dataset(opt) 12 | return dataset 13 | 14 | def find_dataset_using_name(dataset_name): 15 | # Given the option --dataset_mode [datasetname], 16 | # the file "data/datasetname_dataset.py" 17 | # will be imported. 18 | dataset_filename = "datasets.image2image." + dataset_name + "_dataset" 19 | datasetlib = importlib.import_module(dataset_filename) 20 | 21 | # In the file, the class called DatasetNameDataset() will 22 | # be instantiated. It has to be a subclass of BaseDataset, 23 | # and it is case-insensitive. 24 | dataset = None 25 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 26 | for name, cls in datasetlib.__dict__.items(): 27 | if name.lower() == target_dataset_name.lower() \ 28 | and issubclass(cls, BaseDataset): 29 | dataset = cls 30 | 31 | if dataset is None: 32 | print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 33 | exit(0) 34 | 35 | return dataset 36 | 37 | def get_option_setter(dataset_name): 38 | dataset_class = find_dataset_using_name(dataset_name) 39 | return dataset_class.modify_commandline_options 40 | 41 | 42 | def create_dataset(opt): 43 | dataset = find_dataset_using_name(opt.dataset_mode) 44 | instance = dataset() 45 | instance.initialize(opt) 46 | 47 | # Add fixed no. of channel information for A and B 48 | if opt.dataset == "edges2shoes": 49 | dataset.A_nc = 1 50 | dataset.B_nc = 3 51 | else: 52 | dataset.A_nc = 3 53 | dataset.B_nc = 3 54 | 55 | print("dataset [%s] was created" % (instance.name())) 56 | return instance -------------------------------------------------------------------------------- /datasets/image2image/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | import torchvision.transforms as transforms 4 | import torch 5 | from datasets.image2image.base_dataset import BaseDataset 6 | from datasets.image2image.image_folder import make_dataset 7 | from PIL import Image 8 | 9 | 10 | class AlignedDataset(BaseDataset): 11 | @staticmethod 12 | def modify_commandline_options(parser, is_train): 13 | return parser 14 | 15 | def initialize(self, opt): 16 | self.opt = opt 17 | self.root = opt.dataroot 18 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) 19 | self.AB_paths = sorted(make_dataset(self.dir_AB)) 20 | #assert(opt.resize_or_crop == 'resize_and_crop') 21 | 22 | def __getitem__(self, index): 23 | AB_path = self.AB_paths[index] 24 | AB = Image.open(AB_path).convert('RGB') 25 | w, h = AB.size 26 | #assert(self.opt.loadSize >= self.opt.fineSize) 27 | w2 = int(w / 2) 28 | A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) 29 | B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) 30 | A = transforms.ToTensor()(A) 31 | B = transforms.ToTensor()(B) 32 | 33 | if self.opt.direction == 'BtoA': 34 | input_nc = self.B_nc 35 | output_nc = self.A_nc 36 | else: 37 | input_nc = self.A_nc 38 | output_nc = self.B_nc 39 | 40 | if (not self.opt.no_flip) and random.random() < 0.5: 41 | idx = [i for i in range(A.size(2) - 1, -1, -1)] 42 | idx = torch.LongTensor(idx) 43 | A = A.index_select(2, idx) 44 | B = B.index_select(2, idx) 45 | 46 | if input_nc == 1: # RGB to gray 47 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 48 | A = tmp.unsqueeze(0) 49 | 50 | if output_nc == 1: # RGB to gray 51 | tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 52 | B = tmp.unsqueeze(0) 53 | 54 | #return {'A': A, 'B': B, 55 | # 'A_paths': AB_path, 'B_paths': AB_path} 56 | #print(time.time() - t) 57 | return torch.cat([A,B]) 58 | 59 | def __len__(self): 60 | return len(self.AB_paths) 61 | 62 | def name(self): 63 | return 'AlignedDataset' 64 | -------------------------------------------------------------------------------- /datasets/image2image/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | 5 | 6 | class BaseDataset(data.Dataset): 7 | def __init__(self): 8 | super(BaseDataset, self).__init__() 9 | 10 | def name(self): 11 | return 'BaseDataset' 12 | 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train): 15 | return parser 16 | 17 | def initialize(self, opt): 18 | pass 19 | 20 | def __len__(self): 21 | return 0 22 | 23 | 24 | def get_transform(opt): 25 | osize = [opt.loadSize, opt.loadSize] 26 | transform_list = [transforms.Resize(osize, Image.BICUBIC), transforms.ToTensor()] 27 | return transforms.Compose(transform_list) 28 | 29 | 30 | # just modify the width and height to be multiple of 4 31 | def __adjust(img): 32 | ow, oh = img.size 33 | 34 | # the size needs to be a multiple of this number, 35 | # because going through generator network may change img size 36 | # and eventually cause size mismatch error 37 | mult = 4 38 | if ow % mult == 0 and oh % mult == 0: 39 | return img 40 | w = (ow - 1) // mult 41 | w = (w + 1) * mult 42 | h = (oh - 1) // mult 43 | h = (h + 1) * mult 44 | 45 | if ow != w or oh != h: 46 | __print_size_warning(ow, oh, w, h) 47 | 48 | return img.resize((w, h), Image.BICUBIC) 49 | 50 | 51 | def __scale_width(img, target_width): 52 | ow, oh = img.size 53 | 54 | # the size needs to be a multiple of this number, 55 | # because going through generator network may change img size 56 | # and eventually cause size mismatch error 57 | mult = 4 58 | assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult 59 | if (ow == target_width and oh % mult == 0): 60 | return img 61 | w = target_width 62 | target_height = int(target_width * oh / ow) 63 | m = (target_height - 1) // mult 64 | h = (m + 1) * mult 65 | 66 | if target_height != h: 67 | __print_size_warning(target_width, target_height, w, h) 68 | 69 | return img.resize((w, h), Image.BICUBIC) 70 | 71 | 72 | def __print_size_warning(ow, oh, w, h): 73 | if not hasattr(__print_size_warning, 'has_printed'): 74 | print("The image size needs to be a multiple of 4. " 75 | "The loaded image size was (%d, %d), so it was adjusted to " 76 | "(%d, %d). This adjustment will be done to all images " 77 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 78 | __print_size_warning.has_printed = True 79 | -------------------------------------------------------------------------------- /datasets/image2image/download_image2image.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | if [[ $FILE != "cityscapes" && $FILE != "night2day" && $FILE != "edges2handbags" && $FILE != "edges2shoes" && $FILE != "facades" && $FILE != "maps" ]]; then 4 | echo "Available datasets are cityscapes, night2day, edges2handbags, edges2shoes, facades, maps" 5 | exit 1 6 | fi 7 | 8 | echo "Specified [$FILE]" 9 | 10 | URL=http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/$FILE.tar.gz 11 | TAR_FILE=./$FILE.tar.gz 12 | TARGET_DIR=./$FILE/ 13 | wget -N $URL -O $TAR_FILE 14 | mkdir -p $TARGET_DIR 15 | tar -zxvf $TAR_FILE -C ./ 16 | rm $TAR_FILE -------------------------------------------------------------------------------- /datasets/image2image/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir): 25 | images = [] 26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | 34 | return images 35 | 36 | 37 | def default_loader(path): 38 | return Image.open(path).convert('RGB') 39 | 40 | 41 | class ImageFolder(data.Dataset): 42 | 43 | def __init__(self, root, transform=None, return_paths=False, 44 | loader=default_loader): 45 | imgs = make_dataset(root) 46 | if len(imgs) == 0: 47 | raise(RuntimeError("Found 0 images in: " + root + "\n" 48 | "Supported image extensions are: " + 49 | ",".join(IMG_EXTENSIONS))) 50 | 51 | self.root = root 52 | self.imgs = imgs 53 | self.transform = transform 54 | self.return_paths = return_paths 55 | self.loader = loader 56 | 57 | def __getitem__(self, index): 58 | path = self.imgs[index] 59 | img = self.loader(path) 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | if self.return_paths: 63 | return img, path 64 | else: 65 | return img 66 | 67 | def __len__(self): 68 | return len(self.imgs) 69 | -------------------------------------------------------------------------------- /eval/Cityscapes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import namedtuple 3 | 4 | # THIS IS TAKEN FROM https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py 5 | #-------------------------------------------------------------------------------- 6 | # Definitions 7 | #-------------------------------------------------------------------------------- 8 | 9 | # a label and all meta information 10 | Label = namedtuple( 'Label' , [ 11 | 12 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... . 13 | # We use them to uniquely name a class 14 | 15 | 'id' , # An integer ID that is associated with this label. 16 | # The IDs are used to represent the label in ground truth images 17 | # An ID of -1 means that this label does not have an ID and thus 18 | # is ignored when creating ground truth images (e.g. license plate). 19 | # Do not modify these IDs, since exactly these IDs are expected by the 20 | # evaluation server. 21 | 22 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 23 | # ground truth images with train IDs, using the tools provided in the 24 | # 'preparation' folder. However, make sure to validate or submit results 25 | # to our evaluation server using the regular IDs above! 26 | # For trainIds, multiple labels might have the same ID. Then, these labels 27 | # are mapped to the same class in the ground truth images. For the inverse 28 | # mapping, we use the label that is defined first in the list below. 29 | # For example, mapping all void-type classes to the same ID in training, 30 | # might make sense for some approaches. 31 | # Max value is 255! 32 | 33 | 'category' , # The name of the category that this label belongs to 34 | 35 | 'categoryId' , # The ID of this category. Used to create ground truth images 36 | # on category level. 37 | 38 | 'hasInstances', # Whether this label distinguishes between single instances or not 39 | 40 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 41 | # during evaluations or not 42 | 43 | 'color' , # The color of this label 44 | ] ) 45 | 46 | 47 | #-------------------------------------------------------------------------------- 48 | # A list of all labels 49 | #-------------------------------------------------------------------------------- 50 | 51 | # Please adapt the train IDs as appropriate for your approach. 52 | # Note that you might want to ignore labels with ID 255 during training. 53 | # Further note that the current train IDs are only a suggestion. You can use whatever you like. 54 | # Make sure to provide your results using the original IDs and not the training IDs. 55 | # Note that many IDs are ignored in evaluation and thus you never need to predict these! 56 | 57 | labels = [ 58 | # name id trainId category catId hasInstances ignoreInEval color 59 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 60 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 61 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 62 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 63 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 64 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 65 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 66 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 67 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 68 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 69 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 70 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 71 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 72 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 73 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 74 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 75 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 76 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 77 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), 78 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 79 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 80 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 81 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 82 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 83 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 84 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 85 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 86 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 87 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 88 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 89 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 90 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 91 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 92 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 93 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), 94 | ] 95 | #END OF CITYSCAPES SCRIPTS 96 | 97 | def get_L2(opt, device, G, val_data, G_noise): 98 | ''' 99 | Compute average L2 loss for a generator model on validation data 100 | :param opt: Option dict 101 | :param device: Device to use 102 | :param G: Generator model 103 | :param val_data: Validation data 104 | :param G_noise: Noise input to generator 105 | :return: Average L2 loss over the whole validation data 106 | ''' 107 | G.eval() 108 | batch_losses = list() 109 | for val_batch in val_data: 110 | # Get next batch and sample generator noise 111 | val_batch = val_batch.to(device) 112 | noise = G_noise.sample_n(opt.batchSize).to(device) 113 | 114 | # Get fake samples from generator 115 | fake_sample = G([val_batch[:,:3,:,:], noise]) 116 | 117 | # Evaluate outputs 118 | L2 = ((fake_sample - val_batch)[:,3:,:,:] ** 2).mean() 119 | batch_losses.append(L2.detach().cpu().numpy()) 120 | return np.mean(np.array(batch_losses)) 121 | 122 | def get_pixel_acc(opt, device, G, val_data, G_noise): 123 | ''' 124 | Get pixel-wise classification accuracy for Cityscapes dataset 125 | :param opt: Option dictionary 126 | :param device: Device to use 127 | :param G: Generator model 128 | :param val_data: Validation dataset 129 | :param G_noise: Generator input noise 130 | :return: Average pixel-wise accuracy 131 | ''' 132 | all_labels_colors = np.array([l.color for l in labels]) 133 | 134 | eval_labels = [l for l in labels if not l.ignoreInEval] 135 | eval_labels_colors = np.array([l.color for l in eval_labels]) 136 | 137 | G.eval() 138 | batch_losses = list() 139 | for val_batch in val_data: 140 | label_batch = val_batch[:,3:,:,:].detach().cpu().numpy() * 255.0 141 | 142 | val_batch = val_batch.to(device) 143 | noise = G_noise.sample_n(opt.batchSize).to(device) 144 | 145 | # Get fake samples from generator 146 | fake_sample = G([val_batch[:,:3,:,:], noise]).detach().cpu().numpy() 147 | fake_pred = fake_sample[:,3:,:,:] * 255.0 148 | 149 | # EVALUATE MODEL OUTPUTS 150 | # Assign colour of closest label to the prediction image 151 | dist_imag = list() 152 | for l in eval_labels: 153 | label_imag = np.zeros(fake_pred.shape) 154 | label_imag[:,0,:,:] = l.color[0] 155 | label_imag[:, 1, :, :] = l.color[1] 156 | label_imag[:, 2, :, :] = l.color[2] 157 | dist_imag.append(np.sqrt(np.sum(np.square(fake_pred - label_imag), axis=1))) 158 | 159 | dist_imag = np.array(dist_imag) 160 | pred_idx = np.argmin(dist_imag, axis=0) 161 | pred_sample = np.transpose(eval_labels_colors[pred_idx], [0, 3, 1, 2]).astype(np.float32) 162 | 163 | # Assign colour of closest label to label image (this is necessary due to image compression artifacts) 164 | dist_imag = list() 165 | for l in labels: 166 | label_imag = np.zeros(fake_pred.shape) 167 | label_imag[:, 0, :, :] = l.color[0] 168 | label_imag[:, 1, :, :] = l.color[1] 169 | label_imag[:, 2, :, :] = l.color[2] 170 | dist_imag.append(np.sqrt(np.sum(np.square(label_batch - label_imag), axis=1))) 171 | 172 | dist_imag = np.array(dist_imag) 173 | label_idx = np.argmin(dist_imag, axis=0) 174 | label_batch = np.transpose(all_labels_colors[label_idx], [0, 3, 1, 2]).astype(np.float32) 175 | 176 | # Create a mask to remove those pixels from evaluation with a ground truth label that is not part of the evaluation 177 | valid_mask = np.ones([pred_sample.shape[0], pred_sample.shape[2], pred_sample.shape[3]]) 178 | for l in labels: # Go through all ignored labels 179 | if l.ignoreInEval: 180 | # Create an image with just one colour in it equal to the colour of this label 181 | label_imag = np.zeros(fake_pred.shape) 182 | label_imag[:, 0, :, :] = l.color[0] 183 | label_imag[:, 1, :, :] = l.color[1] 184 | label_imag[:, 2, :, :] = l.color[2] 185 | 186 | # Set mask to 0 where the label image has the same colour as the label to ignore 187 | valid_mask[np.sum(np.abs(label_batch - label_imag), axis=1) < 1e-5] = 0.0 188 | 189 | # Compute accuracy 190 | correct_preds = np.sum(np.abs(pred_sample - label_batch), axis=1) < 1e-5 191 | correct = np.sum(np.logical_and(correct_preds, valid_mask != 0.0), axis=(1,2)) 192 | total = np.sum(valid_mask != 0.0, axis=(1,2)) 193 | 194 | batch_losses.append(np.mean(correct.astype(np.float32) / total.astype(np.float32))) 195 | return np.mean(np.array(batch_losses)) -------------------------------------------------------------------------------- /eval/FID.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Calculates the Frechet Inception Distance (FID) to evalulate GANs 3 | The FID metric calculates the distance between two distributions of images. 4 | Typically, we have summary statistics (mean & covariance matrix) of one 5 | of these distributions, while the 2nd distribution is given by a GAN. 6 | When run as a stand-alone program, it compares the distribution of 7 | images that are stored as PNG/JPEG at a specified location with a 8 | distribution given by summary statistics (in pickle format). 9 | The FID is calculated by assuming that X_1 and X_2 are the activations of 10 | the pool_3 layer of the inception net for generated samples and real world 11 | samples respectivly. 12 | See --help to see further details. 13 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead of Tensorflow 14 | 15 | Copyright 2018 Institute of Bioinformatics, JKU Linz 16 | Licensed under the Apache License, Version 2.0 (the "License"); 17 | you may not use this file except in compliance with the License. 18 | You may obtain a copy of the License at 19 | http://www.apache.org/licenses/LICENSE-2.0 20 | Unless required by applicable law or agreed to in writing, software 21 | distributed under the License is distributed on an "AS IS" BASIS, 22 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 23 | See the License for the specific language governing permissions and 24 | limitations under the License. 25 | ''' 26 | 27 | import glob 28 | import os 29 | 30 | import torch 31 | import numpy as np 32 | from imageio import imread 33 | from scipy import linalg 34 | from torch.utils.data import TensorDataset, DataLoader 35 | 36 | def get_activations(dataset, model, device, crop=None, verbose=False): 37 | """Calculates the activations of the pool_3 layer for all images. 38 | Params: 39 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 40 | must lie between 0 and 1. 41 | -- model : Instance of inception model 42 | -- batch_size : the images numpy array is split into batches with 43 | batch size batch_size. A reasonable batch size depends 44 | on the hardware. 45 | -- dims : Dimensionality of features returned by Inception 46 | -- cuda : If set to True, use GPU 47 | -- verbose : If set to True and parameter out_step is given, the number 48 | of calculated batches is reported. 49 | Returns: 50 | -- A numpy array of dimension (num images, dims) that contains the 51 | activations of the given tensor when feeding inception with the 52 | query tensor. 53 | """ 54 | model.eval() 55 | 56 | # Compute activations for real data 57 | act = list() 58 | for batch in dataset: 59 | if isinstance(batch, list) or isinstance(batch, tuple): 60 | assert(len(batch) == 1) 61 | batch = batch[0] 62 | 63 | batch = batch.view((batch.shape[0], 1, -1, 28)).to(device) 64 | # Crop batch samples if desired 65 | if crop != None: 66 | batch = crop(batch) 67 | # Get classifier activation 68 | act.append(model(batch, return_hidden=True).detach().cpu().numpy()) 69 | act = np.concatenate(act) 70 | act = np.reshape(act, (-1, act.shape[-1])) 71 | 72 | if verbose: 73 | print(' done') 74 | 75 | return act 76 | 77 | 78 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 79 | """Numpy implementation of the Frechet Distance. 80 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 81 | and X_2 ~ N(mu_2, C_2) is 82 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 83 | Stable version by Dougal J. Sutherland. 84 | Params: 85 | -- mu1 : Numpy array containing the activations of a layer of the 86 | inception net (like returned by the function 'get_predictions') 87 | for generated samples. 88 | -- mu2 : The sample mean over activations, precalculated on an 89 | representive data set. 90 | -- sigma1: The covariance matrix over activations for generated samples. 91 | -- sigma2: The covariance matrix over activations, precalculated on an 92 | representive data set. 93 | Returns: 94 | -- : The Frechet Distance. 95 | """ 96 | 97 | mu1 = np.atleast_1d(mu1) 98 | mu2 = np.atleast_1d(mu2) 99 | 100 | sigma1 = np.atleast_2d(sigma1) 101 | sigma2 = np.atleast_2d(sigma2) 102 | 103 | assert mu1.shape == mu2.shape, \ 104 | 'Training and test mean vectors have different lengths' 105 | assert sigma1.shape == sigma2.shape, \ 106 | 'Training and test covariances have different dimensions' 107 | 108 | diff = mu1 - mu2 109 | 110 | # Product might be almost singular 111 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 112 | if not np.isfinite(covmean).all(): 113 | msg = ('fid calculation produces singular product; ' 114 | 'adding %s to diagonal of cov estimates') % eps 115 | print(msg) 116 | offset = np.eye(sigma1.shape[0]) * eps 117 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 118 | 119 | # Numerical error might give slight imaginary component 120 | if np.iscomplexobj(covmean): 121 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 122 | m = np.max(np.abs(covmean.imag)) 123 | raise ValueError('Imaginary component {}'.format(m)) 124 | covmean = covmean.real 125 | 126 | tr_covmean = np.trace(covmean) 127 | 128 | return (diff.dot(diff) + np.trace(sigma1) + 129 | np.trace(sigma2) - 2 * tr_covmean) 130 | 131 | 132 | def calculate_activation_statistics(dataset, model, device, crop, verbose=False): 133 | """Calculation of the statistics used by the FID. 134 | Params: 135 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 136 | must lie between 0 and 1. 137 | -- model : Instance of inception model 138 | -- batch_size : The images numpy array is split into batches with 139 | batch size batch_size. A reasonable batch size 140 | depends on the hardware. 141 | -- dims : Dimensionality of features returned by Inception 142 | -- cuda : If set to True, use GPU 143 | -- verbose : If set to True and parameter out_step is given, the 144 | number of calculated batches is reported. 145 | Returns: 146 | -- mu : The mean over samples of the activations of the pool_3 layer of 147 | the inception model. 148 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 149 | the inception model. 150 | """ 151 | act = get_activations(dataset, model, device, crop, verbose) 152 | mu = np.mean(act, axis=0) 153 | sigma = np.cov(act, rowvar=False) 154 | return mu, sigma 155 | 156 | 157 | def _compute_statistics_of_path(path, model, batch_size, device, crop): 158 | if path.endswith('.npz'): 159 | f = np.load(path) 160 | m, s = f['mu'][:], f['sigma'][:] 161 | f.close() 162 | else: 163 | files = glob.glob(os.path.join(path, '*.jpg')) + glob.glob(os.path.join(path, '*.png')) 164 | 165 | imgs = np.array([np.mean(imread(str(fn)).astype(np.float32), axis=2) / 255.0 for fn in files]) 166 | imgs = torch.stack([torch.Tensor(i) for i in imgs]) 167 | 168 | dataset = TensorDataset(imgs) 169 | dataset = DataLoader(dataset, batch_size) 170 | 171 | m, s = calculate_activation_statistics(dataset, model, device, crop) 172 | 173 | return m, s 174 | 175 | def calculate_fid_given_paths(model, paths, batch_size, cuda): 176 | """Calculates the FID of two paths""" 177 | for p in paths: 178 | if not os.path.exists(p): 179 | raise RuntimeError('Invalid path: %s' % p) 180 | 181 | m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, 182 | cuda) 183 | m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, 184 | cuda) 185 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 186 | 187 | return fid_value 188 | 189 | def evaluate_MNIST(opt, classifier, real_dataset, generated_path, device, crop_real=None, crop_fake=None): 190 | # Check if we have the real classifier activations pre-computed and saved already 191 | if os.path.exists("MNIST_classifier_stats.npz"): 192 | # Load saved activations for real data 193 | real_stats = np.load("MNIST_classifier_stats.npz") 194 | m1, s1 = real_stats['m'], real_stats['s'] 195 | else: 196 | real_activations = get_activations(real_dataset, classifier, device, crop_real) 197 | 198 | # Compute statistics of activations and save 199 | m1 = np.mean(real_activations, axis=0) 200 | s1 = np.cov(real_activations, rowvar=False) 201 | np.savez("MNIST_classifier_stats.npz", m=m1, s=s1) 202 | 203 | m2, s2 = _compute_statistics_of_path(generated_path, classifier, opt.batchSize, device, crop_fake) 204 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 205 | 206 | return fid_value -------------------------------------------------------------------------------- /eval/LS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def lsgan_loss(model, real_input, fake_input, device): 5 | ''' 6 | Compute LS-GAN loss function 7 | :param model: LS Discriminator model 8 | :param real_input: Real input data 9 | :param fake_input: Fake input data 10 | :param device: Device to use 11 | :return: LS loss 12 | ''' 13 | real_output = model(real_input.to(device)) 14 | fake_output = model(fake_input.to(device)) 15 | 16 | a = 0.0 17 | b = 1.0 18 | 19 | return ((real_output - b)**2).mean() + ((fake_output - a)**2).mean() 20 | 21 | def train_lsgan(model, real_data_loader, fake_data_loader, device): 22 | ''' 23 | Trains LS discriminator model on real and fake data 24 | :param model: LS discriminator model 25 | :param real_data_loader: Real training data 26 | :param fake_data_loader: Fake training data 27 | :param device: Device to use 28 | ''' 29 | model.train() 30 | NUM_EPOCHS = 40 31 | LR = 1e-4 32 | 33 | optim = torch.optim.Adam(model.parameters(), lr=LR) 34 | 35 | for epoch in range(NUM_EPOCHS): 36 | print("Epoch " + str(epoch)) 37 | for real_batch in real_data_loader: 38 | fake_batch = next(fake_data_loader) 39 | 40 | optim.zero_grad() 41 | loss = lsgan_loss(model, real_batch, fake_batch, device) 42 | print(loss.item()) 43 | loss.backward() 44 | optim.step() 45 | 46 | print("Finished training LSGAN Disc") 47 | 48 | def lsgan_test_loss(model, real_data_loader, fake_data_loader, device): 49 | ''' 50 | Obtains LS distance for pre-trained LS discriminator on test set 51 | :param model: Pre-trained LS discriminator 52 | :param real_data_loader: Real test data 53 | :param fake_data_loader: Fake test data 54 | :param device: Device to use 55 | :return: LS distance 56 | ''' 57 | model.eval() 58 | 59 | batch_losses = list() 60 | batch_weights = list() 61 | for real_batch in real_data_loader: 62 | fake_batch = next(fake_data_loader) 63 | 64 | batch_losses.append(lsgan_loss(model, real_batch, fake_batch, device).item()) 65 | batch_weights.append(real_batch.shape[0]) 66 | 67 | total_loss = np.sum(np.array(batch_losses) * np.array(batch_weights)) / np.sum(np.array(batch_weights)) 68 | print("LS: " + str(total_loss)) 69 | return total_loss 70 | 71 | def compute_ls_metric(classifier_factory, real_train, real_test, generator_sampler, repetitions, device): 72 | ''' 73 | Computes LS metric for a generator model for evaluation 74 | :param classifier_factory: Function yielding a newly initialized (different each time!) LS classifier when called 75 | :param real_train: Real "test-train" dataset for training the LS discriminator 76 | :param real_test: Real "test-test" dataset for measuring LS distance after training LS discriminator 77 | :param generator_sampler: Generator output dataset 78 | :param repetitions: Number of times the classifier should be trained 79 | :param device: Device to use 80 | :return: List of LS distance metrics obtained for each training run (length "repetitions") 81 | ''' 82 | losses = list() 83 | for _ in range(repetitions): 84 | classifier = classifier_factory() 85 | train_lsgan(classifier, real_train, generator_sampler, device) 86 | losses.append(lsgan_test_loss(classifier, real_test, generator_sampler, device)) 87 | del classifier 88 | return losses -------------------------------------------------------------------------------- /eval/SourceSeparation.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import musdb 3 | import museval 4 | import numpy as np 5 | import torch 6 | import glob 7 | import os 8 | import json 9 | 10 | import Utils 11 | 12 | 13 | def produce_musdb_source_estimates(model_config, model, model_noise, output_path, subsets=None): 14 | ''' 15 | Predicts source estimates for MUSDB for a given model checkpoint and configuration, and evaluate them. 16 | :param model_config: Model configuration of the model to be evaluated 17 | :param load_model: Model checkpoint path 18 | :return: 19 | ''' 20 | print("Evaluating trained model on MUSDB and saving source estimate audio to " + str(output_path)) 21 | model.eval() 22 | 23 | mus = musdb.DB(root_dir=model_config.musdb_path) 24 | predict_fun = lambda track : predict(track, model_config, model, model_noise, output_path) 25 | assert(mus.test(predict_fun)) 26 | mus.run(predict_fun, estimates_dir=output_path, subsets=subsets) 27 | 28 | 29 | def predict(track, model_config, model, model_noise, results_dir=None): 30 | ''' 31 | Function in accordance with MUSB evaluation API. Takes MUSDB track object and computes corresponding source estimates, as well as calls evlauation script. 32 | Model has to be saved beforehand into a pickle file containing model configuration dictionary and checkpoint path! 33 | :param track: Track object 34 | :param results_dir: Directory where SDR etc. values should be saved 35 | :return: Source estimates dictionary 36 | ''' 37 | 38 | # Get noise once, use that for all predictions to keep consistency 39 | noise = model_noise.sample() 40 | 41 | # Determine input and output shapes, if we use U-net as separator 42 | sep_input_shape = [1, 1, model_config.input_height, model_config.input_width] # [N, C, H, W] 43 | 44 | print("Testing...") 45 | 46 | mix_audio, orig_sr, mix_channels = track.audio, track.rate, track.audio.shape[1] # Audio has (n_samples, n_channels) shape 47 | separator_preds = predict_track(model_config, model, noise, mix_audio, orig_sr, sep_input_shape, sep_input_shape) 48 | 49 | # Upsample predicted source audio and convert to stereo. Make sure to resample back to the exact number of samples in the original input (with fractional orig_sr/new_sr this causes issues otherwise) 50 | pred_audio = {name : librosa.resample(separator_preds[name], model_config.sample_rate, orig_sr)[:len(mix_audio)] for name in separator_preds.keys()} 51 | 52 | if mix_channels > 1: # Convert to multichannel if mixture input was multichannel by duplicating mono estimate 53 | pred_audio = {name : np.repeat(np.expand_dims(pred_audio[name], 1), mix_channels, axis=1) for name in pred_audio.keys()} 54 | 55 | # Evaluate using museval, if we are currently evaluating MUSDB 56 | if results_dir is not None: 57 | scores = museval.eval_mus_track(track, pred_audio, output_dir=results_dir, win=15, hop=15.0) 58 | 59 | # print nicely formatted mean scores 60 | print(scores) 61 | 62 | return pred_audio 63 | 64 | 65 | def predict_track(model_config, model, model_noise, mix_audio, mix_sr, sep_input_shape, sep_output_shape): 66 | ''' 67 | Outputs source estimates for a given input mixture signal mix_audio [n_frames, n_channels] 68 | It iterates through the track, collecting segment-wise predictions to form the output. 69 | :param model_config: Model configuration dictionary 70 | :param mix_audio: [n_frames, n_channels] audio signal (numpy array). Can have higher sampling rate or channels than the model supports, will be downsampled correspondingly. 71 | :param mix_sr: Sampling rate of mix_audio 72 | :param sep_input_shape: Input shape of separator 73 | :param sep_output_shape: Input shape of separator 74 | :return: 75 | ''' 76 | # Load mixture, convert to mono and downsample then 77 | assert(len(mix_audio.shape) == 2) 78 | 79 | # Prepare mixture 80 | mix_audio = np.mean(mix_audio, axis=1) 81 | mix_audio = librosa.resample(mix_audio, mix_sr, model_config.sample_rate) 82 | 83 | mix_audio = librosa.util.fix_length(mix_audio, len(mix_audio) + model_config.fft_size // 2) 84 | 85 | # Convert to spectrogram 86 | mix_mags, mix_ph = Utils.compute_spectrogram(mix_audio, model_config.fft_size, model_config.hop_size) 87 | 88 | # Preallocate source predictions (same shape as input mixture) 89 | source_time_frames = mix_mags.shape[1] 90 | source_preds = {name : np.zeros(mix_mags.shape, np.float32) for name in ["accompaniment", "vocals"]} 91 | 92 | input_time_frames = sep_input_shape[3] 93 | output_time_frames = sep_output_shape[3] 94 | 95 | # Iterate over mixture magnitudes, fetch network rpediction 96 | for source_pos in range(0, source_time_frames, output_time_frames): 97 | # If this output patch would reach over the end of the source spectrogram, set it so we predict the very end of the output, then stop 98 | if source_pos + output_time_frames > source_time_frames: 99 | source_pos = source_time_frames - output_time_frames 100 | 101 | # Prepare mixture excerpt by selecting time interval 102 | mix_part = mix_mags[:, source_pos:source_pos + input_time_frames] 103 | mix_part = np.expand_dims(np.expand_dims(mix_part, axis=0), axis=0) 104 | 105 | device = next(model.parameters()).device 106 | source_parts = model([torch.from_numpy(mix_part).to(device), model_noise.to(device)]).detach().cpu().numpy() 107 | 108 | # Save predictions 109 | source_preds["accompaniment"][:,source_pos:source_pos + output_time_frames] = source_parts[0, 1] 110 | if source_parts[0].shape[0] > 2: 111 | source_preds["vocals"][:, source_pos:source_pos + output_time_frames] = source_parts[0, 2] 112 | else: 113 | source_preds["vocals"][:, source_pos:source_pos + output_time_frames] = source_parts[0, 1] # Copy acc prediction into vocals for acc-only model 114 | 115 | # Convert predictions back to audio signal 116 | for key in source_preds.keys(): 117 | mags = Utils.denormalise_spectrogram(source_preds[key]) 118 | source_preds[key] = Utils.spectrogramToAudioFile(mags, model_config.fft_size, model_config.hop_size, phase=np.angle(mix_ph)) 119 | 120 | return source_preds 121 | 122 | def compute_mean_metrics(json_folder, compute_averages=True, metric="SDR"): 123 | ''' 124 | Computes averages or collects evaluation metrics produced from MUSDB evaluation of a separator 125 | (see "produce_musdb_source_estimates" function), namely the mean, standard deviation, median, and median absolute 126 | deviation (MAD). Function is used to produce the results in the paper. 127 | Averaging ignores NaN values arising from parts where a source is silent 128 | :param json_folder: Path to the folder in which a collection of json files was written by the MUSDB evaluation library, one for each song. 129 | This is the output of the "produce_musdb_source_estimates" function.(By default, this is model_config["estimates_path"] + test or train) 130 | :param compute_averages: Whether to compute the average over all song segments (to get final evaluation measures) or to return the full list of segments 131 | :param metric: Which metric to evaluate (either "SDR", "SIR", "SAR" or "ISR") 132 | :return: IF compute_averages is True, returns a list with length equal to the number of separated sources, with each list element a tuple of (median, MAD, mean, SD). 133 | If it is false, also returns this list, but each element is now a numpy vector containing all segment-wise performance values 134 | ''' 135 | files = glob.glob(os.path.join(json_folder, "*.json")) 136 | inst_list = None 137 | print("Found " + str(len(files)) + " JSON files to evaluate...") 138 | for path in files: 139 | #print(path) 140 | if path.__contains__("test.json"): 141 | print("Found test JSON, skipping...") 142 | continue 143 | 144 | with open(path, "r") as f: 145 | js = json.load(f) 146 | 147 | if inst_list is None: 148 | inst_list = [list() for _ in range(len(js["targets"]))] 149 | 150 | for i in range(len(js["targets"])): 151 | inst_list[i].extend([np.float(f['metrics'][metric]) for f in js["targets"][i]["frames"]]) 152 | 153 | #return np.array(sdr_acc), np.array(sdr_voc) 154 | inst_list = [np.array(perf) for perf in inst_list] 155 | 156 | if compute_averages: 157 | return [(np.nanmedian(perf), np.nanmedian(np.abs(perf - np.nanmedian(perf))), np.nanmean(perf), np.nanstd(perf)) for perf in inst_list] 158 | else: 159 | return inst_list -------------------------------------------------------------------------------- /eval/Visualisation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision.utils as vutils 3 | 4 | def generate_images(generator, generator_inputs, out_path, num_images, device, transform=None): 5 | # Create subfolder if it doesn't exist 6 | if not os.path.exists(out_path): 7 | os.makedirs(out_path) 8 | 9 | # Generate images and save 10 | idx = 0 11 | while idx < num_images: 12 | gen_input = next(generator_inputs) 13 | gen_input = [item.to(device) for item in gen_input] 14 | sample_batch = generator(gen_input) 15 | if transform is not None: 16 | sample_batch = transform(sample_batch) 17 | for sample in sample_batch: 18 | save_as_image(sample, os.path.join(out_path, "gen_" + str(idx) + ".png")) 19 | idx += 1 20 | 21 | def save_as_image(tensor, path): 22 | vutils.save_image(tensor, path, normalize=True) -------------------------------------------------------------------------------- /eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/eval/__init__.py -------------------------------------------------------------------------------- /factorgan_conditional.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/factorgan_conditional.png -------------------------------------------------------------------------------- /factorgan_unconditional.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/factorgan_unconditional.png -------------------------------------------------------------------------------- /models/MNISTClassifier.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | class MNISTModel(nn.Module): 5 | def __init__(self): 6 | ''' 7 | Simple CNN to use as MNIST classifier 8 | ''' 9 | super(MNISTModel, self).__init__() 10 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 11 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 12 | self.conv2_drop = nn.Dropout2d() 13 | self.fc1 = nn.Linear(320, 50) 14 | self.fc2 = nn.Linear(50, 10) 15 | 16 | def forward(self, x, return_hidden=False): 17 | x = F.leaky_relu(F.avg_pool2d(self.conv1(x), 2)) 18 | x = F.leaky_relu(F.avg_pool2d(self.conv2(x), 2)) 19 | x = x.view(-1, 320) 20 | x = F.leaky_relu(self.fc1(x)) 21 | x = F.dropout(x, training=self.training) 22 | 23 | # If return_hidden flag is given, we return the activations from the last layer instead of the final output 24 | if return_hidden: 25 | return x 26 | 27 | x = self.fc2(x) 28 | if self.training: 29 | return F.log_softmax(x, dim=1) 30 | else: 31 | return F.softmax(x, dim=1) -------------------------------------------------------------------------------- /models/SpectralNorm.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The following is taken from 3 | 4 | https://github.com/christiancosgrove/pytorch-spectral-normalization-gan/blob/master/spectral_normalization.py 5 | 6 | licensed under 7 | 8 | MIT License 9 | 10 | Copyright (c) 2017 Christian Cosgrove 11 | 12 | Permission is hereby granted, free of charge, to any person obtaining a copy 13 | of this software and associated documentation files (the "Software"), to deal 14 | in the Software without restriction, including without limitation the rights 15 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 16 | copies of the Software, and to permit persons to whom the Software is 17 | furnished to do so, subject to the following conditions: 18 | 19 | The above copyright notice and this permission notice shall be included in all 20 | copies or substantial portions of the Software. 21 | 22 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 23 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 24 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 25 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 26 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 27 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 28 | SOFTWARE. 29 | ''' 30 | 31 | import torch 32 | from torch import nn 33 | from torch.nn import Parameter 34 | 35 | def l2normalize(v, eps=1e-12): 36 | return v / (v.norm() + eps) 37 | 38 | class SpectralNorm(nn.Module): 39 | def __init__(self, module, name='weight', power_iterations=1): 40 | super(SpectralNorm, self).__init__() 41 | self.module = module 42 | self.name = name 43 | self.power_iterations = power_iterations 44 | if not self._made_params(): 45 | self._make_params() 46 | 47 | def _update_u_v(self): 48 | u = getattr(self.module, self.name + "_u") 49 | v = getattr(self.module, self.name + "_v") 50 | w = getattr(self.module, self.name + "_bar") 51 | 52 | height = w.data.shape[0] 53 | for _ in range(self.power_iterations): 54 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) 55 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) 56 | 57 | sigma = u.dot(w.view(height, -1).mv(v)) 58 | setattr(self.module, self.name, w / sigma.expand_as(w)) 59 | 60 | def _made_params(self): 61 | try: 62 | u = getattr(self.module, self.name + "_u") 63 | v = getattr(self.module, self.name + "_v") 64 | w = getattr(self.module, self.name + "_bar") 65 | return True 66 | except AttributeError: 67 | return False 68 | 69 | def _make_params(self): 70 | w = getattr(self.module, self.name) 71 | 72 | height = w.data.shape[0] 73 | width = w.view(height, -1).data.shape[1] 74 | 75 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 76 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 77 | u.data = l2normalize(u.data) 78 | v.data = l2normalize(v.data) 79 | w_bar = Parameter(w.data) 80 | 81 | del self.module._parameters[self.name] 82 | 83 | self.module.register_parameter(self.name + "_u", u) 84 | self.module.register_parameter(self.name + "_v", v) 85 | self.module.register_parameter(self.name + "_bar", w_bar) 86 | 87 | 88 | def forward(self, *args): 89 | self._update_u_v() 90 | return self.module.forward(*args) 91 | 92 | def set_spectral_norm(module, activate): 93 | if activate: 94 | return SpectralNorm(module) 95 | else: 96 | return module -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/models/__init__.py -------------------------------------------------------------------------------- /models/discriminators/ConvDiscriminator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from models.SpectralNorm import set_spectral_norm 6 | 7 | class ConvDiscriminator(nn.Module): 8 | def __init__(self, x_dim, y_dim, input_channels=1, filters=32, spectral_norm=True): 9 | super(ConvDiscriminator, self).__init__() 10 | 11 | num_layers = int(np.log2(min(x_dim, y_dim)) - 3) 12 | feature_width = x_dim // (2 ** (num_layers+1)) 13 | feature_height = y_dim // (2 ** (num_layers+1)) 14 | 15 | assert(np.mod(num_layers, 1) == 0) 16 | num_layers = int(num_layers) 17 | 18 | conv_layers = list() 19 | conv_layers.append(set_spectral_norm(nn.Conv2d(input_channels, filters, 4, 2, 1), spectral_norm)) 20 | conv_layers.append(nn.LeakyReLU()) 21 | for i in range(num_layers): 22 | conv_layers.append(set_spectral_norm(nn.Conv2d(filters * (2 ** i), filters * (2 ** (i + 1)), 4, 2, 1), spectral_norm)) 23 | conv_layers.append(nn.LeakyReLU()) 24 | 25 | self.fc = set_spectral_norm(nn.Linear(feature_width * feature_height * filters * (2 ** num_layers), 1, bias=False), spectral_norm) 26 | self.conv_layers = nn.ModuleList(conv_layers) 27 | 28 | def conv_size(self, orig_size, filter_size, padding, stride): 29 | return np.floor((orig_size + 2*padding - filter_size).astype(float) / stride.astype(float)).astype(int) + 1 30 | 31 | def forward(self, input): 32 | x = input 33 | for layer in self.conv_layers: 34 | x = layer(x) 35 | x = self.fc(x.view(x.shape[0], -1)) 36 | 37 | return torch.squeeze(x, 1) -------------------------------------------------------------------------------- /models/discriminators/FCDiscriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.SpectralNorm import set_spectral_norm 5 | 6 | class FCDiscriminator(nn.Module): 7 | def __init__(self, input_dim, spectral_norm=True, preprocess_func=None): 8 | ''' 9 | Fully connected discriminator network 10 | :param input_dim: Number of inputs 11 | :param spectral_norm: Whether to use spectral normalisation 12 | :param preprocess_func: Function that preprocesses the input before feeding to the network 13 | ''' 14 | super(FCDiscriminator, self).__init__() 15 | self.preprocess_func = preprocess_func 16 | 17 | self.fc1 = set_spectral_norm(nn.Linear(input_dim, 128), spectral_norm) 18 | self.fc2 = set_spectral_norm(nn.Linear(128, 128), spectral_norm) 19 | self.fc3 = set_spectral_norm(nn.Linear(128, 1), spectral_norm) 20 | self.activation = nn.LeakyReLU() 21 | self.output_activation = nn.Sigmoid() 22 | 23 | def forward(self, input): 24 | if self.preprocess_func != None: 25 | input = self.preprocess_func(input) 26 | 27 | x = self.fc1(input) 28 | x = self.activation(x) 29 | x = self.fc2(x) 30 | x = self.activation(x) 31 | x = self.fc3(x) 32 | return torch.squeeze(x, 1) -------------------------------------------------------------------------------- /models/discriminators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/models/discriminators/__init__.py -------------------------------------------------------------------------------- /models/generators/ConvGenerator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | 4 | class ConvGenerator(nn.Module): 5 | def __init__(self, opt, ngf, out_width, nc=3): 6 | super(ConvGenerator, self).__init__() 7 | self.conv_size = 4 8 | 9 | num_layers = np.log2(out_width) - 3 10 | assert(np.mod(num_layers, 1) == 0) 11 | num_layers = int(num_layers) 12 | 13 | conv_list = list() 14 | 15 | # Compute channel numbers in each layer 16 | channel_list = [ngf * (2 ** (i-1)) for i in range(num_layers+1, 0, -1)] 17 | 18 | # First layer 19 | conv_list.append(nn.ConvTranspose2d(opt.nz, channel_list[0], 4, 1, 0, bias=True)) 20 | conv_list.append(nn.ReLU(True)) 21 | 22 | for i in range(0, num_layers): 23 | conv_list.append(nn.ConvTranspose2d(channel_list[i], channel_list[i+1], self.conv_size, 2, 1, bias=True)) 24 | conv_list.append(nn.ReLU(True)) 25 | 26 | # Last layer 27 | conv_list.append(nn.ConvTranspose2d(ngf, nc, self.conv_size, 2, 1, bias=True)) 28 | if opt.generator_activation == "sigmoid": 29 | conv_list.append(nn.Sigmoid()) 30 | elif opt.generator_activation == "relu": 31 | conv_list.append(nn.ReLU()) 32 | else: 33 | print("WARNING: Using ConvGenerator without output activation") 34 | 35 | self.main = nn.Sequential(*conv_list) 36 | 37 | def forward(self, input): 38 | assert (len(input) == 1) 39 | noise = input[0].unsqueeze(2).unsqueeze(2) 40 | output = self.main(noise) 41 | return output 42 | -------------------------------------------------------------------------------- /models/generators/FCGenerator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class FCGenerator(nn.Module): 5 | def __init__(self, opt, output_dim): 6 | super(FCGenerator, self).__init__() 7 | self.opt = opt 8 | self.fc1 = nn.Linear(opt.nz, 128) 9 | self.fc2 = nn.Linear(128, 128) 10 | self.fc3 = nn.Linear(128, output_dim) 11 | self.activation = nn.ReLU() 12 | 13 | if opt.generator_activation == "sigmoid": 14 | self.output_activation = nn.Sigmoid() 15 | elif opt.generator_activation == "relu": 16 | self.output_activation = nn.ReLU() 17 | else: 18 | print("Using generator without output activation") 19 | self.output_activation = None 20 | 21 | def forward(self, input): 22 | assert(len(input) == 1) 23 | 24 | x = self.fc1(input[0]) 25 | x = self.activation(x) 26 | x = self.fc2(x) 27 | x = self.activation(x) 28 | x = self.fc3(x) 29 | if self.output_activation != None: 30 | x = self.output_activation(x) 31 | return x -------------------------------------------------------------------------------- /models/generators/Unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import Utils 6 | 7 | class Unet(nn.Module): 8 | def __init__(self, opt, ngf, nc_in=3, nc_out=3): 9 | 10 | super(Unet, self).__init__() 11 | self.inc = inconv(nc_in, ngf) 12 | self.down1 = down(ngf, ngf*2) 13 | self.down2 = down(ngf*2, ngf*4) 14 | self.down3 = down(ngf*4, ngf*8) 15 | self.down4 = down(ngf*8, ngf*8) 16 | 17 | fmap_shape = (opt.input_width//16)*(opt.input_height//16) 18 | noise_channels = 2 19 | self.fc_noise = nn.Linear(opt.nz, noise_channels*fmap_shape) 20 | 21 | self.up1 = up(ngf*16 + noise_channels, ngf*4) 22 | self.up2 = up(ngf*8, ngf*2) 23 | self.up3 = up(ngf*4, ngf) 24 | self.up4 = up(ngf*2, ngf) 25 | self.outc = outconv(ngf, nc_out) 26 | 27 | if opt.generator_activation == "sigmoid": 28 | self.output_activation = nn.Sigmoid() 29 | elif opt.generator_activation == "relu": 30 | self.output_activation = nn.ReLU() 31 | else: 32 | print("WARNING: Using Unet without output activation") 33 | self.output_activation = None 34 | 35 | if hasattr(opt, "generator_mask"): 36 | assert(opt.generator_mask == 0 or nc_out == 1) 37 | assert(opt.generator_mask == 0 or opt.generator_activation == "sigmoid") 38 | self.mask = opt.generator_mask 39 | else: 40 | self.mask = False 41 | 42 | def forward(self, input): 43 | cond = input[0] 44 | noise = input[1] 45 | 46 | # Downward 47 | x1 = self.inc(cond) 48 | x2 = self.down1(x1) 49 | x3 = self.down2(x2) 50 | x4 = self.down3(x3) 51 | x5 = self.down4(x4) 52 | 53 | # FC to append noise channels to feature map 54 | noise_fmap_shape = [x5.shape[0], -1, x5.shape[2], x5.shape[3]] 55 | noise_fmap = self.fc_noise(noise).view(noise_fmap_shape) 56 | x5 = torch.cat([x5, noise_fmap], dim=1) 57 | 58 | # Upward 59 | x = self.up1(x5, x4) 60 | x = self.up2(x, x3) 61 | x = self.up3(x, x2) 62 | x = self.up4(x, x1) 63 | x = self.outc(x) 64 | 65 | if self.output_activation != None: 66 | x = self.output_activation(x) 67 | 68 | if self.mask: 69 | # Denormalise audio mixture input 70 | original_cond = Utils.denormalise_spectrogram_torch(cond) 71 | 72 | # Compute acc and voice based on mask and unnormalised mixture spectrogram, then renormalise 73 | acc = Utils.normalise_spectrogram_torch(original_cond * x) 74 | voice = Utils.normalise_spectrogram_torch(original_cond * (1.0 - x)) 75 | 76 | x = torch.cat([acc, voice], dim=1) 77 | 78 | return torch.cat([cond, x], dim=1) 79 | 80 | class double_conv(nn.Module): 81 | '''(conv => BN => ReLU) * 2''' 82 | 83 | def __init__(self, in_ch, out_ch): 84 | super(double_conv, self).__init__() 85 | self.conv = nn.Sequential( 86 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 87 | nn.BatchNorm2d(out_ch), 88 | nn.ReLU(inplace=True), 89 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 90 | nn.BatchNorm2d(out_ch), 91 | nn.ReLU(inplace=True) 92 | ) 93 | 94 | def forward(self, x): 95 | x = self.conv(x) 96 | return x 97 | 98 | class inconv(nn.Module): 99 | def __init__(self, in_ch, out_ch): 100 | super(inconv, self).__init__() 101 | self.conv = double_conv(in_ch, out_ch) 102 | 103 | def forward(self, x): 104 | x = self.conv(x) 105 | return x 106 | 107 | 108 | class down(nn.Module): 109 | def __init__(self, in_ch, out_ch): 110 | super(down, self).__init__() 111 | self.mpconv = nn.Sequential( 112 | nn.MaxPool2d(2), 113 | double_conv(in_ch, out_ch) 114 | ) 115 | 116 | def forward(self, x): 117 | x = self.mpconv(x) 118 | return x 119 | 120 | 121 | class up(nn.Module): 122 | def __init__(self, in_ch, out_ch, bilinear=True): 123 | super(up, self).__init__() 124 | 125 | # would be a nice idea if the upsampling could be learned too, 126 | # but my machine do not have enough memory to handle all those weights 127 | if bilinear: 128 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 129 | else: 130 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) 131 | 132 | self.conv = double_conv(in_ch, out_ch) 133 | 134 | def forward(self, x1, x2): 135 | x1 = self.up(x1) 136 | 137 | # input is CHW 138 | diffY = x2.size()[2] - x1.size()[2] 139 | diffX = x2.size()[3] - x1.size()[3] 140 | 141 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, 142 | diffY // 2, diffY - diffY // 2)) 143 | 144 | # for padding issues, see 145 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 146 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 147 | 148 | x = torch.cat([x2, x1], dim=1) 149 | x = self.conv(x) 150 | return x 151 | 152 | 153 | class outconv(nn.Module): 154 | def __init__(self, in_ch, out_ch): 155 | super(outconv, self).__init__() 156 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 157 | 158 | def forward(self, x): 159 | x = self.conv(x) 160 | return x -------------------------------------------------------------------------------- /models/generators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/models/generators/__init__.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio 2 | seaborn 3 | pandas 4 | matplotlib 5 | soundfile 6 | scipy 7 | museval 8 | musdb 9 | librosa 10 | numpy 11 | tensorboard 12 | torchvision==0.5.0 13 | torch==1.4.0 14 | tqdm -------------------------------------------------------------------------------- /training/AdversarialTraining.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm import tqdm 4 | 5 | from training.DiscriminatorTraining import * 6 | from torch.utils.tensorboard import SummaryWriter 7 | 8 | def train(cfg, G, G_input, G_opt, D_marginal_setups, D_dep_pairs, device, logdir): 9 | print("START TRAINING! Writing logs to " + logdir) 10 | writer = SummaryWriter(logdir) 11 | 12 | # Create expression for overall discriminator output given a complete fake sample 13 | # Marginal disc sum output 14 | marginal_sum = lambda y: sum(setup.D(setup.crop_fake(y)) for setup in D_marginal_setups) 15 | if cfg.factorGAN == 1: 16 | # Dep disc sum output 17 | dep_sum = lambda y: sum(disc_pair.get_comb_disc()(y) for disc_pair in D_dep_pairs) 18 | jointD = lambda y: marginal_sum(y) + dep_sum(y) 19 | else: 20 | jointD = marginal_sum 21 | 22 | # START NORMAL TRAINING 23 | for epoch in range(cfg.epochs): 24 | for i in tqdm(range(cfg.epoch_iter)): 25 | total_it = epoch * cfg.epoch_iter + i 26 | # If dependency GAN active, train marginal discriminators here from both extra data and main data 27 | for j in range(cfg.disc_iter): # No. of disc iterations 28 | # Train marginal discriminators 29 | for D_setup in D_marginal_setups: 30 | errD, correct, _, _ = get_marginal_disc_output(D_setup, device, backward=True, zero_gradients=True) 31 | if j==cfg.disc_iter-1: writer.add_scalar(D_setup.name + "_acc", correct, total_it) 32 | D_setup.optim.step() 33 | 34 | if cfg.factorGAN == 1: 35 | # Additionally train dependency discriminators 36 | for D_dep_pair in D_dep_pairs: 37 | # Train REAL dependency discriminator 38 | if cfg.use_real_dep_disc == 1: 39 | # Training step for real dep disc 40 | errD, correct, _, _ = get_dep_disc_output(D_dep_pair.real_disc, device, backward=True,zero_gradients=True) 41 | D_dep_pair.real_disc.optim.step() 42 | 43 | # Logging for last discriminator update 44 | if j == cfg.disc_iter - 1: writer.add_scalar(D_dep_pair.real_disc.name + "_acc", correct, total_it) 45 | if j == cfg.disc_iter - 1: writer.add_scalar(D_dep_pair.real_disc.name + "_errD", errD, total_it) 46 | 47 | # Train FAKE dependency discriminator. Use combined output of real and fake dependency discs for regularisation purposes => Fake dep. disc needs to ensure its gradients stay close to the real dep. ones 48 | errD, correct, _, _ = get_dep_disc_output(D_dep_pair.fake_disc, device, backward=True, zero_gradients=True) 49 | D_dep_pair.fake_disc.optim.step() 50 | 51 | # Logging for last discriminator update 52 | if j == cfg.disc_iter - 1: writer.add_scalar(D_dep_pair.fake_disc.name + "_acc", correct, total_it) 53 | if j == cfg.disc_iter - 1: writer.add_scalar(D_dep_pair.fake_disc.name + "_errD", errD, total_it) 54 | 55 | ############################ 56 | # (2) Update G network: 57 | ########################### 58 | 59 | G.zero_grad() 60 | 61 | # Get fake samples from generator 62 | gen_input = G_input.__next__() 63 | gen_input = [item.to(device) for item in gen_input] 64 | fake_sample = G(gen_input) 65 | 66 | #TODO Produce log outputs for all tasks here? 67 | 68 | # Get setup information from first marginal discriminator (which is the only one in normal GAN training) 69 | real_label = D_marginal_setups[0].real_label 70 | criterion = D_marginal_setups[0].criterion 71 | 72 | label = torch.full((cfg.batchSize,), real_label, device=device) # fake labels are real for generator cost 73 | disc_output = jointD(fake_sample) 74 | writer.add_scalar("probG", torch.mean(torch.nn.Sigmoid()(disc_output)), total_it) 75 | if cfg.objective == "JSD": 76 | errG = criterion()(disc_output, label) # Normal JSD 77 | elif cfg.objective == "KL": 78 | errG = -torch.mean(disc_output) # KL[q|p] 79 | else: 80 | raise NotImplementedError 81 | 82 | writer.add_scalar("errG", errG, i) 83 | 84 | errG.backward() 85 | G_opt.step() 86 | 87 | print("EPOCH FINISHED") 88 | 89 | model_output_path = os.path.join(cfg.experiment_path, "G_" + str(epoch)) 90 | print("Saving generator at " + model_output_path) 91 | torch.save(G.state_dict(), model_output_path) 92 | 93 | # FINISHED 94 | print("TRAINING FINISHED") -------------------------------------------------------------------------------- /training/DiscriminatorTraining.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | 4 | class DiscriminatorSetup(object): 5 | def __init__(self, name, D, optim, real_data, fake_data, 6 | crop_real=lambda x:x, crop_fake=lambda x:x, 7 | criterion=torch.nn.BCEWithLogitsLoss, real_label=1, fake_label=0): 8 | ''' 9 | Disriminator model including optimiser, input sources, how inputs are cropped, and how the discriminator is trained 10 | :param name: Name of discriminator 11 | :param D: Discriminator model 12 | :param optim: Optimiser for D 13 | :param real_data: Real data iterator 14 | :param fake_data: Fake data iterator 15 | :param crop_real: Function that crops real inputs 16 | :param crop_fake: Function that crops fake inputs 17 | :param criterion: Discriminator training loss to use 18 | :param real_label: Real label for training loss 19 | :param fake_label: Fake label for training loss 20 | ''' 21 | self.name = name 22 | self.D = D 23 | self.optim = optim 24 | self.real_data = real_data 25 | self.fake_data = fake_data 26 | 27 | self.real_label = real_label 28 | self.fake_label = fake_label 29 | 30 | self.crop_real = crop_real 31 | self.crop_fake = crop_fake 32 | 33 | self.criterion = criterion 34 | 35 | class DependencyDiscriminatorSetup(object): 36 | def __init__(self, name, D, optim, data, shuffle_batch_func, crop_func=lambda x:x, 37 | criterion=torch.nn.BCEWithLogitsLoss, real_label=1, fake_label=0): 38 | ''' 39 | Dependency discriminator model including name, optimiser, input data source, how batches are shuffled, and training criterion 40 | :param name: Name of discriminator 41 | :param D: Discriminator model 42 | :param optim: Optimiser 43 | :param data: Joint data source ("real" data) with dependencies 44 | :param shuffle_batch_func: Function that shuffles a batch taken from data to get an independent version 45 | :param crop_func: Function that crops the input before feeding it to the discriminator 46 | :param criterion: Training loss 47 | :param real_label: Real label 48 | :param fake_label: Fake label (shuffled/independent batch) 49 | ''' 50 | self.name = name 51 | self.D = D 52 | self.optim = optim 53 | self.data = data 54 | self.real_label = real_label 55 | self.fake_label = fake_label 56 | self.crop_func = crop_func 57 | self.shuffle_batch_func = shuffle_batch_func 58 | self.criterion = criterion 59 | 60 | class DependencyDiscriminatorPair(object): 61 | def __init__(self, real_disc, fake_disc:DependencyDiscriminatorSetup): 62 | ''' 63 | Binds to dependency discriminators (for p and q) together, to compute a combined output 64 | :param real_disc: p-dependency discriminator. Can be None in case none is used 65 | :param fake_disc: q-dependency discriminator 66 | ''' 67 | assert(real_disc == None or isinstance(real_disc, DependencyDiscriminatorSetup)) 68 | self.real_disc = real_disc 69 | self.fake_disc = fake_disc 70 | 71 | def get_comb_disc(self): 72 | ''' 73 | Computes combined output d_P(x) - d_Q(x) that represents the dependency-based part of the generator loss 74 | :return: 75 | ''' 76 | if self.real_disc != None: 77 | return lambda x : self.real_disc.D(self.real_disc.crop_func(x)) - self.fake_disc.D(self.fake_disc.crop_func(x)) 78 | else: 79 | return lambda x : - self.fake_disc.D(self.fake_disc.crop_func(x)) 80 | 81 | def get_dep_disc_output(disc_setup:DependencyDiscriminatorSetup, device, backward=False, zero_gradients=False): 82 | ''' 83 | Computes output of a dependency discriminator (and optionally gradients) for another input batch, by using the batch 84 | itself as real, and a shuffled version as fake input 85 | :param disc_setup: Dependency discriminator object 86 | :param device: Device to use 87 | :param backward: Whether to compute gradients 88 | :param zero_gradients: Whether to zero gradients at the beginning 89 | :return: see get_disc_output_batch 90 | ''' 91 | real_sample = disc_setup.data.__next__() 92 | real_sample = disc_setup.crop_func(real_sample) 93 | 94 | fake_sample = disc_setup.shuffle_batch_func(real_sample) # Get fake data by simply shuffling current batch 95 | 96 | return get_disc_output_batch(disc_setup.D, real_sample, fake_sample, disc_setup.real_label, disc_setup.fake_label, disc_setup.criterion, device, backward, zero_gradients) 97 | 98 | def get_marginal_disc_output(disc_setup:DiscriminatorSetup, device, backward=False, zero_gradients=False): 99 | ''' 100 | Computes output of a marginal discriminator 101 | :param disc_setup: Marginal discriminator object 102 | :param device: Device to use 103 | :param backward: Whether to compute gradients 104 | :param zero_gradients: Whether to zero gradients at the beginning 105 | :return: see get_disc_output_batch 106 | ''' 107 | real_sample = disc_setup.real_data.__next__().to(device) 108 | real_sample = disc_setup.crop_real(real_sample) 109 | 110 | fake_sample = disc_setup.fake_data.__next__().to(device) 111 | fake_sample = disc_setup.crop_fake(fake_sample) 112 | 113 | return get_disc_output_batch(disc_setup.D, real_sample, fake_sample, disc_setup.real_label, disc_setup.fake_label, disc_setup.criterion, device, backward, zero_gradients) 114 | 115 | def get_disc_output_batch(D, real_sample, fake_sample, real_label, fake_label, criterion, device, backward, zero_gradients): 116 | ''' 117 | Compute loss, output and optionally gradients for a discriminator model with a given training loss 118 | :param D: Discriminator model 119 | :param real_sample: Batch of real samples 120 | :param fake_sample: Batch of fake samples 121 | :param real_label: Target label for real batch 122 | :param fake_label: Target label for fake batch 123 | :param criterion: Training loss 124 | :param device: Device to use 125 | :param backward: Whether to compute gradients 126 | :param zero_gradients: Whether to zero gradients at the beginning 127 | :return: Average of real and fake training loss, discriminator accuracy, discriminator outputs for real and fake 128 | ''' 129 | # Never backpropagate through disc input in this function 130 | # Transfer inputs to correct device 131 | real_sample = real_sample.detach().to(device) 132 | fake_sample = fake_sample.detach().to(device) 133 | 134 | if zero_gradients: 135 | D.zero_grad() 136 | 137 | # Get real sample output 138 | real_batch_size = real_sample.size()[0] 139 | 140 | label = torch.full((real_batch_size,), real_label, device=device) 141 | real_output = D(real_sample) 142 | 143 | errD_real = criterion()(real_output, label) 144 | if backward: 145 | errD_real.backward() 146 | 147 | # Get fake sample output 148 | fake_batch_size = fake_sample.size()[0] 149 | label = torch.full((fake_batch_size,), fake_label, device=device) 150 | fake_output = D(fake_sample) 151 | 152 | errD_fake = criterion()(fake_output, label) 153 | if backward: 154 | errD_fake.backward() 155 | 156 | # Accuracy 157 | correct = 0.5 * (real_output > 0.0).sum().item() / real_batch_size + 0.5 * (fake_output < 0.0).sum().item() / fake_batch_size 158 | return 0.5*errD_real + 0.5*errD_fake, correct, real_output, fake_output -------------------------------------------------------------------------------- /training/MNIST.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import os 6 | from torchvision import datasets, transforms 7 | 8 | from models.MNISTClassifier import MNISTModel 9 | 10 | 11 | def train(model, device, train_loader, optimizer, epoch): 12 | ''' 13 | Train MNIST model for one epoch 14 | :param model: MNIST model 15 | :param device: Device to use 16 | :param train_loader: Training dataset loader 17 | :param optimizer: Optimiser to use 18 | :param epoch: Current epoch index 19 | ''' 20 | model.train() 21 | for batch_idx, (data, target) in enumerate(train_loader): 22 | data, target = data.to(device), target.to(device) 23 | optimizer.zero_grad() 24 | output = model(data) 25 | loss = F.nll_loss(output, target) 26 | loss.backward() 27 | optimizer.step() 28 | if batch_idx % 10 == 0: 29 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 30 | epoch, batch_idx * len(data), len(train_loader.dataset), 31 | 100. * batch_idx / len(train_loader), loss.item())) 32 | 33 | def test(model, device, test_loader): 34 | ''' 35 | Test MNIST classifier, prints out results into standard output 36 | :param model: Classifier model 37 | :param device: Device to use 38 | :param test_loader: Test dataset loader 39 | ''' 40 | model.eval() 41 | test_loss = 0 42 | correct = 0 43 | with torch.no_grad(): 44 | for data, target in test_loader: 45 | data, target = data.to(device), target.to(device) 46 | output = model(data) 47 | test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss 48 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 49 | correct += pred.eq(target.view_as(pred)).sum().item() 50 | 51 | test_loss /= len(test_loader.dataset) 52 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 53 | test_loss, correct, len(test_loader.dataset), 54 | 100. * correct / len(test_loader.dataset))) 55 | 56 | def main(opt): 57 | use_cuda = opt.cuda and torch.cuda.is_available() 58 | 59 | torch.manual_seed(opt.seed) 60 | 61 | device = torch.device("cuda" if use_cuda else "cpu") 62 | 63 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 64 | train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True, 65 | transform=transforms.Compose([transforms.ToTensor()])), 66 | batch_size=opt.batchSize, shuffle=True, **kwargs) 67 | test_loader = torch.utils.data.DataLoader( 68 | datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor()])), 69 | batch_size=opt.batchSize, shuffle=True, **kwargs) 70 | 71 | 72 | model = MNISTModel().to(device) 73 | optimizer = optim.Adam(model.parameters(), lr=1e-4) 74 | 75 | # Check if we saved the model before already, in that case, just load that! 76 | MODEL_NAME = "mnist_classifier_model" 77 | if os.path.exists(MODEL_NAME): 78 | print("Found pre-trained MNIST classifier, loading from " + MODEL_NAME) 79 | model.load_state_dict(torch.load(MODEL_NAME)) 80 | return model 81 | 82 | # Train model for a certain number of epochs 83 | NUM_EPOCHS = 4 84 | print("TRAINING MNIST CLASSIFIER") 85 | for epoch in range(NUM_EPOCHS): 86 | train(model, device, train_loader, optimizer, epoch) 87 | test(model, device, test_loader) 88 | torch.save(model.state_dict(), MODEL_NAME) 89 | return model 90 | 91 | if __name__ == '__main__': 92 | main() -------------------------------------------------------------------------------- /training/TrainingOptions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | 4 | def get_parser(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--experiment_name', type=str, default=str(np.random.randint(0, 100000)), help='experiment name') 7 | parser.add_argument('--out_path', type=str, default="out", help="Output path") 8 | 9 | parser.add_argument('--eval', action='store_true', help='Perform evaluation instead of training') 10 | parser.add_argument('--eval_model', type=str, default='G', help='Name of generator checkpoint file to load for evaluation') 11 | 12 | parser.add_argument('--epochs', type=int, default=40, help='number of epochs to train for') 13 | parser.add_argument('--epoch_iter', type=int, default=5000, help="Number of generator updates per epoch") 14 | 15 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate, default=0.0002') 16 | parser.add_argument('--L2', type=float, default=0.0, help='L2 regularisation for discriminators (except joint p dependency discriminator') 17 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 18 | 19 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 20 | parser.add_argument('--seed', type=int, default=1337, help='manual seed') 21 | 22 | # Generator settings 23 | parser.add_argument('--nz', type=int, default=50, help='size of the latent z vector') 24 | parser.add_argument('--factorGAN', type=int, default=0, help="Activate FactorGAN instead of normal GAN") 25 | 26 | parser.add_argument('--disc_iter', type=int, default=2, help="Number of discriminator(s) iteration per generator update") 27 | parser.add_argument('--objective', type=str, default="JSD", help="JSD or KL as generator objective") 28 | 29 | # Dependency settings 30 | parser.add_argument('--lipschitz_q', type=int, default=1, help="Spectral norm regularisation for fake dependency discriminators") 31 | parser.add_argument('--lipschitz_p', type=int, default=1, help="Spectral norm regularisation for real dependency discriminators") 32 | parser.add_argument('--use_real_dep_disc', type=int, default=1, help="1 to use the dependency discriminator on real data normally, 0 to not use it and set the output to zero, assuming our real data dimensions are independent") 33 | 34 | # Data loading settings 35 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=1) 36 | parser.add_argument('--batchSize', type=int, default=25, help='input batch size') 37 | 38 | return parser -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/training/__init__.py --------------------------------------------------------------------------------