├── AudioSeparation.py
├── Image2Image.py
├── Image2ImageGrid.py
├── ImagePairs.py
├── ImagePairsGrid.py
├── LICENSE
├── PairedMNIST.py
├── README.md
├── Utils.py
├── analysis
├── __init__.py
├── imagepairs
│ ├── ImagePairs Results.csv
│ ├── ImagePairs_cityscapes_1000_joint_GAN_big_gens.png
│ ├── ImagePairs_cityscapes_1000_joint_GAN_gens.png
│ ├── ImagePairs_cityscapes_1000_joint_factorGAN_gens.png
│ ├── ImagePairs_cityscapes_100_joint_GAN_big_gens.png
│ ├── ImagePairs_cityscapes_100_joint_GAN_gens.png
│ ├── ImagePairs_cityscapes_100_joint_factorGAN_gens.png
│ ├── ImagePairs_cityscapes_All_joint_GAN_big_gens.png
│ ├── ImagePairs_cityscapes_All_joint_GAN_gens.png
│ ├── ImagePairs_cityscapes_All_joint_factorGAN_gens.png
│ ├── ImagePairs_edges2shoes_1000_joint_GAN_big_gens.png
│ ├── ImagePairs_edges2shoes_1000_joint_GAN_gens.png
│ ├── ImagePairs_edges2shoes_1000_joint_factorGAN_gens.png
│ ├── ImagePairs_edges2shoes_100_joint_GAN_big_gens.png
│ ├── ImagePairs_edges2shoes_100_joint_GAN_gens.png
│ ├── ImagePairs_edges2shoes_100_joint_GAN_gens_8.png
│ ├── ImagePairs_edges2shoes_100_joint_factorGAN_gens.png
│ ├── ImagePairs_edges2shoes_100_joint_factorGAN_gens_8.png
│ ├── ImagePairs_edges2shoes_All_joint_GAN_big_gens.png
│ ├── ImagePairs_edges2shoes_All_joint_GAN_gens.png
│ ├── ImagePairs_edges2shoes_All_joint_factorGAN_gens.png
│ ├── PlotExamples.py
│ ├── PlotLS.py
│ ├── __init__.py
│ └── imagepairs_LS.pdf
├── img2img_cityscapes
│ ├── Image2Image_Cityscapes.csv
│ ├── PlotExamples.py
│ ├── PlotL2Acc.py
│ ├── __init__.py
│ ├── cityscapes.pdf
│ ├── cityscapes_1000_joint_GAN_gens.png
│ ├── cityscapes_1000_joint_factorGAN_gens.png
│ ├── cityscapes_100_joint_GAN_gens.png
│ ├── cityscapes_100_joint_GAN_gens_small.png
│ ├── cityscapes_100_joint_factorGAN_gens.png
│ ├── cityscapes_100_joint_factorGAN_gens_small.png
│ ├── cityscapes_all_joint_GAN_gens.png
│ └── cityscapes_all_joint_factorGAN_gens.png
├── mnist
│ ├── PairedMNIST Results Diff.csv
│ ├── PairedMNIST Results FID.csv
│ ├── PlotDep.py
│ ├── PlotExamples.py
│ ├── PlotFID.py
│ ├── __init__.py
│ ├── mnist_dep.pdf
│ ├── mnist_examples.pdf
│ └── mnist_fid.pdf
└── source_separation
│ ├── PlotSDR.py
│ ├── __init__.py
│ ├── sdr.csv
│ └── sdr.pdf
├── datasets
├── AudioSeparationDataset.py
├── CropDataset.py
├── DoubleMNISTDataset.py
├── GeneratorInputDataset.py
├── InfiniteDataSampler.py
├── TransformDataSampler.py
├── __init__.py
└── image2image
│ ├── __init__.py
│ ├── aligned_dataset.py
│ ├── base_dataset.py
│ ├── download_image2image.sh
│ └── image_folder.py
├── eval
├── Cityscapes.py
├── FID.py
├── LS.py
├── SourceSeparation.py
├── Visualisation.py
└── __init__.py
├── factorgan_conditional.png
├── factorgan_unconditional.png
├── models
├── MNISTClassifier.py
├── SpectralNorm.py
├── __init__.py
├── discriminators
│ ├── ConvDiscriminator.py
│ ├── FCDiscriminator.py
│ └── __init__.py
└── generators
│ ├── ConvGenerator.py
│ ├── FCGenerator.py
│ ├── Unet.py
│ └── __init__.py
├── requirements.txt
└── training
├── AdversarialTraining.py
├── DiscriminatorTraining.py
├── MNIST.py
├── TrainingOptions.py
└── __init__.py
/AudioSeparation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import os
4 | from torch.utils.data import DataLoader
5 |
6 | import training.TrainingOptions
7 | import training.AdversarialTraining
8 | import Utils
9 | from datasets.GeneratorInputDataset import GeneratorInputDataset
10 | from datasets.InfiniteDataSampler import InfiniteDataSampler
11 | from datasets.TransformDataSampler import TransformDataSampler
12 | from datasets.AudioSeparationDataset import MUSDBDataset
13 | from eval.SourceSeparation import produce_musdb_source_estimates
14 | from models.generators.Unet import Unet
15 | from training.DiscriminatorTraining import DiscriminatorSetup, DependencyDiscriminatorSetup, DependencyDiscriminatorPair
16 | from models.discriminators.ConvDiscriminator import ConvDiscriminator
17 |
18 | def set_paths(opt):
19 | # Set up paths and create folders
20 | opt.experiment_path = os.path.join(opt.out_path, "AudioSeparation", opt.experiment_name)
21 | opt.gen_path = os.path.join(opt.experiment_path, "gen")
22 | opt.log_path = os.path.join(opt.experiment_path, "logs")
23 | opt.estimates_path = os.path.join(opt.experiment_path, "source_estimates")
24 | Utils.make_dirs([opt.experiment_path, opt.gen_path, opt.log_path, opt.estimates_path])
25 |
26 | def train(opt):
27 | Utils.set_seeds(opt)
28 | device = Utils.get_device(opt.cuda)
29 | set_paths(opt)
30 |
31 | if opt.num_joint_songs > 100:
32 | print("ERROR: Cannot train with " + str(opt.num_joint_songs) + " samples, dataset has only size of 100")
33 | return
34 |
35 | # Partition into paired and unpaired songs
36 | idx = [i for i in range(100)]
37 | random.shuffle(idx)
38 |
39 | # Joint samples
40 | dataset_train = MUSDBDataset(opt, idx[:opt.num_joint_songs], "paired")
41 | train_joint = InfiniteDataSampler(DataLoader(dataset_train, num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True, drop_last=True))
42 |
43 | if opt.factorGAN == 1:
44 | # For marginals, take full dataset
45 | mix_dataset = MUSDBDataset(opt, idx, "mix")
46 |
47 | acc_dataset = MUSDBDataset(opt, idx, "accompaniment")
48 | acc_loader = InfiniteDataSampler(DataLoader(acc_dataset, num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True, drop_last=True))
49 |
50 | vocal_dataset = MUSDBDataset(opt, idx, "vocals")
51 | vocal_loader = InfiniteDataSampler(DataLoader(vocal_dataset, num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True, drop_last=True))
52 | else: # For normal GAN, take only few joint songs
53 | mix_dataset = MUSDBDataset(opt, idx[:opt.num_joint_songs], "mix")
54 |
55 | # SETUP GENERATOR MODEL
56 | G = Unet(opt, opt.generator_channels, 1, 1).to(device) # 1 input channel (mixture), 1 output channel (mask)
57 | G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz))
58 | G_opt = Utils.create_optim(G.parameters(), opt)
59 |
60 | # Prepare data sources that are a combination of real data and generator network, or purely from the generator network
61 | G_input_data = DataLoader(GeneratorInputDataset(mix_dataset, G_noise), num_workers=int(opt.workers),
62 | batch_size=opt.batchSize, shuffle=True, drop_last=True)
63 | G_inputs = InfiniteDataSampler(G_input_data)
64 | G_filled_outputs = TransformDataSampler(InfiniteDataSampler(G_inputs), G, device)
65 |
66 | # SETUP DISCRIMINATOR(S)
67 | crop_mix = lambda x: x[:, 1:, :, :] # Only keep sources, not mixture for dep discs
68 | if opt.factorGAN == 1:
69 | # Setup marginal disc networks
70 | D_voc = ConvDiscriminator(opt.input_height, opt.input_width, 1, opt.disc_channels).to(device)
71 | D_acc = ConvDiscriminator(opt.input_height, opt.input_width, 1, opt.disc_channels).to(device)
72 |
73 | D_acc_setup = DiscriminatorSetup("D_acc", D_acc, Utils.create_optim(D_acc.parameters(), opt), acc_loader,
74 | G_filled_outputs, crop_fake=lambda x : x[:,1:2,:,:])
75 |
76 | D_voc_setup = DiscriminatorSetup("D_voc", D_voc, Utils.create_optim(D_voc.parameters(), opt), vocal_loader,
77 | G_filled_outputs, crop_fake=lambda x : x[:,2:3,:,:])
78 | # Marginal discriminator
79 | D_setups = [D_acc_setup, D_voc_setup]
80 |
81 | # If our dep discriminators are only defined on classifier probabilites, integrate classification into discriminator network as first step
82 | if opt.use_real_dep_disc == 1:
83 | DP = ConvDiscriminator(opt.input_height, opt.input_width, 2, opt.disc_channels, spectral_norm=(opt.lipschitz_p == 1)).to(device)
84 | else:
85 | DP = lambda x : 0
86 |
87 | DQ = ConvDiscriminator(opt.input_height, opt.input_width, 2, opt.disc_channels).to(device)
88 |
89 | # Dependency discriminators
90 | shuffle_batch_func = lambda x: Utils.shuffle_batch_dims(x, 1) # Randomly mixes different sources together (e.g. accompaniment from one song with vocals from another)
91 |
92 | if opt.use_real_dep_disc:
93 | DP_setup = DependencyDiscriminatorSetup("DP", DP, Utils.create_optim(DP.parameters(), opt),
94 | train_joint, shuffle_batch_func, crop_func=crop_mix)
95 | else:
96 | DP_setup = None
97 |
98 | DQ_setup = DependencyDiscriminatorSetup("DQ", DQ, Utils.create_optim(DQ.parameters(), opt),
99 | G_filled_outputs, shuffle_batch_func, crop_func=crop_mix)
100 | D_dep_setups = [DependencyDiscriminatorPair(DP_setup, DQ_setup)]
101 | else:
102 | D = ConvDiscriminator(opt.input_height, opt.input_width, 2, opt.disc_channels).to(device)
103 |
104 | D_setup = DiscriminatorSetup("D", D, Utils.create_optim(D.parameters(), opt),
105 | train_joint, G_filled_outputs, crop_real=crop_mix, crop_fake=crop_mix)
106 | D_setups = [D_setup]
107 | D_dep_setups = []
108 |
109 | # RUN TRAINING
110 | training.AdversarialTraining.train(opt, G, G_inputs, G_opt, D_setups, D_dep_setups, device, opt.log_path)
111 | torch.save(G.state_dict(), os.path.join(opt.experiment_path, "G"))
112 |
113 | def eval(opt):
114 | Utils.set_seeds(opt)
115 | device = Utils.get_device(opt.cuda)
116 | set_paths(opt)
117 |
118 | # GENERATOR
119 | # SETUP GENERATOR MODEL
120 | G = Unet(opt, opt.generator_channels, 1, 1).to(device) # 1 input channel (mixture), 1 output channel (mask)
121 | G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz))
122 | G.load_state_dict(torch.load(os.path.join(opt.experiment_path, opt.eval_model)))
123 | G.eval()
124 |
125 | #EVALUATE BY PRODUCING SOURCE ESTIMATE AUDIO AND SDR METRICS
126 | produce_musdb_source_estimates(opt, G, G_noise, opt.estimates_path, subsets="test")
127 |
128 |
129 | def get_opt():
130 | # COLLECT ALL CMD ARGUMENTS
131 | parser = training.TrainingOptions.get_parser()
132 |
133 | parser.add_argument('--musdb_path', type=str, help="Path to MUSDB dataset")
134 | parser.add_argument('--preprocessed_dataset_path', type=str, help="Path to where the preprocessed dataset should be saved")
135 |
136 | parser.add_argument('--num_joint_songs', type=int, default=100,
137 | help="Number of songs from which joint observations are available for training normal gan/dependency discriminators")
138 | parser.add_argument('--hop_size', type=int, default=256,
139 | help="Hop size of FFT")
140 | parser.add_argument('--fft_size', type=int, default=512,
141 | help="Size of FFT")
142 | parser.add_argument('--sample_rate', type=int, default=22050,
143 | help="Resample input audio to this sample rate")
144 | parser.add_argument('--generator_channels', type=int, default=32,
145 | help="Number of intial feature channels used in G. 32 was used in the paper")
146 | parser.add_argument('--disc_channels', type=int, default=32,
147 | help="Number of intial feature channels used in each discriminator")
148 |
149 | opt = parser.parse_args()
150 | print(opt)
151 |
152 | opt.input_height = opt.fft_size // 2 # No. of freq bins for model
153 | opt.input_width = opt.input_height // 2 # No of time frames
154 |
155 | print("Activating generator mask and sigmoid non-linearity for mask")
156 | opt.generator_mask = 1 # Use a mask for Unet output
157 | opt.generator_activation = "sigmoid" # Use sigmoid output for mask
158 |
159 | return opt
160 |
161 | if __name__ == "__main__":
162 | opt = get_opt()
163 |
164 | if not opt.eval:
165 | train(opt)
166 | eval(opt)
--------------------------------------------------------------------------------
/Image2Image.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.tensorboard import SummaryWriter
3 |
4 | import os
5 | from torch.utils.data import DataLoader, Subset
6 |
7 | import training.TrainingOptions
8 | import training.AdversarialTraining
9 | import Utils
10 | from datasets.GeneratorInputDataset import GeneratorInputDataset
11 | from datasets.InfiniteDataSampler import InfiniteDataSampler
12 | from datasets.TransformDataSampler import TransformDataSampler
13 | from datasets.image2image import get_aligned_dataset
14 | from eval.Cityscapes import get_L2, get_pixel_acc
15 | from eval.Visualisation import generate_images
16 | from models.generators.Unet import Unet
17 | from training.DiscriminatorTraining import DiscriminatorSetup, DependencyDiscriminatorSetup, DependencyDiscriminatorPair
18 | from models.discriminators.ConvDiscriminator import ConvDiscriminator
19 | from datasets.CropDataset import CropDataset
20 |
21 | def set_paths(opt):
22 | # Set up paths and create folders
23 | opt.experiment_path = os.path.join(opt.out_path, "Image2Image_" + str(opt.dataset), opt.experiment_name)
24 | opt.gen_path = os.path.join(opt.experiment_path, "gen")
25 | opt.log_path = os.path.join(opt.experiment_path, "logs")
26 | Utils.make_dirs([opt.experiment_path, opt.gen_path, opt.log_path])
27 |
28 | def train(opt):
29 | Utils.set_seeds(opt)
30 | device = Utils.get_device(opt.cuda)
31 | set_paths(opt)
32 |
33 | # DATA
34 | dataset = get_aligned_dataset(opt, "train")
35 | nc = dataset.A_nc + dataset.B_nc
36 |
37 | if opt.num_joint_samples > len(dataset):
38 | print("WARNING: Cannot train with " + str(opt.num_joint_samples) + " samples, dataset has only size of " + str(len(dataset))+ ". Using full dataset!")
39 | opt.num_joint_samples = len(dataset)
40 |
41 | # Joint samples
42 | dataset_train = Subset(dataset, range(opt.num_joint_samples))
43 | train_joint = InfiniteDataSampler(
44 | DataLoader(dataset_train, num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True, drop_last=True))
45 |
46 | if opt.factorGAN == 1:
47 | # For marginals, take full dataset and crop
48 | a_dataset = CropDataset(dataset, lambda x : x[0:dataset.A_nc, :, :])
49 | b_dataset = CropDataset(dataset, lambda x : x[dataset.A_nc:, :, :])
50 | train_b = InfiniteDataSampler(DataLoader(b_dataset, num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True, drop_last=True))
51 |
52 | generator_input_data = a_dataset
53 | else:
54 | generator_input_data = CropDataset(dataset_train, lambda x : x[0:dataset.A_nc, :, :])
55 |
56 | # SETUP GENERATOR MODEL
57 | G = Unet(opt, opt.generator_channels, dataset.A_nc, dataset.B_nc).to(device)
58 | G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz))
59 | G_opt = Utils.create_optim(G.parameters(), opt)
60 | # Prepare data sources that are a combination of real data and generator network, or purely from the generator network
61 | G_input_data = DataLoader(GeneratorInputDataset(generator_input_data, G_noise), num_workers=int(opt.workers),
62 | batch_size=opt.batchSize, shuffle=True, drop_last=True)
63 | G_inputs = InfiniteDataSampler(G_input_data)
64 | G_filled_outputs = TransformDataSampler(InfiniteDataSampler(G_input_data), G, device)
65 |
66 | # SETUP DISCRIMINATOR(S)
67 | if opt.factorGAN == 1:
68 | # Setup disc networks
69 | D2 = ConvDiscriminator(opt.loadSize, opt.loadSize, dataset.B_nc).to(device)
70 | # If our dep discriminators are only defined on classifier probabilites, integrate classification into discriminator network as first step
71 | if opt.use_real_dep_disc == 1:
72 | DP = ConvDiscriminator(opt.loadSize, opt.loadSize, nc, spectral_norm=(opt.lipschitz_p == 1)).to(device)
73 | else:
74 | DP = lambda x : 0
75 |
76 | DQ = ConvDiscriminator(opt.loadSize, opt.loadSize, nc).to(device)
77 |
78 | # Prepare discriminators for training method
79 | # Marginal discriminator
80 | D_setups = [DiscriminatorSetup("D2", D2, Utils.create_optim(D2.parameters(), opt),
81 | train_b, G_filled_outputs, crop_fake=lambda x: x[:, dataset.A_nc:, :, :])]
82 | # Dependency discriminators
83 | shuffle_batch_func = lambda x: Utils.shuffle_batch_dims(x, [dataset.A_nc])
84 | if opt.use_real_dep_disc:
85 | DP_setup = DependencyDiscriminatorSetup("DP", DP, Utils.create_optim(DP.parameters(), opt),
86 | train_joint, shuffle_batch_func)
87 | else:
88 | DP_setup = None
89 |
90 | DQ_setup = DependencyDiscriminatorSetup("DQ", DQ, Utils.create_optim(DQ.parameters(), opt),
91 | G_filled_outputs, shuffle_batch_func)
92 | D_dep_setups = [DependencyDiscriminatorPair(DP_setup, DQ_setup)]
93 | else:
94 | D = ConvDiscriminator(opt.loadSize, opt.loadSize, nc).to(device)
95 | print(sum(p.numel() for p in D.parameters()))
96 |
97 | D_setup = DiscriminatorSetup("D", D, Utils.create_optim(D.parameters(), opt),
98 | train_joint, G_filled_outputs)
99 | D_setups = [D_setup]
100 | D_dep_setups = []
101 |
102 | # RUN TRAINING
103 | training.AdversarialTraining.train(opt, G, G_inputs, G_opt, D_setups, D_dep_setups, device, opt.log_path)
104 | torch.save(G.state_dict(), os.path.join(opt.experiment_path, "G"))
105 |
106 | def eval(opt):
107 | Utils.set_seeds(opt)
108 | device = Utils.get_device(opt.cuda)
109 | set_paths(opt)
110 |
111 | # DATASET
112 | dataset = get_aligned_dataset(opt, "val")
113 | input_dataset = CropDataset(dataset, lambda x: x[0:dataset.A_nc, :, :])
114 |
115 | # GENERATOR
116 | G = Unet(opt, opt.generator_channels, dataset.A_nc, dataset.B_nc).to(device)
117 | G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz))
118 | G.load_state_dict(torch.load(os.path.join(opt.experiment_path, opt.eval_model)))
119 | G.eval()
120 |
121 | # EVALUATE: Generate some images using test set and noise as conditional input
122 | G_input_data = DataLoader(GeneratorInputDataset(input_dataset, G_noise), num_workers=int(opt.workers),
123 | batch_size=opt.batchSize, shuffle=False)
124 | G_inputs = InfiniteDataSampler(G_input_data)
125 |
126 | generate_images(G, G_inputs, opt.gen_path, 100, device, lambda x : Utils.create_image_pair(x, dataset.A_nc, dataset.B_nc))
127 |
128 | # EVALUATE for Cityscapes
129 | if opt.dataset == "cityscapes":
130 | writer = SummaryWriter(opt.log_path)
131 | val_input_data = DataLoader(dataset, num_workers=int(opt.workers),batch_size=opt.batchSize)
132 |
133 | pixel_error = get_pixel_acc(opt, device, G, val_input_data, G_noise)
134 | print("VALIDATION PERFORMANCE Pixel: " + str(pixel_error))
135 | writer.add_scalar("val_pix", pixel_error)
136 |
137 | L2_error = get_L2(opt, device, G, val_input_data, G_noise)
138 | print("VALIDATION PERFORMANCE L2: " + str(L2_error))
139 | writer.add_scalar("val_L2", L2_error)
140 |
141 |
142 | def get_opt():
143 | # COLLECT ALL CMD ARGUMENTS
144 | parser = training.TrainingOptions.get_parser()
145 |
146 | parser.add_argument('--dataset', type=str, default="cityscapes",
147 | help="Dataset to train on - can be cityscapes or edges2shoes (but other img2img datasets can be integrated easily")
148 | parser.add_argument('--num_joint_samples', type=int, default=1000,
149 | help="Number of joint observations available for training normal gan/dependency discriminators")
150 | parser.add_argument('--loadSize', type=int, default=128,
151 | help="Dimensions (no. of pixels) the dataset images are resampled to")
152 | parser.add_argument('--generator_channels', type=int, default=32,
153 | help="Number of intial feature channels used in G. 64 was used in the paper")
154 |
155 | opt = parser.parse_args()
156 | print(opt)
157 |
158 | # Set generator to sigmoid output
159 | opt.generator_activation = "sigmoid"
160 |
161 | # Square images => loadSize determines width and height
162 | opt.input_width = opt.loadSize
163 | opt.input_height = opt.loadSize
164 |
165 | return opt
166 |
167 | if __name__ == "__main__":
168 | opt = get_opt()
169 |
170 | if not opt.eval:
171 | train(opt)
172 | eval(opt)
--------------------------------------------------------------------------------
/Image2ImageGrid.py:
--------------------------------------------------------------------------------
1 | # This script will train image segmentation models in different dataset configurations as in the paper
2 | import Image2Image
3 |
4 | opt = Image2Image.get_opt()
5 |
6 | for dataset_name in ["cityscapes", "edges2shoes"]: # Iterate over datasets, more datasets could be added here like maps
7 | opt.dataset = dataset_name
8 | for num_joint_samples in [100, 1000, 10000]: # Try for different amount of paired samples
9 | # Apply settings
10 | print(str(num_joint_samples) + " joint samples")
11 | opt.num_joint_samples = num_joint_samples
12 |
13 | print("Training GAN")
14 | opt.experiment_name = str(num_joint_samples) + "_joint_GAN"
15 | opt.factorGAN = 0
16 | Image2Image.train(opt)
17 |
18 | print("Training factorGAN")
19 | opt.experiment_name = str(num_joint_samples) + "_joint_factorGAN"
20 | opt.factorGAN = 1
21 | Image2Image.train(opt)
--------------------------------------------------------------------------------
/ImagePairs.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import torch
3 | import os
4 | from torch.utils.data import DataLoader, Subset
5 |
6 | import training.TrainingOptions
7 | import training.AdversarialTraining
8 | import Utils
9 | from datasets.GeneratorInputDataset import GeneratorInputDataset
10 | from datasets.InfiniteDataSampler import InfiniteDataSampler
11 | from datasets.TransformDataSampler import TransformDataSampler
12 | from datasets.image2image import get_aligned_dataset
13 | from eval import LS
14 | from eval.Visualisation import generate_images
15 | from models.generators.ConvGenerator import ConvGenerator
16 | from training.DiscriminatorTraining import DiscriminatorSetup, DependencyDiscriminatorSetup, DependencyDiscriminatorPair
17 | from models.discriminators.ConvDiscriminator import ConvDiscriminator
18 | from datasets.CropDataset import CropDataset
19 |
20 | def set_paths(opt):
21 | # Set up paths and create folders
22 | opt.experiment_path = os.path.join(opt.out_path, "ImagePairs", opt.dataset, opt.experiment_name)
23 | opt.gen_path = os.path.join(opt.experiment_path, "gen")
24 | opt.log_path = os.path.join(opt.experiment_path, "logs")
25 | Utils.make_dirs([opt.experiment_path, opt.gen_path, opt.log_path])
26 |
27 | def train(opt):
28 | Utils.set_seeds(opt)
29 | device = Utils.get_device(opt.cuda)
30 | set_paths(opt)
31 |
32 | # DATA
33 | dataset = get_aligned_dataset(opt, "train")
34 | nc = dataset.A_nc + dataset.B_nc
35 |
36 | # Warning if desired number of joint samples is larger than dataset, in that case, use whole dataset as paired
37 | if opt.num_joint_samples > len(dataset):
38 | print("WARNING: Cannot train with " + str(opt.num_joint_samples) + " samples, dataset has only size of " + str(len(dataset))+ ". Using full dataset!")
39 | opt.num_joint_samples = len(dataset)
40 |
41 | # Joint samples
42 | dataset_train = Subset(dataset, range(opt.num_joint_samples))
43 | train_joint = InfiniteDataSampler(
44 | DataLoader(dataset_train, num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True, drop_last=True))
45 |
46 | if opt.factorGAN == 1:
47 | # For marginals, take full dataset and crop
48 | train_a = InfiniteDataSampler(DataLoader(CropDataset(dataset, lambda x : x[0:dataset.A_nc, :, :]),
49 | num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True))
50 | train_b = InfiniteDataSampler(DataLoader(CropDataset(dataset, lambda x : x[dataset.A_nc:, :, :]),
51 | num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True))
52 |
53 | # SETUP GENERATOR MODEL
54 | G = ConvGenerator(opt, opt.generator_channels, opt.loadSize, nc).to(device)
55 | G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz))
56 | G_opt = Utils.create_optim(G.parameters(), opt)
57 |
58 | # Prepare data sources that are a combination of real data and generator network, or purely from the generator network
59 | G_input_data = DataLoader(GeneratorInputDataset(None, G_noise), num_workers=int(opt.workers),
60 | batch_size=opt.batchSize, shuffle=True)
61 | G_inputs = InfiniteDataSampler(G_input_data)
62 | G_outputs = TransformDataSampler(InfiniteDataSampler(G_input_data), G, device)
63 |
64 | # SETUP DISCRIMINATOR(S)
65 | if opt.factorGAN == 1:
66 | # Setup disc networks
67 | D1 = ConvDiscriminator(opt.loadSize, opt.loadSize, dataset.A_nc, opt.disc_channels).to(device)
68 | D2 = ConvDiscriminator(opt.loadSize, opt.loadSize, dataset.B_nc, opt.disc_channels).to(device)
69 | # If our dep discriminators are only defined on classifier probabilites, integrate classification into discriminator network as first step
70 | if opt.use_real_dep_disc == 1:
71 | DP = ConvDiscriminator(opt.loadSize, opt.loadSize, nc, opt.disc_channels, spectral_norm=(opt.lipschitz_p == 1)).to(device)
72 | else:
73 | DP = lambda x : 0
74 |
75 | DQ = ConvDiscriminator(opt.loadSize, opt.loadSize, nc, opt.disc_channels).to(device)
76 | print(sum(p.numel() for p in D1.parameters()))
77 |
78 | # Prepare discriminators for training method
79 | # Marginal discriminators
80 | D1_setup = DiscriminatorSetup("D1", D1, Utils.create_optim(D1.parameters(), opt),
81 | train_a, G_outputs, crop_fake=lambda x : x[:, 0:dataset.A_nc, :, :])
82 | D2_setup = DiscriminatorSetup("D2", D2, Utils.create_optim(D2.parameters(), opt),
83 | train_b, G_outputs, crop_fake=lambda x : x[:, dataset.A_nc:, :, :])
84 | D_setups = [D1_setup, D2_setup]
85 |
86 | # Dependency discriminators
87 | shuffle_batch_func = lambda x: Utils.shuffle_batch_dims(x, [dataset.A_nc])
88 | if opt.use_real_dep_disc:
89 | DP_setup = DependencyDiscriminatorSetup("DP", DP, Utils.create_optim(DP.parameters(), opt),
90 | train_joint, shuffle_batch_func)
91 | else:
92 | DP_setup = None
93 |
94 | DQ_setup = DependencyDiscriminatorSetup("DQ", DQ,Utils.create_optim(DQ.parameters(), opt),
95 | G_outputs, shuffle_batch_func)
96 | D_dep_setups = [DependencyDiscriminatorPair(DP_setup, DQ_setup)]
97 | else:
98 | D = ConvDiscriminator(opt.loadSize, opt.loadSize, nc, opt.disc_channels).to(device)
99 | print(sum(p.numel() for p in D.parameters()))
100 | D_setups = [DiscriminatorSetup("D", D, Utils.create_optim(D.parameters(), opt), train_joint, G_outputs)]
101 | D_dep_setups = []
102 |
103 | # RUN TRAINING
104 | training.AdversarialTraining.train(opt, G, G_inputs, G_opt, D_setups, D_dep_setups, device, opt.log_path)
105 | torch.save(G.state_dict(), os.path.join(opt.experiment_path, "G"))
106 |
107 | def eval(opt):
108 | device = Utils.get_device(opt.cuda)
109 | set_paths(opt)
110 |
111 | # Get test dataset
112 | dataset = get_aligned_dataset(opt, "val")
113 | nc = dataset.A_nc + dataset.B_nc
114 |
115 | # SETUP GENERATOR MODEL
116 | G = ConvGenerator(opt, opt.generator_channels, opt.loadSize, nc).to(device)
117 | G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz))
118 |
119 | # Prepare data sources that are a combination of real data and generator network, or purely from the generator network
120 | G_input_data = DataLoader(GeneratorInputDataset(None, G_noise), num_workers=int(opt.workers),
121 | batch_size=opt.batchSize, shuffle=True)
122 | G_inputs = InfiniteDataSampler(G_input_data)
123 | G_outputs = TransformDataSampler(InfiniteDataSampler(G_input_data), G, device)
124 | G.load_state_dict(torch.load(os.path.join(opt.experiment_path, opt.eval_model)))
125 | G.eval()
126 |
127 | # EVALUATE
128 | # GENERATE EXAMPLES
129 | generate_images(G, G_inputs, opt.gen_path, 1000, device, lambda x: Utils.create_image_pair(x, dataset.A_nc, dataset.B_nc))
130 |
131 | # COMPUTE LS DISTANCE
132 | # Partition into test train and test test
133 | test_train_samples = int(0.8 * float(len(dataset)))
134 | test_test_samples = len(dataset) - test_train_samples
135 | print("VALIDATION SAMPLES: " + str(test_train_samples))
136 | print("TEST SAMPLES: " + str(test_test_samples))
137 | real_test_train_loader = DataLoader(Subset(dataset, range(test_train_samples)), num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True, drop_last=True)
138 | real_test_test_loader = DataLoader(Subset(dataset, range(test_train_samples, len(dataset))), num_workers=int(opt.workers), batch_size=opt.batchSize)
139 |
140 | # Initialise classifier
141 | classifier_factory = lambda : ConvDiscriminator(opt.loadSize, opt.loadSize, nc, filters=opt.ls_channels, spectral_norm=False).to(device)
142 | # Compute metric
143 | losses = LS.compute_ls_metric(classifier_factory, real_test_train_loader, real_test_test_loader, G_outputs, opt.ls_runs, device)
144 |
145 | # WRITE RESULTS INTO CSV FOR LATER ANALYSIS
146 | file_existed = os.path.exists(os.path.join(opt.experiment_path, "LS.csv"))
147 | with open(os.path.join(opt.experiment_path, "LS.csv"), "a") as csv_file:
148 | writer = csv.writer(csv_file)
149 | model = "factorGAN" if opt.factorGAN else "gan"
150 | if not file_existed:
151 | writer.writerow(["LS", "Model", "Samples", "Dataset", "Samples_Validation","Samples_Test"])
152 | for val in losses:
153 | writer.writerow([val, model, opt.num_joint_samples, opt.dataset, test_train_samples, test_test_samples])
154 |
155 | def get_opt():
156 | # COLLECT ALL CMD ARGUMENTS
157 | parser = training.TrainingOptions.get_parser()
158 |
159 | parser.add_argument('--dataset', type=str, default="edges2shoes",
160 | help="Dataset to train on - can be cityscapes or edges2shoes (but other img2img datasets can be integrated easily")
161 | parser.add_argument('--num_joint_samples', type=int, default=1000,
162 | help="Number of joint observations available for training normal gan/dependency discriminators")
163 | parser.add_argument('--loadSize', type=int, default=64,
164 | help="Dimensions (no. of pixels) the dataset images are resampled to")
165 | parser.add_argument('--generator_channels', type=int, default=64,
166 | help="Number of intial feature channels used in G. 64 was used in the paper")
167 | parser.add_argument('--disc_channels', type=int, default=32,
168 | help="Number of intial feature channels used in each discriminator")
169 |
170 |
171 | # LS distance eval settings
172 | parser.add_argument('--ls_runs', type=int, default=10,
173 | help="Number of LS Discriminator training runs for evaluation")
174 | parser.add_argument('--ls_channels', type=int, default=16,
175 | help="Number of initial feature channels used for LS discriminator. 16 in the paper")
176 |
177 |
178 | opt = parser.parse_args()
179 | print(opt)
180 |
181 | # Set generator to sigmoid output
182 | opt.generator_activation = "sigmoid"
183 |
184 | return opt
185 |
186 | if __name__ == "__main__":
187 | opt = get_opt()
188 |
189 | if not opt.eval:
190 | train(opt)
191 | eval(opt)
--------------------------------------------------------------------------------
/ImagePairsGrid.py:
--------------------------------------------------------------------------------
1 | # Run ImagePairs experiment for multiple datasets and number of paired samples, using GAN and FactorGAN
2 |
3 | import ImagePairs
4 | opt = ImagePairs.getImageOpt()
5 |
6 | for dataset_name in ["cityscapes", "edges2shoes"]:
7 | opt.dataset = dataset_name
8 | for num_joint_samples in [100, 1000, 10000]:
9 |
10 | # Apply settings
11 | print(str(num_joint_samples) + " joint samples")
12 | opt.num_joint_samples = num_joint_samples
13 |
14 | print("Training GAN")
15 | opt.experiment_name = str(num_joint_samples) + "_joint_GAN"
16 | opt.factorGAN = 0
17 | ImagePairs.train(opt)
18 |
19 | print("Training factorGAN")
20 | opt.experiment_name = str(num_joint_samples) + "_joint_factorGAN"
21 | opt.factorGAN = 1
22 | ImagePairs.train(opt)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Daniel Stoller
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/PairedMNIST.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.tensorboard import SummaryWriter
3 | from torchvision import datasets
4 | import os
5 | from torch.utils.data import DataLoader
6 | import training.TrainingOptions
7 | import training.AdversarialTraining
8 | import Utils
9 | from datasets.GeneratorInputDataset import GeneratorInputDataset
10 | from datasets.InfiniteDataSampler import InfiniteDataSampler
11 | from datasets.TransformDataSampler import TransformDataSampler
12 | from datasets.DoubleMNISTDataset import DoubleMNISTDataset
13 | from eval import Visualisation, FID
14 | from training import MNIST
15 | from training.DiscriminatorTraining import DiscriminatorSetup, DependencyDiscriminatorSetup, DependencyDiscriminatorPair
16 | from models.discriminators.FCDiscriminator import FCDiscriminator
17 | from models.generators.FCGenerator import FCGenerator
18 | from datasets.CropDataset import CropDataset
19 | import numpy as np
20 |
21 | def set_paths(opt):
22 | # PATHS
23 | opt.experiment_path = os.path.join(opt.out_path, "PairedMNIST", opt.experiment_name)
24 | opt.gen_path = os.path.join(opt.experiment_path, "gen")
25 | opt.log_path = os.path.join(opt.experiment_path, "logs")
26 | Utils.make_dirs([opt.experiment_path, opt.gen_path, opt.log_path])
27 |
28 | def predict_digits_batch(classifier, two_digit_input):
29 | '''
30 | Takes MNIST classifier and paired-MNIST sample and gives digit label probabilities for both
31 | :param classifier: MNIST classifier model
32 | :param two_digit_input: Paired MNIST sample
33 | :return: 20-dim. vector containing 2*10 digit label probabilities for upper and lower digit
34 | '''
35 | if len(two_digit_input.shape) == 2: # B, X
36 | two_digit_input = two_digit_input.view(-1, 1, 56, 28)
37 | elif len(two_digit_input.shape) == 3: # B, H, W
38 | two_digit_input = two_digit_input.unsqueeze(1)
39 |
40 | upper_digit = two_digit_input[:, :, :28, :]
41 | lower_digit = two_digit_input[:, :, 28:, :]
42 |
43 | probs = torch.cat([classifier(upper_digit), classifier(lower_digit)], dim=1)
44 | return probs
45 |
46 | def get_class_prob_matrix(G, G_inputs, classifier, num_samples, device):
47 | '''
48 | Build matrix of digit label combination frequencies (10x10)
49 | :param G: Generator model
50 | :param G_inputs: Input data sampler for generator (noise)
51 | :param classifier: MNIST classifier model
52 | :param num_samples: Number of samples to draw from generator to estimate c_Q
53 | :param device: Device to use
54 | :return: Normalised frequency of digit combination occurrences (10x10 matrix)
55 | '''
56 | it = 0
57 | joint_class_probs = np.zeros([10, 10])
58 | while(True):
59 | input_batch = next(G_inputs)
60 | # Get generator samples
61 | input_batch = [item.to(device) for item in input_batch]
62 | gen_batch = G(input_batch)
63 | # Feed through classifier
64 | digit_preds = predict_digits_batch(classifier, gen_batch)
65 | for pred in digit_preds:
66 | upper_pred = torch.argmax(pred[:10])
67 | lower_pred = torch.argmax(pred[10:])
68 | joint_class_probs[upper_pred, lower_pred] += 1
69 |
70 | it += 1
71 | if it >= num_samples:
72 | return joint_class_probs / np.sum(joint_class_probs)
73 |
74 | def train(opt):
75 | print("Using " + str(opt.num_joint_samples) + " joint samples!")
76 | Utils.set_seeds(opt)
77 | device = Utils.get_device(opt.cuda)
78 |
79 | # DATA
80 | MNIST_dim = 784
81 | dataset = datasets.MNIST('datasets', train=True, download=True)
82 |
83 | # Create partitions of stacked MNIST
84 | dataset_joint = DoubleMNISTDataset(dataset, range(opt.num_joint_samples),same_digit_prob=opt.mnist_same_digit_prob)
85 | train_joint = InfiniteDataSampler(DataLoader(dataset_joint, num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True))
86 | if opt.factorGAN == 1:
87 | # For marginals, take full dataset and crop it
88 | full_dataset = DoubleMNISTDataset(dataset, None, same_digit_prob=opt.mnist_same_digit_prob)
89 | train_x1 = InfiniteDataSampler(DataLoader(CropDataset(full_dataset, lambda x : x[:MNIST_dim]), num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True))
90 | train_x2 = InfiniteDataSampler(DataLoader(CropDataset(full_dataset, lambda x : x[MNIST_dim:]), num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True))
91 |
92 | # SETUP GENERATOR MODEL
93 | G = FCGenerator(opt, 2*MNIST_dim).to(device)
94 | G.train()
95 | G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz))
96 | G_opt = Utils.create_optim(G.parameters(), opt)
97 |
98 | # Prepare data sources that are a combination of real data and generator network, or purely from the generator network
99 | G_input_data = DataLoader(GeneratorInputDataset(None, G_noise), num_workers=int(opt.workers),
100 | batch_size=opt.batchSize, shuffle=True)
101 | G_inputs = InfiniteDataSampler(G_input_data)
102 | G_outputs = TransformDataSampler(InfiniteDataSampler(G_input_data), G, device)
103 |
104 | # SETUP DISCRIMINATOR(S)
105 | if opt.factorGAN == 1:
106 | # Setup disc networks
107 | D1 = FCDiscriminator(MNIST_dim).to(device)
108 | D2 = FCDiscriminator(MNIST_dim).to(device)
109 | # If our dep discriminators are only defined on classifier probabilites, integrate classification into discriminator network as first step
110 | if opt.use_real_dep_disc == 1:
111 | DP = FCDiscriminator(2 * MNIST_dim,spectral_norm=(opt.lipschitz_p == 1)).to(device)
112 | else:
113 | DP = lambda x : 0
114 |
115 | DQ = FCDiscriminator(2 * MNIST_dim).to(device)
116 |
117 | # Prepare discriminators for training method
118 | # Marginal discriminators
119 | D1_setup = DiscriminatorSetup("D1", D1, Utils.create_optim(D1.parameters(), opt),
120 | train_x1, G_outputs, crop_fake=lambda x: x[:, :MNIST_dim])
121 | D2_setup = DiscriminatorSetup("D2", D2, Utils.create_optim(D2.parameters(), opt),
122 | train_x2, G_outputs, crop_fake=lambda x: x[:, MNIST_dim:])
123 | D_setups = [D1_setup, D2_setup]
124 |
125 | # Dependency discriminators
126 | shuffle_batch_func = lambda x: Utils.shuffle_batch_dims(x, marginal_index=MNIST_dim)
127 |
128 | if opt.use_real_dep_disc:
129 | DP_setup = DependencyDiscriminatorSetup("DP", DP, Utils.create_optim(DP.parameters(), opt), train_joint, shuffle_batch_func)
130 | else:
131 | DP_setup = None
132 | DQ_setup = DependencyDiscriminatorSetup("DQ", DQ, Utils.create_optim(DQ.parameters(), opt), G_outputs, shuffle_batch_func)
133 | D_dep_setups = [DependencyDiscriminatorPair(DP_setup, DQ_setup)]
134 | else:
135 | D = FCDiscriminator(2*MNIST_dim).to(device)
136 | D_setups = [DiscriminatorSetup("D", D, Utils.create_optim(D.parameters(), opt), train_joint, G_outputs)]
137 | D_dep_setups = []
138 |
139 | # RUN TRAINING
140 | training.AdversarialTraining.train(opt, G, G_inputs, G_opt, D_setups, D_dep_setups, device, opt.log_path)
141 | torch.save(G.state_dict(), os.path.join(opt.out_path, "G"))
142 |
143 | def eval(opt):
144 | print("EVALUATING MNIST MODEL...")
145 | MNIST_dim = 784
146 | device = Utils.get_device(opt.cuda)
147 |
148 | # Train and save a digit classification model, needed for factorGAN variants and evaluation
149 | classifier = MNIST.main(opt)
150 | classifier.to(device)
151 | classifier.eval()
152 |
153 | # SETUP GENERATOR MODEL
154 | G = FCGenerator(opt, 2 * MNIST_dim).to(device)
155 | G_noise = torch.distributions.uniform.Uniform(torch.Tensor([-1] * opt.nz), torch.Tensor([1] * opt.nz))
156 | # Prepare data sources that are a combination of real data and generator network, or purely from the generator network
157 | G_input_data = DataLoader(GeneratorInputDataset(None, G_noise), num_workers=int(opt.workers),
158 | batch_size=opt.batchSize, shuffle=True)
159 | G_inputs = InfiniteDataSampler(G_input_data)
160 |
161 | G.load_state_dict(torch.load(os.path.join(opt.experiment_path, opt.eval_model)))
162 | G.eval()
163 |
164 | # EVALUATE: Save images to eyeball them + FID for marginals + Class probability correlations
165 | writer = SummaryWriter(opt.log_path)
166 |
167 | test_mnist = datasets.MNIST('datasets', train=False, download=True)
168 | test_dataset = DoubleMNISTDataset(test_mnist, None, same_digit_prob=opt.mnist_same_digit_prob)
169 | test_dataset_loader = DataLoader(test_dataset, num_workers=int(opt.workers), batch_size=opt.batchSize, shuffle=True)
170 | transform_func = lambda x: x.view(-1, 1, 56, 28)
171 | Visualisation.generate_images(G, G_inputs, opt.gen_path, len(test_dataset), device, transform_func)
172 |
173 | crop_upper = lambda x: x[:, :, :28, :]
174 | crop_lower = lambda x: x[:, :, 28:, :]
175 | fid_upper = FID.evaluate_MNIST(opt, classifier, test_dataset_loader, opt.gen_path, device,crop_real=crop_upper,crop_fake=crop_upper)
176 | fid_lower = FID.evaluate_MNIST(opt, classifier, test_dataset_loader, opt.gen_path, device,crop_real=crop_lower,crop_fake=crop_lower)
177 | print("FID Upper Digit: " + str(fid_upper))
178 | print("FID Lower Digit: " + str(fid_lower))
179 | writer.add_scalar("FID_lower", fid_lower)
180 | writer.add_scalar("FID_upper", fid_upper)
181 |
182 | # ESTIMATE QUALITY OF DEPENDENCY MODELLING
183 | # cp(...) = cq(...) ideally for all inputs on the test set if dependencies are perfectly modelled. So compute average of that value and take difference to 1
184 | # Get joint distribution of real class indices in the data
185 | test_dataset = DoubleMNISTDataset(test_mnist, None,
186 | same_digit_prob=opt.mnist_same_digit_prob, deterministic=True, return_labels=True)
187 | test_it = DataLoader(test_dataset)
188 | real_class_probs = np.zeros((10, 10))
189 | for sample in test_it:
190 | _, d1, d2 = sample
191 | real_class_probs[d1, d2] += 1
192 | real_class_probs /= np.sum(real_class_probs)
193 |
194 | # Compute marginal distribution of real class indices from joint one
195 | real_class_probs_upper = np.sum(real_class_probs, axis=1) # a
196 | real_class_probs_lower = np.sum(real_class_probs, axis=0) # b
197 | real_class_probs_marginal = real_class_probs_upper * np.reshape(real_class_probs_lower, [-1, 1])
198 |
199 | # Get joint distribution of class indices on generated data (using classifier predictions)
200 | fake_class_probs = get_class_prob_matrix(G, G_inputs, classifier, len(test_dataset), device)
201 | # Compute marginal distribution of class indices on generated data
202 | fake_class_probs_upper = np.sum(fake_class_probs, axis=1)
203 | fake_class_probs_lower = np.sum(fake_class_probs, axis=0)
204 | fake_class_probs_marginal = fake_class_probs_upper * np.reshape(fake_class_probs_lower, [-1, 1])
205 |
206 | # Compute average of |cp(...) - cq(...)|
207 | cp = np.divide(real_class_probs, real_class_probs_marginal + 0.001)
208 | cq = np.divide(fake_class_probs, fake_class_probs_marginal + 0.001)
209 |
210 | diff_metric = np.mean(np.abs(cp - cq))
211 |
212 | print("Dependency cp/cq diff metric: " + str(diff_metric))
213 | writer.add_scalar("Diff-Dep", diff_metric)
214 |
215 | return fid_upper, fid_lower
216 |
217 | def get_opt():
218 | # COLLECT ALL CMD ARGUMENTS
219 | parser = training.TrainingOptions.get_parser()
220 |
221 | parser.add_argument('--mnist_same_digit_prob', type=float, default=0.4,
222 | help="Probability of same digits occuring together. 0.1 means indpendently put together, 1.0 means always same digits, 0.0 never same digits")
223 | parser.add_argument('--num_joint_samples', type=int, default=50,
224 | help="Number of joint observations available for training normal gan/dependency discriminators")
225 |
226 | opt = parser.parse_args()
227 | # Set generator to sigmoid output
228 | opt.generator_activation = "sigmoid"
229 | print(opt)
230 | return opt
231 |
232 | if __name__ == "__main__":
233 | opt = get_opt()
234 |
235 | set_paths(opt)
236 |
237 | if not opt.eval:
238 | train(opt)
239 | eval(opt)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Overview
2 |
3 | This respository implements "FactorGAN" as described in the paper
4 |
5 | "Training Generative Adversarial Networks from Incomplete Observations using Factorised Discriminators"
6 |
7 | ## FactorGAN - Quick introduction
8 |
9 | Consider training a GAN to solve some generation task, where a sample can be naturally divided into multiple parts.
10 | As one example, we will use images of shoes along with a drawing of their edges (see diagram below).
11 |
12 | With a normal GAN for this task, the generator just outputs a pair of images, which the discriminator then analyses as a whole.
13 | For training the discriminator, you use real shoe-edge image pairs from a "paired" dataset.
14 |
15 | But what if you only have a few of these paired samples, but many more individual shoe OR edge images?
16 | You cannot use these for GAN training to improve the quality of your shoes and edges further, as the discriminator needs shoe-edge image pairs.
17 |
18 | That is where the FactorGAN comes in. As shown below, it makes use of all available data (shoe-edge pairs, shoes alone, edges alone):
19 |
20 |
21 |
22 | To achieve this, FactorGAN uses four discriminators:
23 | * a discriminator to judge the generator's shoe quality
24 | * another to do the same for the edges
25 | * two "dependency discriminators" to ensure the generator outputs edge maps that fit to their respective shoe images:
26 | * The "real dependency" discriminator tries to distinguish real paired examples from real ones where each shoe was randomly assigned to an edge map (by "shuffling" the real batch), thereby having to learn which edge maps go along which shoes.
27 | * The "fake dependency" discriminator does the same for generator samples.
28 | The real dependency discriminator is the only component that needs paired samples for training, while the other components can make use of the extra available shoes and edge images.
29 |
30 | Training works by alternating between a) updating all discriminators (individually) and b) updating the generator.
31 | Amazingly, we can update the generator just like in a normal GAN, by simply adding the unnormalised discriminator outputs and using the result for the generator loss.
32 | This combined output can be proven to approximate the same probability for real and fake inputs as estimated by a normal GAN discriminator.
33 |
34 | In our experiments, the FactorGAN provides very good output quality even with just very few paired samples, as the shoe and edge discriminators trained on the additional unpaired samples help the generator to output realistic shoes and edge maps.
35 |
36 | This principle can also be used for conditional generation (aka prediction tasks).
37 | Let's take image segmentation as an example:
38 |
39 |
40 |
41 | The generator now acts as a segmentation model, predicting the segmentation from a given city scene.
42 | In contrast to a normal conditional GAN, whose discriminator requires the scene along with its segmentation as "paired" input, here we use
43 | * a discriminator acting only on real and fake segmentations, trainable with individual scenes and segmentation maps, ensuring the predicted segmentation is "realistic" on its own, irrespective of the scene it was predicted from
44 | * a fake dependency discriminator that distinguishes (real scene, fake segmentation) pairs from their shuffled variant, to learn how the generator output corresponds to the input
45 | * a real dependency discriminator that distinguishes (real scene, real segmentation) pairs from their shuffled variants.
46 |
47 | We perform segmentaiton experiments on the Cityscapes dataset, treating the samples as unpaired (like the CycleGAN).
48 | But we find that adding as few as 25 paired samples yields substantially higher segmentation accuracy than the CycleGAN - suggesting that the FactorGAN fills a gap between fully unsupervised and fully supervised methods by making efficient use of both paired and unpaired samples.
49 |
50 | # Requirements
51 |
52 | * Python 3.6
53 | * Pip for installing Python packages
54 | * [libsnd](http://www.mega-nerd.com/libsndfile/) library installed
55 | * [wget](https://www.gnu.org/software/wget/) installed for downloading the datasets automatically
56 | * GPU is optional, but strongly recommended to avoid long computation times
57 |
58 | # Installation
59 |
60 | Install the required packages as listed in ``requirements.txt``.
61 | To ensure existing packages do not interfere with the installation, it is best to create a virtual environment with ``virtualenv`` first and then install the packages separately into that environment.
62 | Easy installation can be performed using
63 | ```
64 | pip install -r requirements.txt
65 | ```
66 |
67 | ## Dataset download
68 |
69 | ### Cityscapes and Edges2Shoes
70 |
71 | For experiments involving Cityscapes or Edges2Shoes data, you need to download these datasets first.
72 | To do this, change to the ``datasets/image2image`` subfolder in your commandline using
73 |
74 | ```
75 | cd datasets/image2image
76 | ```
77 |
78 | and then simply execute
79 |
80 | ```
81 | ./download_image2image.sh cityscapes
82 | ```
83 |
84 | or
85 |
86 | ```
87 | ./download_image2image.sh edges2shoes
88 | ```
89 |
90 | ### MUSDB18 (audio separation)
91 |
92 | For audio source separation experiments, you will need to download the [MUSDB18 dataset](https://sigsep.github.io/datasets/musdb.html) from [Zenodo](https://zenodo.org/record/1117372) manually, since it requires requesting access, and extract it to a folder of your choice.
93 |
94 | When running the training script, you can point to the MUSDB dataset folder by giving its path as a command-line parameter.
95 |
96 | # Running experiments
97 |
98 | To run the experiments, execute the script corresponding to the particular application, from the root directory of the repository:
99 | * ```PairedMNIST.py```: Paired MNIST experiments
100 | * ```ImagePairs.py```: Generation of image pairs (Cityscapes, Edges2Shoes)
101 | * ```Image2Image.py```: Used for image segmentation (Cityscapes)
102 | * ```AudioSeparation.py```: For vocal separation
103 |
104 | Each experiment in the paper can be replicated by specifying the experimental parameters via the commandline.
105 |
106 | Firstly, there is a set of parameters shared between all experiments, which are described in ```training/TrainingOptions.py```.
107 | The most important ones are:
108 | * ```--cuda```: Activate GPU training flag
109 | * ```--experiment_name```: Provide a string to name the experiment, which will be used to name the output folder
110 | * ```--out_path```: Provide the output folder where results and logs are saved. All output will usually be in ```out_path/TASKNAME/experiment_name```.
111 | * ```--eval```: Append this flag if you only want to perform model evaluation for an already trained model. CAUTION: ``experiment_name`` path as well as network parameters have to be set correctly (like the one used during training) to ensure this works correctly.
112 | * ```--factorGAN```: Provide a 0 to use the normal GAN, 1 for FactorGAN
113 | * ```--use_real_dep_disc```: Provide a 0 to not use a p-dependency discriminator, 1 for the full FactorGAN
114 |
115 | Every experiment also has specific extra commandline parameters which are explained in the code file for each experiment.
116 |
117 | ## Examples
118 |
119 | Train a GAN on PairedMNIST with 1000 joint samples, using GPU, and save results in ```out/PairedMNIST/100_samples_GAN```:
120 | ```
121 | python PairedMNIST.py --cuda --factorGAN 0 --num_joint_samples 1000 --experiment_name "100_samples_GAN"
122 | ```
123 |
124 | Train a FactorGAN to generate scene-segmentation image pairs with 100 joint samples on the Cityscapes dataset, using GPU, and save results in ```out/ImagePairs/cityscapes/100_samples_factorGAN```:
125 | ```
126 | python ImagePairs.py --cuda --dataset "cityscapes" --num_joint_samples 100 --factorGAN 1 --experiment_name "100_samples_factorGAN"
127 | ```
128 |
129 | Train a FactorGAN to for image segmentation on the Cityscapes dataset with 25 joint samples on the Cityscapes dataset, using GPU, and save results in ```out/Image2Image_cityscapes/25_samples_factorGAN```:
130 | ```
131 | python Image2Image.py --cuda --dataset "cityscapes" --num_joint_samples 25 --factorGAN 1 --experiment_name "25_samples_factorGAN"
132 | ```
133 |
134 | # Analysing and plotting results
135 |
136 | The ```analysis``` subfolder contains
137 | * some of the results obtained during our experiments, such as performance metrics
138 | * all scripts that were used to produce the figures
139 | * the figures used in the paper
140 |
141 | Some of them require the full output of an experiment, so the experiment needs to be run first, and the path to the resulting output folder can be inserted into the script. They are not included in the repository directly as they can be quite large.
142 |
--------------------------------------------------------------------------------
/Utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import librosa
4 | import numpy as np
5 | import torch
6 | import random
7 | import math
8 |
9 | def create_image_pair(x, ch1, ch2):
10 | '''
11 | Concatenates two images horizontally (that are saved in x using 3 or 1 channels for each image
12 | :param x: Pair of images (along channel dimension)
13 | :param ch1: Number of channels for image 1
14 | :param ch2: Number of channels for image 2
15 | :return: Horizontally stacked image pair
16 | '''
17 | assert(ch1 == 3 or ch1 == 1)
18 | assert(ch2 == 3 or ch2 == 1)
19 | assert(x.shape[1] == ch1 + ch2)
20 |
21 | repeat_left = 3 if ch1 == 1 else 1
22 | repeat_right = 3 if ch2 == 1 else 1
23 | return torch.cat([x[:, :ch1, :, :].repeat(1, repeat_left, 1, 1), x[:, ch1:, :, :].repeat(1,repeat_right,1,1)], dim=3)
24 |
25 | def is_square(integer):
26 | '''
27 | Check if number is a square of another number
28 | :param integer: Number to be checked
29 | :return: Whether number is square of another number
30 | '''
31 | root = math.sqrt(integer)
32 | if int(root + 0.5) ** 2 == integer:
33 | return True
34 | else:
35 | return False
36 |
37 | def shuffle_batch_dims(batch, marginal_index, dim=1):
38 | '''
39 | Shuffles groups of dimensions of a batch of samples so that groups are drawn independently of each other
40 | :param batch: Input batch to be shuffled
41 | :param marginal_index: If list: List of indices that denote the boundaries of the groups to be shuffled, excluding 0 and batch.shape[1].
42 | If int: Each group has this many dimensions, batch.shape[1] must be divisible by this number. If None: Input batch needs to have groups as dimensions: [Num_samples, Group1_dim, ... GroupN_dim]
43 | :return: Shuffled batch
44 | '''
45 |
46 | if isinstance(batch, torch.Tensor):
47 | out = batch.clone()
48 | else:
49 | out = batch.copy()
50 |
51 | if isinstance(marginal_index, int):
52 | assert (batch.shape[dim] % marginal_index == 0)
53 | marginal_index = [(x+1)*marginal_index for x in range(int(batch.shape[1] / marginal_index) - 1)]
54 | if isinstance(marginal_index, list):
55 | groups = marginal_index + [batch.shape[dim]]
56 | for group_idx in range(len(groups)-1): # Shuffle each group, except the first one
57 | dim_start = groups[group_idx]
58 | dim_end = groups[group_idx+1]
59 | ordering = np.random.permutation(batch.shape[0])
60 | if dim == 1:
61 | out[:,dim_start:dim_end] = batch[ordering, dim_start:dim_end]
62 | elif dim == 2:
63 | out[:, :, dim_start:dim_end] = batch[ordering, :, dim_start:dim_end]
64 | elif dim == 3:
65 | out[:, :, :, dim_start:dim_end] = batch[ordering, :, :, dim_start:dim_end]
66 | else:
67 | raise NotImplementedError
68 | else:
69 | raise NotImplementedError
70 |
71 | return out
72 |
73 | def load(path, sr=22050, mono=True, offset=0.0, duration=None, dtype=np.float32):
74 | # ALWAYS output (n_frames, n_channels) audio
75 | y, orig_sr = librosa.load(path, sr, mono, offset, duration, dtype)
76 | if len(y.shape) == 1:
77 | y = np.expand_dims(y, axis=0)
78 | return y.T, orig_sr
79 |
80 | def shuffle_batch_image_quadrants(batch):
81 | '''
82 | Given an input batch of square images, shuffle the four quadrants independently across examples
83 | :param batch: Input batch of square images
84 | :return: Shuffled square images
85 | '''
86 | input_shape = batch.shape
87 | if len(batch.shape) == 2: # [batch, dim] shape means we have to reshape
88 | # Check if data can be shaped into square image, if it is not already
89 | dim = int(batch.shape[1])
90 | root = int(math.sqrt(dim) + 0.5)
91 | assert(root ** 2 == dim)
92 | elif len(batch.shape) > 2:
93 | # Check if last two dimensions are the same size N, and reshape to [-1, C, N, N]
94 | assert(batch.shape[-2] == batch.shape[-1])
95 | root = batch.shape[-1]
96 | else:
97 | raise SyntaxError
98 | assert(root % 2 == 0) # Image should be splittable in half
99 | q = root // 2 # Length/width of each quadrant
100 |
101 | # Change to [B, C, N, N] shape
102 | if isinstance(batch, torch.Tensor):
103 | batch_reshape = batch.view((batch.shape[0], -1, root, root))
104 | out = batch_reshape.clone()
105 | else:
106 | batch_reshape = np.reshape(batch, (batch.shape[0], -1, root, root))
107 | out = batch_reshape.copy()
108 |
109 | # Shuffle the four quadrants of the square image around
110 | for row in range(2):
111 | for col in range(2):
112 | if row == 0 and col == 0: continue # Do not need to shuffle first quadrant, if we shuffle all the others across the batch
113 |
114 | ordering = np.random.permutation(batch.shape[0])
115 | out[:, :, row*q:(row+1)*q, col*q:(col+1)*q] = batch_reshape[ordering, :, row*q:(row+1)*q, col*q:(col+1)*q]
116 |
117 | # Reshape to the shape of the original input
118 | if isinstance(batch, torch.Tensor):
119 | out = out.view(input_shape)
120 | else:
121 | out = np.reshape(out, input_shape)
122 |
123 | return out
124 |
125 | def compute_spectrogram(audio, fft_size, hop_size):
126 | '''
127 | Compute magnitude spectrogram for audio signal
128 | :param audio: Audio input signal
129 | :param fft_size: FFT Window size (samples)
130 | :param hop_size: Hop size (samples) for STFT
131 | :return: Magnitude spectrogram
132 | '''
133 | stft = librosa.core.stft(audio, fft_size, hop_size)
134 | mag, ph = librosa.core.magphase(stft)
135 |
136 | return normalise_spectrogram(mag), ph
137 |
138 | def normalise_spectrogram(mag, cut_last_freq=True):
139 | '''
140 | Normalise audio spectrogram with log-normalisation
141 | :param mag: Magnitude spectrogram to be normalised
142 | :param cut_last_freq: Whether to cut highest frequency bin to reach power of 2 in number of bins
143 | :return: Normalised spectrogram
144 | '''
145 | if cut_last_freq:
146 | # Throw away last freq bin to make it number of freq bins a power of 2
147 | out = mag[:-1,:]
148 |
149 | # Normalize with log1p
150 | out = np.log1p(out)
151 | return out
152 |
153 | def normalise_spectrogram_torch(mag):
154 | return torch.log1p(mag)
155 |
156 | def denormalise_spectrogram(mag, pad_freq=True):
157 | '''
158 | Reverses normalisation performed in "normalise_spectrogram" function
159 | :param mag: Normalised magnitudes
160 | :param pad_freq: Whether to append a frequency bin as highest frequency with 0 as energy
161 | :return: Reconstructed spectrogram
162 | '''
163 | out = np.expm1(mag)
164 |
165 | if pad_freq:
166 | out = np.pad(out, [(0,1), (0, 0)], mode="constant")
167 |
168 | return out
169 |
170 | def denormalise_spectrogram_torch(mag):
171 | return torch.expm1(mag)
172 |
173 | def spectrogramToAudioFile(magnitude, fftWindowSize, hopSize, phaseIterations=10, phase=None, length=None):
174 | '''
175 | Computes an audio signal from the given magnitude spectrogram, and optionally an initial phase.
176 | Griffin-Lim is executed to recover/refine the given the phase from the magnitude spectrogram.
177 | :param magnitude: Magnitudes to be converted to audio
178 | :param fftWindowSize: Size of FFT window used to create magnitudes
179 | :param hopSize: Hop size in frames used to create magnitudes
180 | :param phaseIterations: Number of Griffin-Lim iterations to recover phase
181 | :param phase: If given, starts ISTFT with this particular phase matrix
182 | :param length: If given, audio signal is clipped/padded to this number of frames
183 | :return:
184 | '''
185 | if phase is not None:
186 | if phaseIterations > 0:
187 | # Refine audio given initial phase with a number of iterations
188 | return reconPhase(magnitude, fftWindowSize, hopSize, phaseIterations, phase, length)
189 | # reconstructing the new complex matrix
190 | stftMatrix = magnitude * np.exp(phase * 1j) # magnitude * e^(j*phase)
191 | audio = librosa.istft(stftMatrix, hop_length=hopSize, length=length)
192 | else:
193 | audio = reconPhase(magnitude, fftWindowSize, hopSize, phaseIterations)
194 | return audio
195 |
196 | def reconPhase(magnitude, fftWindowSize, hopSize, phaseIterations=10, initPhase=None, length=None):
197 | '''
198 | Griffin-Lim algorithm for reconstructing the phase for a given magnitude spectrogram, optionally with a given
199 | intial phase.
200 | :param magnitude: Magnitudes to be converted to audio
201 | :param fftWindowSize: Size of FFT window used to create magnitudes
202 | :param hopSize: Hop size in frames used to create magnitudes
203 | :param phaseIterations: Number of Griffin-Lim iterations to recover phase
204 | :param initPhase: If given, starts reconstruction with this particular phase matrix
205 | :param length: If given, audio signal is clipped/padded to this number of frames
206 | :return:
207 | '''
208 | for i in range(phaseIterations):
209 | if i == 0:
210 | if initPhase is None:
211 | reconstruction = np.random.random_sample(magnitude.shape) + 1j * (2 * np.pi * np.random.random_sample(magnitude.shape) - np.pi)
212 | else:
213 | reconstruction = np.exp(initPhase * 1j) # e^(j*phase), so that angle => phase
214 | else:
215 | reconstruction = librosa.stft(audio, fftWindowSize, hopSize)
216 | spectrum = magnitude * np.exp(1j * np.angle(reconstruction))
217 | if i == phaseIterations - 1:
218 | audio = librosa.istft(spectrum, hopSize, length=length)
219 | else:
220 | audio = librosa.istft(spectrum, hopSize)
221 | return audio
222 |
223 | def make_dirs(dirs):
224 | if isinstance(dirs, str):
225 | dirs = [dirs]
226 | assert(isinstance(dirs, list))
227 | for dir in dirs:
228 | if not os.path.exists(dir):
229 | os.makedirs(dir)
230 |
231 | def create_optim(parameters, opt):
232 | return torch.optim.Adam(parameters, lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=opt.L2)
233 |
234 | def get_device(cuda):
235 | if torch.cuda.is_available() and not cuda:
236 | print("WARNING: You have a CUDA device, so you should probably run with --cuda")
237 | device = torch.device("cuda:0" if cuda else "cpu")
238 | return device
239 |
240 | def set_seeds(opt):
241 | '''
242 | Set Python, numpy as and torch random seeds to a fixed number
243 | :param opt: Option dictionary containined .seed member value
244 | '''
245 | if opt.seed is None:
246 | opt.seed = random.randint(1, 10000)
247 | print("Random Seed: ", opt.seed)
248 | random.seed(opt.seed)
249 | torch.manual_seed(opt.seed)
250 | np.random.seed(opt.seed)
--------------------------------------------------------------------------------
/analysis/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/__init__.py
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs Results.csv:
--------------------------------------------------------------------------------
1 | LS,Model,Paired samples,Dataset,Samples_Validation,Samples_Test
2 | 0.19918266125023365,GAN,100,edges2shoes,160,40
3 | 0.197277020663023,GAN,100,edges2shoes,160,40
4 | 0.18398787640035152,GAN,100,edges2shoes,160,40
5 | 0.20709695480763912,GAN,100,edges2shoes,160,40
6 | 0.1931029949337244,GAN,100,edges2shoes,160,40
7 | 0.1648500096052885,GAN,100,edges2shoes,160,40
8 | 0.2281373254954815,GAN,100,edges2shoes,160,40
9 | 0.1764399316161871,GAN,100,edges2shoes,160,40
10 | 0.19546093232929707,GAN,100,edges2shoes,160,40
11 | 0.24289108626544476,GAN,100,edges2shoes,160,40
12 | 0.4614328145980835,GAN,1000,edges2shoes,160,40
13 | 0.4914138615131378,GAN,1000,edges2shoes,160,40
14 | 0.4868180975317955,GAN,1000,edges2shoes,160,40
15 | 0.4894701614975929,GAN,1000,edges2shoes,160,40
16 | 0.5159566402435303,GAN,1000,edges2shoes,160,40
17 | 0.4466138854622841,GAN,1000,edges2shoes,160,40
18 | 0.496683269739151,GAN,1000,edges2shoes,160,40
19 | 0.4755656234920025,GAN,1000,edges2shoes,160,40
20 | 0.5014562234282494,GAN,1000,edges2shoes,160,40
21 | 0.4484795890748501,GAN,1000,edges2shoes,160,40
22 | 0.4587823525071144,GAN,49825,edges2shoes,160,40
23 | 0.49065449088811874,GAN,49825,edges2shoes,160,40
24 | 0.47518329322338104,GAN,49825,edges2shoes,160,40
25 | 0.5073818042874336,GAN,49825,edges2shoes,160,40
26 | 0.4914328381419182,GAN,49825,edges2shoes,160,40
27 | 0.4772345498204231,GAN,49825,edges2shoes,160,40
28 | 0.46360666304826736,GAN,49825,edges2shoes,160,40
29 | 0.45902401581406593,GAN,49825,edges2shoes,160,40
30 | 0.47111673653125763,GAN,49825,edges2shoes,160,40
31 | 0.48572762310504913,GAN,49825,edges2shoes,160,40
32 | 0.49867895245552063,FactorGAN,100,edges2shoes,160,40
33 | 0.48741042986512184,FactorGAN,100,edges2shoes,160,40
34 | 0.4798458069562912,FactorGAN,100,edges2shoes,160,40
35 | 0.4852375388145447,FactorGAN,100,edges2shoes,160,40
36 | 0.5021589547395706,FactorGAN,100,edges2shoes,160,40
37 | 0.5132415220141411,FactorGAN,100,edges2shoes,160,40
38 | 0.5196625292301178,FactorGAN,100,edges2shoes,160,40
39 | 0.4916936084628105,FactorGAN,100,edges2shoes,160,40
40 | 0.4872872121632099,FactorGAN,100,edges2shoes,160,40
41 | 0.5147524625062943,FactorGAN,100,edges2shoes,160,40
42 | 0.5059081017971039,FactorGAN,1000,edges2shoes,160,40
43 | 0.5109081342816353,FactorGAN,1000,edges2shoes,160,40
44 | 0.47647980600595474,FactorGAN,1000,edges2shoes,160,40
45 | 0.5170110575854778,FactorGAN,1000,edges2shoes,160,40
46 | 0.49985283613204956,FactorGAN,1000,edges2shoes,160,40
47 | 0.50772500410676,FactorGAN,1000,edges2shoes,160,40
48 | 0.5248940289020538,FactorGAN,1000,edges2shoes,160,40
49 | 0.525466576218605,FactorGAN,1000,edges2shoes,160,40
50 | 0.5174819976091385,FactorGAN,1000,edges2shoes,160,40
51 | 0.53387051820755,FactorGAN,1000,edges2shoes,160,40
52 | 0.5184591487050056,FactorGAN,49825,edges2shoes,160,40
53 | 0.5095947757363319,FactorGAN,49825,edges2shoes,160,40
54 | 0.49991364777088165,FactorGAN,49825,edges2shoes,160,40
55 | 0.5085523054003716,FactorGAN,49825,edges2shoes,160,40
56 | 0.4887695945799351,FactorGAN,49825,edges2shoes,160,40
57 | 0.523505188524723,FactorGAN,49825,edges2shoes,160,40
58 | 0.4910292327404022,FactorGAN,49825,edges2shoes,160,40
59 | 0.5000504329800606,FactorGAN,49825,edges2shoes,160,40
60 | 0.48224231973290443,FactorGAN,49825,edges2shoes,160,40
61 | 0.49068504571914673,FactorGAN,49825,edges2shoes,160,40
62 | 0.06610106024891138,GAN,100,cityscapes,400,100
63 | 0.0486236484721303,GAN,100,cityscapes,400,100
64 | 0.06520971702411771,GAN,100,cityscapes,400,100
65 | 0.043228451162576675,GAN,100,cityscapes,400,100
66 | 0.06359246838837862,GAN,100,cityscapes,400,100
67 | 0.03645658027380705,GAN,100,cityscapes,400,100
68 | 0.04610483441501856,GAN,100,cityscapes,400,100
69 | 0.05799236707389355,GAN,100,cityscapes,400,100
70 | 0.049013398587703705,GAN,100,cityscapes,400,100
71 | 0.06162051111459732,GAN,100,cityscapes,400,100
72 | 0.041525376960635185,GAN,1000,cityscapes,400,100
73 | 0.07696561142802238,GAN,1000,cityscapes,400,100
74 | 0.05571409035474062,GAN,1000,cityscapes,400,100
75 | 0.06277178600430489,GAN,1000,cityscapes,400,100
76 | 0.03282815124839544,GAN,1000,cityscapes,400,100
77 | 0.05692122969776392,GAN,1000,cityscapes,400,100
78 | 0.05593732185661793,GAN,1000,cityscapes,400,100
79 | 0.03852173686027527,GAN,1000,cityscapes,400,100
80 | 0.03605717979371548,GAN,1000,cityscapes,400,100
81 | 0.04344829358160496,GAN,1000,cityscapes,400,100
82 | 0.0877392329275608,GAN,2975,cityscapes,400,100
83 | 0.08541159704327583,GAN,2975,cityscapes,400,100
84 | 0.08468282781541348,GAN,2975,cityscapes,400,100
85 | 0.06158710457384586,GAN,2975,cityscapes,400,100
86 | 0.10693855583667755,GAN,2975,cityscapes,400,100
87 | 0.08687366731464863,GAN,2975,cityscapes,400,100
88 | 0.10995074361562729,GAN,2975,cityscapes,400,100
89 | 0.09155567362904549,GAN,2975,cityscapes,400,100
90 | 0.08998952526599169,GAN,2975,cityscapes,400,100
91 | 0.07705152779817581,GAN,2975,cityscapes,400,100
92 | 0.15152879618108273,FactorGAN,100,cityscapes,400,100
93 | 0.1284797377884388,FactorGAN,100,cityscapes,400,100
94 | 0.13163253664970398,FactorGAN,100,cityscapes,400,100
95 | 0.18368030712008476,FactorGAN,100,cityscapes,400,100
96 | 0.12400678172707558,FactorGAN,100,cityscapes,400,100
97 | 0.1800699234008789,FactorGAN,100,cityscapes,400,100
98 | 0.14700616523623466,FactorGAN,100,cityscapes,400,100
99 | 0.2166460081934929,FactorGAN,100,cityscapes,400,100
100 | 0.1466676127165556,FactorGAN,100,cityscapes,400,100
101 | 0.13666629791259766,FactorGAN,100,cityscapes,400,100
102 | 0.24428238719701767,FactorGAN,1000,cityscapes,400,100
103 | 0.220618337392807,FactorGAN,1000,cityscapes,400,100
104 | 0.20961547270417213,FactorGAN,1000,cityscapes,400,100
105 | 0.20297425985336304,FactorGAN,1000,cityscapes,400,100
106 | 0.24880168214440346,FactorGAN,1000,cityscapes,400,100
107 | 0.20047229900956154,FactorGAN,1000,cityscapes,400,100
108 | 0.19510486349463463,FactorGAN,1000,cityscapes,400,100
109 | 0.213859923183918,FactorGAN,1000,cityscapes,400,100
110 | 0.1974785439670086,FactorGAN,1000,cityscapes,400,100
111 | 0.28764569014310837,FactorGAN,1000,cityscapes,400,100
112 | 0.28121019154787064,FactorGAN,2975,cityscapes,400,100
113 | 0.24645909294486046,FactorGAN,2975,cityscapes,400,100
114 | 0.22664135321974754,FactorGAN,2975,cityscapes,400,100
115 | 0.24045272171497345,FactorGAN,2975,cityscapes,400,100
116 | 0.26552461832761765,FactorGAN,2975,cityscapes,400,100
117 | 0.2997002340853214,FactorGAN,2975,cityscapes,400,100
118 | 0.21766618266701698,FactorGAN,2975,cityscapes,400,100
119 | 0.28905053064227104,FactorGAN,2975,cityscapes,400,100
120 | 0.21604740619659424,FactorGAN,2975,cityscapes,400,100
121 | 0.19727909564971924,FactorGAN,2975,cityscapes,400,100
122 | 0.09539508633315563,GAN (big),2975,cityscapes,400,100
123 | 0.18587369099259377,GAN (big),2975,cityscapes,400,100
124 | 0.1633717156946659,GAN (big),2975,cityscapes,400,100
125 | 0.16083178855478764,GAN (big),2975,cityscapes,400,100
126 | 0.16951332241296768,GAN (big),2975,cityscapes,400,100
127 | 0.1349962204694748,GAN (big),2975,cityscapes,400,100
128 | 0.15365711972117424,GAN (big),2975,cityscapes,400,100
129 | 0.1875893995165825,GAN (big),2975,cityscapes,400,100
130 | 0.12070108018815517,GAN (big),2975,cityscapes,400,100
131 | 0.1576228328049183,GAN (big),2975,cityscapes,400,100
132 | 0.12126125209033489,GAN (big),1000,cityscapes,400,100
133 | 0.1332952231168747,GAN (big),1000,cityscapes,400,100
134 | 0.09272367879748344,GAN (big),1000,cityscapes,400,100
135 | 0.18972866609692574,GAN (big),1000,cityscapes,400,100
136 | 0.1441386379301548,GAN (big),1000,cityscapes,400,100
137 | 0.14152801036834717,GAN (big),1000,cityscapes,400,100
138 | 0.05023533571511507,GAN (big),1000,cityscapes,400,100
139 | 0.19023562967777252,GAN (big),1000,cityscapes,400,100
140 | 0.1459544152021408,GAN (big),1000,cityscapes,400,100
141 | 0.1345924697816372,GAN (big),1000,cityscapes,400,100
142 | 0.04006412858143449,GAN (big),100,cityscapes,400,100
143 | 0.03009739425033331,GAN (big),100,cityscapes,400,100
144 | 0.04413636680692434,GAN (big),100,cityscapes,400,100
145 | 0.021442324155941606,GAN (big),100,cityscapes,400,100
146 | 0.03707707207649946,GAN (big),100,cityscapes,400,100
147 | 0.036953238770365715,GAN (big),100,cityscapes,400,100
148 | 0.04017030727118254,GAN (big),100,cityscapes,400,100
149 | 0.03742054011672735,GAN (big),100,cityscapes,400,100
150 | 0.0391128808259964,GAN (big),100,cityscapes,400,100
151 | 0.04093080945312977,GAN (big),100,cityscapes,400,100
152 | 0.4937850534915924,GAN (big),49825,edges2shoes,160,40
153 | 0.4875243380665779,GAN (big),49825,edges2shoes,160,40
154 | 0.4851335920393467,GAN (big),49825,edges2shoes,160,40
155 | 0.5300062000751495,GAN (big),49825,edges2shoes,160,40
156 | 0.5167667642235756,GAN (big),49825,edges2shoes,160,40
157 | 0.48705537244677544,GAN (big),49825,edges2shoes,160,40
158 | 0.5067377276718616,GAN (big),49825,edges2shoes,160,40
159 | 0.4927330017089844,GAN (big),49825,edges2shoes,160,40
160 | 0.5186118818819523,GAN (big),49825,edges2shoes,160,40
161 | 0.49117034673690796,GAN (big),49825,edges2shoes,160,40
162 | 0.44656917452812195,GAN (big),1000,edges2shoes,160,40
163 | 0.43908343836665154,GAN (big),1000,edges2shoes,160,40
164 | 0.47615355253219604,GAN (big),1000,edges2shoes,160,40
165 | 0.4346071891486645,GAN (big),1000,edges2shoes,160,40
166 | 0.44736265018582344,GAN (big),1000,edges2shoes,160,40
167 | 0.43023814633488655,GAN (big),1000,edges2shoes,160,40
168 | 0.49190129339694977,GAN (big),1000,edges2shoes,160,40
169 | 0.4424554482102394,GAN (big),1000,edges2shoes,160,40
170 | 0.4625251926481724,GAN (big),1000,edges2shoes,160,40
171 | 0.39753028750419617,GAN (big),1000,edges2shoes,160,40
172 | 0.11263203108683228,GAN (big),100,edges2shoes,160,40
173 | 0.10219825804233551,GAN (big),100,edges2shoes,160,40
174 | 0.11926222685724497,GAN (big),100,edges2shoes,160,40
175 | 0.15174391120672226,GAN (big),100,edges2shoes,160,40
176 | 0.12145963963121176,GAN (big),100,edges2shoes,160,40
177 | 0.13108434528112411,GAN (big),100,edges2shoes,160,40
178 | 0.1380934789776802,GAN (big),100,edges2shoes,160,40
179 | 0.16618123278021812,GAN (big),100,edges2shoes,160,40
180 | 0.10596461221575737,GAN (big),100,edges2shoes,160,40
181 | 0.10804288182407618,GAN (big),100,edges2shoes,160,40
182 |
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_cityscapes_1000_joint_GAN_big_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_cityscapes_1000_joint_GAN_big_gens.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_cityscapes_1000_joint_GAN_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_cityscapes_1000_joint_GAN_gens.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_cityscapes_1000_joint_factorGAN_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_cityscapes_1000_joint_factorGAN_gens.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_cityscapes_100_joint_GAN_big_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_cityscapes_100_joint_GAN_big_gens.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_cityscapes_100_joint_GAN_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_cityscapes_100_joint_GAN_gens.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_cityscapes_100_joint_factorGAN_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_cityscapes_100_joint_factorGAN_gens.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_cityscapes_All_joint_GAN_big_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_cityscapes_All_joint_GAN_big_gens.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_cityscapes_All_joint_GAN_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_cityscapes_All_joint_GAN_gens.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_cityscapes_All_joint_factorGAN_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_cityscapes_All_joint_factorGAN_gens.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_edges2shoes_1000_joint_GAN_big_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_1000_joint_GAN_big_gens.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_edges2shoes_1000_joint_GAN_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_1000_joint_GAN_gens.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_edges2shoes_1000_joint_factorGAN_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_1000_joint_factorGAN_gens.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_edges2shoes_100_joint_GAN_big_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_100_joint_GAN_big_gens.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_edges2shoes_100_joint_GAN_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_100_joint_GAN_gens.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_edges2shoes_100_joint_GAN_gens_8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_100_joint_GAN_gens_8.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_edges2shoes_100_joint_factorGAN_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_100_joint_factorGAN_gens.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_edges2shoes_100_joint_factorGAN_gens_8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_100_joint_factorGAN_gens_8.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_edges2shoes_All_joint_GAN_big_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_All_joint_GAN_big_gens.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_edges2shoes_All_joint_GAN_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_All_joint_GAN_gens.png
--------------------------------------------------------------------------------
/analysis/imagepairs/ImagePairs_edges2shoes_All_joint_factorGAN_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/ImagePairs_edges2shoes_All_joint_factorGAN_gens.png
--------------------------------------------------------------------------------
/analysis/imagepairs/PlotExamples.py:
--------------------------------------------------------------------------------
1 | '''
2 | Script to plot ImagePair experiment generator examples.
3 | Point the RESULTS_PATH to the folder where experiment logs are saved (that contains an "ImagePairs" folder)
4 | The script will automatically plot generator examples for each experiment that is found in the folder.
5 | The number of images to show for each model can be set by NUM_EXAMPLES.
6 | NUM_COLS sets the number of columns (along horizontal axis) to use for plotting.
7 | '''
8 |
9 | import glob
10 |
11 | import imageio as imageio
12 | import numpy as np
13 | import torch
14 | import torchvision
15 | import os
16 |
17 | RESULTS_PATH = "/mnt/windaten/Results/factorGAN/"
18 | DATASET_FOLDERS = ["cityscapes", "edges2shoes"]
19 | OUT_PATH = ""
20 | NUM_EXAMPLES = 16
21 | NUM_COLS = 4
22 |
23 | for dataset in DATASET_FOLDERS:
24 | for experiment_path in glob.glob(os.path.join(RESULTS_PATH, "ImagePairs", dataset, "*")):
25 | model = os.path.basename(experiment_path)
26 | gan_paths = [os.path.join(RESULTS_PATH, "ImagePairs", dataset, model, "gen", "gen_" + str(i) + ".png") for i in range(NUM_EXAMPLES)]
27 | gan_imgs = list()
28 |
29 | for file in gan_paths:
30 | gan_imgs.append(imageio.imread(file))
31 | gan_imgs = torch.from_numpy(np.transpose(np.stack(gan_imgs), [0, 3, 1, 2]))
32 | gan_imgs = torchvision.utils.make_grid(gan_imgs, nrow=NUM_COLS, padding=10, pad_value=255.0).permute(1, 2, 0)
33 |
34 | imageio.imwrite(os.path.join(OUT_PATH, "ImagePairs_" + dataset + "_" + model + "_gens.png"), gan_imgs)
--------------------------------------------------------------------------------
/analysis/imagepairs/PlotLS.py:
--------------------------------------------------------------------------------
1 | import seaborn as sns
2 | import pandas as pd
3 | import matplotlib.pyplot as plt
4 |
5 | sns.set(font_scale=1.1, rc={'text.usetex' : True})
6 |
7 | df = pd.read_csv("ImagePairs Results.csv", delimiter=",")
8 | df.loc[df["Paired samples"] > 1000, "Paired samples"] = "All"
9 |
10 | g = sns.catplot(x="Paired samples", y="LS", hue="Model", col="Dataset", data=df, kind="bar", ci=95, height=3, aspect=2)#, hue_order=["FactorGAN", "GAN", "GAN (big)"])
11 |
12 | g.axes[0][0].set_title("Edges2Shoes")
13 | g.axes[0][1].set_title("Cityscapes")
14 |
15 | #plt.tight_layout()
16 | plt.savefig("imagepairs_LS.pdf")
--------------------------------------------------------------------------------
/analysis/imagepairs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/__init__.py
--------------------------------------------------------------------------------
/analysis/imagepairs/imagepairs_LS.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/imagepairs/imagepairs_LS.pdf
--------------------------------------------------------------------------------
/analysis/img2img_cityscapes/Image2Image_Cityscapes.csv:
--------------------------------------------------------------------------------
1 | Paired samples,Perf,Model,Metric
2 | 25,0.0409,GAN,MSE
3 | 100,0.03447,GAN,MSE
4 | 500,0.02419,GAN,MSE
5 | 1000,0.01954,GAN,MSE
6 | 3000,0.02029,GAN,MSE
7 | 25,0.02665,FactorGAN,MSE
8 | 100,0.02688,FactorGAN,MSE
9 | 500,0.02116,FactorGAN,MSE
10 | 1000,0.02038,FactorGAN,MSE
11 | 3000,0.0196,FactorGAN,MSE
12 | 25,53.98,GAN,Accuracy
13 | 100,64.41,GAN,Accuracy
14 | 500,73.26,GAN,Accuracy
15 | 1000,79.24,GAN,Accuracy
16 | 3000,78.18,GAN,Accuracy
17 | 25,71.66,FactorGAN,Accuracy
18 | 100,71.98,FactorGAN,Accuracy
19 | 500,77.41,FactorGAN,Accuracy
20 | 1000,77.95,FactorGAN,Accuracy
21 | 3000,78.66,FactorGAN,Accuracy
--------------------------------------------------------------------------------
/analysis/img2img_cityscapes/PlotExamples.py:
--------------------------------------------------------------------------------
1 | '''
2 | Visualise Cityscapes image segmentation predictions. Requires experiments folders to be existing with the correct names first!
3 | '''
4 | import glob
5 |
6 | import imageio as imageio
7 | import numpy as np
8 | import torch
9 | import torchvision
10 | import os
11 |
12 | NUM_ROWS = 3
13 | NUM_COLUMNS = 2
14 |
15 | BASEDIR = "../../out/Image2Image_cityscapes/"
16 | EXPERIMENTS = [os.path.basename(p) for p in glob.glob(os.path.join(BASEDIR, "*"))]
17 |
18 | for experiment in EXPERIMENTS:
19 | gan_paths = [os.path.join(BASEDIR, experiment, "gen", "gen_" + str(i) + ".png") for i in range(NUM_ROWS*NUM_COLUMNS)]
20 | gan_imgs = list()
21 |
22 | for file in gan_paths:
23 | gan_imgs.append(imageio.imread(file))
24 | gan_imgs = torch.from_numpy(np.transpose(np.stack(gan_imgs), [0, 3, 1, 2]))
25 | gan_imgs = torchvision.utils.make_grid(gan_imgs, nrow=2, padding=10, pad_value=255.0).permute(1, 2, 0)
26 |
27 | imageio.imwrite("cityscapes_" + experiment + "_gens.png", gan_imgs)
--------------------------------------------------------------------------------
/analysis/img2img_cityscapes/PlotL2Acc.py:
--------------------------------------------------------------------------------
1 | '''
2 | Plot L2 and MSE accuracy of image segmentation models (GAN vs FactorGAN)
3 | '''
4 |
5 | import seaborn as sns
6 | import pandas as pd
7 | import matplotlib.pyplot as plt
8 |
9 | sns.set(font_scale=1.1, rc={'text.usetex' : True})
10 |
11 | df = pd.read_csv("Image2Image_Cityscapes.csv", delimiter=",")
12 | g = sns.catplot("Paired samples", "Perf", hue="Model", col="Metric", data=df, kind="bar", sharey=False, height=3, aspect=2)
13 | g.axes[0][0].set_ylabel("Mean squared error")
14 | g.axes[0][0].set_title("")
15 | g.axes[0][1].set_title("")
16 | g.axes[0][1].set_ylabel("Accuracy (\%)")
17 |
18 | g.fig.subplots_adjust(top=0.95, wspace=0.15)
19 |
20 | plt.savefig("cityscapes.pdf")
--------------------------------------------------------------------------------
/analysis/img2img_cityscapes/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/__init__.py
--------------------------------------------------------------------------------
/analysis/img2img_cityscapes/cityscapes.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/cityscapes.pdf
--------------------------------------------------------------------------------
/analysis/img2img_cityscapes/cityscapes_1000_joint_GAN_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/cityscapes_1000_joint_GAN_gens.png
--------------------------------------------------------------------------------
/analysis/img2img_cityscapes/cityscapes_1000_joint_factorGAN_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/cityscapes_1000_joint_factorGAN_gens.png
--------------------------------------------------------------------------------
/analysis/img2img_cityscapes/cityscapes_100_joint_GAN_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/cityscapes_100_joint_GAN_gens.png
--------------------------------------------------------------------------------
/analysis/img2img_cityscapes/cityscapes_100_joint_GAN_gens_small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/cityscapes_100_joint_GAN_gens_small.png
--------------------------------------------------------------------------------
/analysis/img2img_cityscapes/cityscapes_100_joint_factorGAN_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/cityscapes_100_joint_factorGAN_gens.png
--------------------------------------------------------------------------------
/analysis/img2img_cityscapes/cityscapes_100_joint_factorGAN_gens_small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/cityscapes_100_joint_factorGAN_gens_small.png
--------------------------------------------------------------------------------
/analysis/img2img_cityscapes/cityscapes_all_joint_GAN_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/cityscapes_all_joint_GAN_gens.png
--------------------------------------------------------------------------------
/analysis/img2img_cityscapes/cityscapes_all_joint_factorGAN_gens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/img2img_cityscapes/cityscapes_all_joint_factorGAN_gens.png
--------------------------------------------------------------------------------
/analysis/mnist/PairedMNIST Results Diff.csv:
--------------------------------------------------------------------------------
1 | Paired samples;Lambda;Diff;Model
2 | 10;0.1;4.142;GAN
3 | 20;0.1;2.296;GAN
4 | 50;0.1;1.03;GAN
5 | 100;0.1;0.746;GAN
6 | 1000;0.1;0.2392;GAN
7 | 5000;0.1;0.2024;GAN
8 | 10000;0.1;0.2072;GAN
9 | 20000;0.1;0.1986;GAN
10 | 60000;0.1;0.2293;GAN
11 | 10;0.9;0.6944;GAN
12 | 20;0.9;0.3906;GAN
13 | 50;0.9;0.4542;GAN
14 | 100;0.9;0.3664;GAN
15 | 1000;0.9;0.6986;GAN
16 | 5000;0.9;0.6833;GAN
17 | 10000;0.9;0.6688;GAN
18 | 20000;0.9;0.723;GAN
19 | 60000;0.9;0.7067;GAN
20 | 10;0.1;0.4388;FactorGAN
21 | 20;0.1;0.381;FactorGAN
22 | 50;0.1;0.306;FactorGAN
23 | 100;0.1;0.3545;FactorGAN
24 | 1000;0.1;0.2586;FactorGAN
25 | 5000;0.1;0.2501;FactorGAN
26 | 10000;0.1;0.2225;FactorGAN
27 | 20000;0.1;0.2162;FactorGAN
28 | 60000;0.1;0.2025;FactorGAN
29 | 10;0.9;1.249;FactorGAN
30 | 20;0.9;1.174;FactorGAN
31 | 50;0.9;1.224;FactorGAN
32 | 100;0.9;1.208;FactorGAN
33 | 1000;0.9;0.815;FactorGAN
34 | 5000;0.9;0.513;FactorGAN
35 | 10000;0.9;0.4168;FactorGAN
36 | 20000;0.9;0.3421;FactorGAN
37 | 60000;0.9;0.3271;FactorGAN
38 | 10;0.1;0.267;FactorGAN-no-cp
39 | 20;0.1;0.267;FactorGAN-no-cp
40 | 50;0.1;0.267;FactorGAN-no-cp
41 | 100;0.1;0.267;FactorGAN-no-cp
42 | 1000;0.1;0.267;FactorGAN-no-cp
43 | 5000;0.1;0.267;FactorGAN-no-cp
44 | 10000;0.1;0.267;FactorGAN-no-cp
45 | 20000;0.1;0.267;FactorGAN-no-cp
46 | 60000;0.1;0.267;FactorGAN-no-cp
47 | 10;0.9;1.488;FactorGAN-no-cp
48 | 20;0.9;1.488;FactorGAN-no-cp
49 | 50;0.9;1.488;FactorGAN-no-cp
50 | 100;0.9;1.488;FactorGAN-no-cp
51 | 1000;0.9;1.488;FactorGAN-no-cp
52 | 5000;0.9;1.488;FactorGAN-no-cp
53 | 10000;0.9;1.488;FactorGAN-no-cp
54 | 20000;0.9;1.488;FactorGAN-no-cp
55 | 60000;0.9;1.488;FactorGAN-no-cp
56 |
--------------------------------------------------------------------------------
/analysis/mnist/PairedMNIST Results FID.csv:
--------------------------------------------------------------------------------
1 | Paired samples;Lambda;FID;Model
2 | 10;0.1;154.45;GAN
3 | 20;0.1;83.97;GAN
4 | 50;0.1;37.24;GAN
5 | 100;0.1;22.58;GAN
6 | 1000;0.1;4.137;GAN
7 | 5000;0.1;3.0715;GAN
8 | 10000;0.1;2.752;GAN
9 | 20000;0.1;2.953;GAN
10 | 60000;0.1;2.615;GAN
11 | 10;0.9;144.05;GAN
12 | 20;0.9;62.66;GAN
13 | 50;0.9;34.265;GAN
14 | 100;0.9;18.52;GAN
15 | 1000;0.9;5.9085;GAN
16 | 5000;0.9;2.8295;GAN
17 | 10000;0.9;2.9965;GAN
18 | 20000;0.9;2.48;GAN
19 | 60000;0.9;2.129;GAN
20 | 10;0.1;7.7045;FactorGAN
21 | 20;0.1;6.793;FactorGAN
22 | 50;0.1;3.2665;FactorGAN
23 | 100;0.1;4.341;FactorGAN
24 | 1000;0.1;3.466;FactorGAN
25 | 5000;0.1;2.5155;FactorGAN
26 | 10000;0.1;2.747;FactorGAN
27 | 20000;0.1;2.814;FactorGAN
28 | 60000;0.1;2.861;FactorGAN
29 | 10;0.9;10.815;FactorGAN
30 | 20;0.9;5.4855;FactorGAN
31 | 50;0.9;4.057;FactorGAN
32 | 100;0.9;4.4495;FactorGAN
33 | 1000;0.9;3.0935;FactorGAN
34 | 5000;0.9;1.382;FactorGAN
35 | 10000;0.9;1.5345;FactorGAN
36 | 20000;0.9;1.3675;FactorGAN
37 | 60000;0.9;1.142;FactorGAN
38 | 10;0.1;2.827;FactorGAN-no-cp
39 | 20;0.1;2.827;FactorGAN-no-cp
40 | 50;0.1;2.827;FactorGAN-no-cp
41 | 100;0.1;2.827;FactorGAN-no-cp
42 | 1000;0.1;2.827;FactorGAN-no-cp
43 | 5000;0.1;2.827;FactorGAN-no-cp
44 | 10000;0.1;2.827;FactorGAN-no-cp
45 | 20000;0.1;2.827;FactorGAN-no-cp
46 | 60000;0.1;2.827;FactorGAN-no-cp
47 | 10;0.9;2.598;FactorGAN-no-cp
48 | 20;0.9;2.598;FactorGAN-no-cp
49 | 50;0.9;2.598;FactorGAN-no-cp
50 | 100;0.9;2.598;FactorGAN-no-cp
51 | 1000;0.9;2.598;FactorGAN-no-cp
52 | 5000;0.9;2.598;FactorGAN-no-cp
53 | 10000;0.9;2.598;FactorGAN-no-cp
54 | 20000;0.9;2.598;FactorGAN-no-cp
55 | 60000;0.9;2.598;FactorGAN-no-cp
56 |
--------------------------------------------------------------------------------
/analysis/mnist/PlotDep.py:
--------------------------------------------------------------------------------
1 | '''
2 | Plot d_dep dependency metric for PairedMNIST experiment
3 | '''
4 |
5 | import seaborn as sns
6 | import pandas as pd
7 | import matplotlib.pyplot as plt
8 |
9 | sns.set(font_scale=1.1, rc={'text.usetex' : True})
10 |
11 | fig, ax = plt.subplots(figsize=(6,3.5))
12 |
13 | df = pd.read_csv("PairedMNIST Results Diff.csv", delimiter=";")
14 |
15 | # Get no-cp performances
16 | low_lambda_nodep = df[(df["Model"] == "FactorGAN-no-cp") & (df["Lambda"] == 0.1)]["Diff"].as_matrix()[0]
17 | high_lambda_nodep = df[(df["Model"] == "FactorGAN-no-cp") & (df["Lambda"] == 0.9)]["Diff"].as_matrix()[0]
18 |
19 | # Filter facgan-no-cp
20 | df = df[df["Model"] != "FactorGAN-no-cp"]
21 |
22 | df["Model"] = df["Model"] + ", $\lambda$=" + df["Lambda"].apply(lambda x: str(x))
23 |
24 | ax.axhline(y=low_lambda_nodep, c='black', linestyle='--', label="FactorGAN-no-cp, $\lambda=0.1$", alpha=0.8)
25 | ax.axhline(y=high_lambda_nodep, c='gray', linestyle='--', label="FactorGAN-no-cp, $\lambda=0.9$", alpha=0.8)
26 |
27 | sns.barplot("Paired samples", "Diff", hue="Model", data=df, ax=ax)
28 |
29 | ax.set_ylabel("Dependency metric $d_{dep}$")
30 |
31 | # Sort legend
32 | handles, labels = ax.get_legend_handles_labels()
33 | handles = handles[2:] + handles[0:2]
34 | labels = labels[2:] + labels[0:2]
35 | ax.legend(handles, labels)
36 |
37 | #plt.legend()
38 | fig.tight_layout()
39 | fig.savefig("mnist_dep.pdf")
--------------------------------------------------------------------------------
/analysis/mnist/PlotExamples.py:
--------------------------------------------------------------------------------
1 | '''
2 | Plot PairedMNIST generated examples
3 | '''
4 |
5 | import os
6 |
7 | import imageio as imageio
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | import torch
11 | import torchvision
12 |
13 | fig, axes = plt.subplots(2, 3)
14 | root_path = "../../out/PairedMNIST"
15 |
16 | for idx, samples in enumerate([100, 5000, 20000]):
17 | for model_idx, model in enumerate(["GAN", "factorGAN"]):
18 | gan_paths = [os.path.join(root_path, str(samples) + "_joint_0.9_samedigit_" + model, "gen", "gen_" + str(i) + ".png") for i in range(28)]
19 |
20 | gan_imgs = list()
21 | for file in gan_paths:
22 | gan_imgs.append(imageio.imread(file))
23 | gan_imgs = torch.from_numpy(np.transpose(np.stack(gan_imgs), [0, 3, 1, 2]))
24 | gan_imgs = torchvision.utils.make_grid(gan_imgs, nrow=7, padding=5, pad_value=255.0).permute(1, 2, 0)
25 |
26 | axes[model_idx][idx].imshow(gan_imgs)
27 | axes[model_idx][idx].axis("off")
28 |
29 | plt.subplots_adjust(wspace=0.1, hspace=.1)
30 |
31 | axes[0][0].text(-0.3,0.5, "GAN", size=12, ha="center", transform=axes[0][0].transAxes)
32 | axes[1][0].text(-0.3,0.5, "factorGAN", size=12, ha="center", transform=axes[1][0].transAxes)
33 |
34 | axes[1][0].text(0.5,-0.1, "100", size=12, ha="center", transform=axes[1][0].transAxes)
35 | axes[1][1].text(0.5,-0.1, "500", size=12, ha="center", transform=axes[1][1].transAxes)
36 | axes[1][2].text(0.5,-0.1, "20000", size=12, ha="center", transform=axes[1][2].transAxes)
37 |
38 | plt.savefig("mnist_examples.pdf", bbox_inches="tight")
--------------------------------------------------------------------------------
/analysis/mnist/PlotFID.py:
--------------------------------------------------------------------------------
1 | '''
2 | Plot FID values for PairedMNIST experiment
3 | '''
4 |
5 | import seaborn as sns
6 | import pandas as pd
7 | import matplotlib.pyplot as plt
8 |
9 | sns.set(font_scale=1.1, rc={'text.usetex' : True})
10 |
11 | fig, ax = plt.subplots(figsize=(6,3.5))
12 | # PLOT FID
13 | df = pd.read_csv("PairedMNIST Results FID.csv", delimiter=";")
14 |
15 | # Get no-cp performances
16 | low_lambda_nodep = df[(df["Model"] == "FactorGAN-no-cp") & (df["Lambda"] == 0.1)]["FID"].as_matrix()[0]
17 | high_lambda_nodep = df[(df["Model"] == "FactorGAN-no-cp") & (df["Lambda"] == 0.9)]["FID"].as_matrix()[0]
18 |
19 | # Filter facgan-no-cp
20 | df = df[df["Model"] != "FactorGAN-no-cp"]
21 |
22 | # Combine model with lambda
23 | df["Model"] = df["Model"] + ", $\lambda$=" + df["Lambda"].apply(lambda x: str(x))
24 |
25 | sns.barplot("Paired samples", "FID", hue="Model", data=df, ax=ax)
26 | ax.set_yscale("log")
27 |
28 | ax.axhline(y=low_lambda_nodep, c='black', linestyle='dashed', label="FactorGAN-no-cp, $\lambda=0.1$", alpha=0.8)
29 | ax.axhline(y=high_lambda_nodep, c='gray', linestyle='dashed', label="FactorGAN-no-cp, $\lambda=0.9$", alpha=0.8)
30 |
31 | # PLOT
32 | # Sort legend
33 | handles, labels = ax.get_legend_handles_labels()
34 | handles = handles[2:] + handles[0:2]
35 | labels = labels[2:] + labels[0:2]
36 | ax.legend(handles, labels) # bbox_to_anchor=(1.04,0.5), loc="center left", borderaxespad=0)
37 |
38 | #ax.get_legend().remove()
39 | fig.tight_layout()
40 | fig.savefig("mnist_fid.pdf")
--------------------------------------------------------------------------------
/analysis/mnist/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/mnist/__init__.py
--------------------------------------------------------------------------------
/analysis/mnist/mnist_dep.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/mnist/mnist_dep.pdf
--------------------------------------------------------------------------------
/analysis/mnist/mnist_examples.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/mnist/mnist_examples.pdf
--------------------------------------------------------------------------------
/analysis/mnist/mnist_fid.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/mnist/mnist_fid.pdf
--------------------------------------------------------------------------------
/analysis/source_separation/PlotSDR.py:
--------------------------------------------------------------------------------
1 | '''
2 | Plot SDR values for source separation experiment
3 | '''
4 |
5 | import matplotlib.pyplot as plt
6 | import pandas as pd
7 | import seaborn as sns
8 |
9 | sns.set(font_scale=1.1, rc={'text.usetex' : True})
10 |
11 | df = pd.read_csv("sdr.csv", delimiter=",")
12 |
13 | g = sns.catplot("Songs", "Mean SDR", hue="Model", col="Source", data=df, kind="bar", height=3, aspect=2, sharey=False)
14 |
15 | g.axes[0][0].set_title("Vocals")
16 | g.axes[0][1].set_title("Accompaniment")
17 |
18 | plt.savefig("sdr.pdf")
--------------------------------------------------------------------------------
/analysis/source_separation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/source_separation/__init__.py
--------------------------------------------------------------------------------
/analysis/source_separation/sdr.csv:
--------------------------------------------------------------------------------
1 | Songs,Model,Median SDR,MAD SDR,Mean SDR,SDR Std,Source
2 | 10,GAN,1.983965,1.45067,0.0410285857142858,10.689719135518489,Vocals
3 | 20,GAN,2.08704,1.42766,0.2823284285714284,9.80545661853073,Vocals
4 | 50,GAN,1.4713150000000002,1.2996700000000003,0.07224964285714285,9.248452231130631,Vocals
5 | 10,FactorGAN,3.00948,1.645375,1.7975211571428573,8.113460729156188,Vocals
6 | 20,FactorGAN,3.156565,1.6046550000000002,1.8749713857142856,8.30691717643957,Vocals
7 | 50,FactorGAN,3.4388,1.559725,2.0191757,8.65596928418218,Vocals
8 | 10,GAN,6.17731,1.540635,6.183414271428571,2.625664373284633,Accompaniment
9 | 20,GAN,5.991875,1.5996899999999998,6.233455328571428,2.8197917995690553,Accompaniment
10 | 50,GAN,5.7209900000000005,1.7076749999999996,5.969218571428572,2.8519183877568577,Accompaniment
11 | 10,FactorGAN,6.679995,1.7517950000000004,6.891838142857144,2.9684880821473953,Accompaniment
12 | 20,FactorGAN,6.799595,1.7324549999999999,6.9449088,2.9332867432936025,Accompaniment
13 | 50,FactorGAN,7.022645,1.7381000000000002,7.1121528,2.8895338557354373,Accompaniment
--------------------------------------------------------------------------------
/analysis/source_separation/sdr.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/analysis/source_separation/sdr.pdf
--------------------------------------------------------------------------------
/datasets/AudioSeparationDataset.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import os.path
4 | from multiprocessing import Pool
5 |
6 | import musdb
7 | import numpy as np
8 | import soundfile
9 | import torch
10 | from torch.utils.data.dataset import Dataset
11 |
12 | import Utils
13 |
14 | def preprocess_song(item):
15 | idx, song, out_path, sample_rate, input_width, mode, fft_size, hop_size = item
16 |
17 | if not os.path.exists(os.path.join(out_path, str(idx))):
18 | os.makedirs(os.path.join(out_path, str(idx)))
19 |
20 | length = np.Inf
21 | if mode == "paired" or mode == "mix":
22 | mix_audio, _ = Utils.load(song["mix"], sr=sample_rate, mono=True)
23 | mix, _ = Utils.compute_spectrogram(np.squeeze(mix_audio, 1), fft_size, hop_size)
24 | length = min(mix.shape[1], length)
25 |
26 | if mode == "paired" or mode == "accompaniment":
27 | accompaniment_audio, _ = Utils.load(song["accompaniment"], sr=sample_rate, mono=True)
28 | accompaniment, _ = Utils.compute_spectrogram(np.squeeze(accompaniment_audio, 1), fft_size, hop_size)
29 | length = min(accompaniment.shape[1], length)
30 |
31 | if mode == "paired" or mode == "vocals":
32 | vocals_audio, _ = Utils.load(song["vocals"], sr=sample_rate, mono=True)
33 | vocals, _ = Utils.compute_spectrogram(np.squeeze(vocals_audio, 1), fft_size, hop_size)
34 | length = min(vocals.shape[1], length)
35 |
36 | sample_num = 0
37 | for start_pos in range(0, length - input_width, input_width // 2):
38 | sample = list()
39 | if mode == "paired" or mode == "mix":
40 | sample.append(mix[:, start_pos:start_pos + input_width])
41 |
42 | if mode == "paired" or mode == "accompaniment":
43 | sample.append(accompaniment[:, start_pos:start_pos + input_width])
44 |
45 | if mode == "paired" or mode == "vocals":
46 | sample.append(vocals[:, start_pos:start_pos + input_width])
47 |
48 | # Write current snippet
49 | sample = np.stack(sample, axis=0)
50 | np.save(os.path.join(out_path, str(idx), str(sample_num) + ".npy"), sample)
51 | sample_num += 1
52 |
53 | class MUSDBDataset(Dataset):
54 | def __init__(self, opt, song_idx, mode):
55 | self.opt = opt
56 | self.mode = mode
57 | # Load MUSDB/convert to wav
58 | dataset = getMUSDB(opt.musdb_path)[0]
59 |
60 | self.out_path = os.path.join(opt.preprocessed_dataset_path, mode)
61 |
62 | if not os.path.exists(self.out_path):
63 | # Preprocess audio into spectrogram and write into each sample into a numpy file
64 | p = Pool(10) #multiprocessing.cpu_count())
65 | p.map(preprocess_song, [(curr_song_idx, song, self.out_path, opt.sample_rate, opt.input_width, mode, opt.fft_size, opt.hop_size) for curr_song_idx, song in enumerate(dataset)])
66 |
67 | # Select songs to use for training
68 | file_list = list()
69 | for idx in song_idx:
70 | npy_files = glob.glob(os.path.join(self.out_path, str(idx), "*.npy"))
71 | file_list.extend(npy_files)
72 |
73 | self.dataset = file_list
74 |
75 | def __getitem__(self, index):
76 | return self.npy_loader(self.dataset[index])
77 |
78 | def __len__(self):
79 | return len(self.dataset)
80 |
81 | def npy_loader(self, path):
82 | sample = torch.from_numpy(np.load(path))
83 | return sample
84 |
85 | def getMUSDB(database_path):
86 | mus = musdb.DB(root_dir=database_path, is_wav=False)
87 |
88 | subsets = list()
89 |
90 | for subset in ["train", "test"]:
91 | tracks = mus.load_mus_tracks(subset)
92 | samples = list()
93 |
94 | # Go through tracks
95 | for track in tracks:
96 | # Skip track if mixture is already written, assuming this track is done already
97 | track_path = track.path[:-4]
98 | mix_path = track_path + "_mix.wav"
99 | acc_path = track_path + "_accompaniment.wav"
100 | if os.path.exists(mix_path):
101 | print("WARNING: Skipping track " + mix_path + " since it exists already")
102 |
103 | # Add paths and then skip
104 | paths = {"mix" : mix_path, "accompaniment" : acc_path}
105 | paths.update({key : track_path + "_" + key + ".wav" for key in ["bass", "drums", "other", "vocals"]})
106 |
107 | samples.append(paths)
108 |
109 | continue
110 |
111 | rate = track.rate
112 |
113 | # Go through each instrument
114 | paths = dict()
115 | stem_audio = dict()
116 | for stem in ["bass", "drums", "other", "vocals"]:
117 | path = track_path + "_" + stem + ".wav"
118 | audio = track.targets[stem].audio
119 | soundfile.write(path, audio, rate, "PCM_16")
120 | stem_audio[stem] = audio
121 | paths[stem] = path
122 |
123 | # Add other instruments to form accompaniment
124 | acc_audio = np.clip(sum([stem_audio[key] for key in stem_audio.keys() if key != "vocals"]), -1.0, 1.0)
125 | soundfile.write(acc_path, acc_audio, rate, "PCM_16")
126 | paths["accompaniment"] = acc_path
127 |
128 | # Create mixture
129 | mix_audio = track.audio
130 | soundfile.write(mix_path, mix_audio, rate, "PCM_16")
131 | paths["mix"] = mix_path
132 |
133 | diff_signal = np.abs(mix_audio - acc_audio - stem_audio["vocals"])
134 | print("Maximum absolute deviation from source additivity constraint: " + str(np.max(diff_signal)))# Check if acc+vocals=mix
135 | print("Mean absolute deviation from source additivity constraint: " + str(np.mean(diff_signal)))
136 |
137 | samples.append(paths)
138 |
139 | subsets.append(samples)
140 |
141 | return subsets
--------------------------------------------------------------------------------
/datasets/CropDataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data.dataset import Dataset
2 |
3 | class CropDataset(Dataset):
4 | def __init__(self, dataset, crop_func):
5 | self.dataset = dataset
6 | self.crop_func = crop_func
7 |
8 | def __getitem__(self, index):
9 | sample = self.dataset[index]
10 | return self.crop_func(sample)
11 |
12 | def __len__(self):
13 | return len(self.dataset)
--------------------------------------------------------------------------------
/datasets/DoubleMNISTDataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data.dataset import Dataset
2 | import torch
3 | import numpy as np
4 |
5 | class DoubleMNISTDataset(Dataset):
6 | def __init__(self, mnist_dataset, idx_range, deterministic=True, same_digit_prob=0.9, return_labels=False):
7 | if idx_range is None:
8 | idx_range = range(len(mnist_dataset.data))
9 | self.data = mnist_dataset.data[idx_range].clone().detach().float() / 255.0
10 | self.labels = mnist_dataset.targets[idx_range].clone().detach()
11 |
12 | self.idx_range = range(len(idx_range))
13 |
14 | self.same_digit_prob = same_digit_prob
15 | self.deterministic = deterministic
16 |
17 | self.idx_digits = list()
18 | self.idx_digits_proportion = list()
19 | for digit_num in range(10):
20 | indices = [i for i in self.idx_range if self.labels[i] == digit_num]
21 | self.idx_digits.append(indices)
22 | self.idx_digits_proportion.append(float(len(indices)))
23 | self.idx_digits_proportion /= np.sum(self.idx_digits_proportion)
24 |
25 | self.second_digits = np.full(len(self.idx_range), -1)
26 |
27 | self.return_labels = return_labels
28 |
29 | def __getitem__(self, index):
30 | first_digit, first_digit_label = self.data[index], self.labels[index]
31 | if self.deterministic and self.second_digits[index] != -1:
32 | second_digit_idx = self.second_digits[index]
33 | else:
34 | # Draw same digit with given higher probability
35 | same_digit = (np.random.rand() < self.same_digit_prob)
36 |
37 | if same_digit:
38 | # Draw from pool of same-digit numbers
39 | second_digit_label = first_digit_label
40 | else:
41 | # Draw from pool of different-digit numbers
42 | second_digit_label = np.random.choice(range(10), p=self.idx_digits_proportion)
43 | second_digit_idx = self.idx_digits[second_digit_label.item()][np.random.randint(0, len(self.idx_digits[second_digit_label]))]
44 |
45 | # Finally, save second digit choice if deterministic
46 | if self.deterministic:
47 | self.second_digits[index] = second_digit_idx
48 |
49 | sample = torch.cat([first_digit.view(-1), self.data[second_digit_idx].view(-1)])
50 | if self.return_labels:
51 | return (sample, first_digit_label, second_digit_label)
52 | else:
53 | return sample
54 |
55 |
56 | def __len__(self):
57 | return len(self.idx_range)
58 |
59 | def get_digits(self, digit_label):
60 | return sum([self.idx_digits[d] for d in range(10) if d != digit_label], [])
--------------------------------------------------------------------------------
/datasets/GeneratorInputDataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data.dataset import Dataset
2 | import numpy as np
3 |
4 | class GeneratorInputDataset(Dataset):
5 | def __init__(self, cond_dataset, noise_distribution):
6 | assert(cond_dataset != None or noise_distribution != None)
7 | self.cond_dataset = cond_dataset
8 | self.noise = noise_distribution
9 |
10 | def __getitem__(self, index):
11 | output = list()
12 |
13 | if self.cond_dataset != None:
14 | # Sample input for generator from other dataset
15 | output.append(self.cond_dataset[np.random.randint(0, len(self.cond_dataset))])
16 |
17 | if self.noise != None:
18 | # Sample noise from noise_distribution, if we are using noise in the conditional generator network
19 | output.append(self.noise.sample())
20 |
21 | # Get generator input
22 | return output
23 |
24 | def __len__(self):
25 | if self.cond_dataset != None:
26 | return len(self.cond_dataset)
27 | else:
28 | return 10000
--------------------------------------------------------------------------------
/datasets/InfiniteDataSampler.py:
--------------------------------------------------------------------------------
1 | class InfiniteDataSampler(object):
2 | def __init__(self, data_loader):
3 | '''
4 | Can be used to infinitely loop over a dataset
5 | :param data_loader: Data loader object for a dataset
6 | '''
7 | self.data_loader = data_loader
8 |
9 | if not hasattr(data_loader, "__next__"):
10 | self.data_iter = iter(data_loader)
11 | else:
12 | self.data_iter = data_loader
13 |
14 | def next(self):
15 | return self.__next__()
16 |
17 | def __next__(self):
18 | try:
19 | data = next(self.data_iter)
20 | except StopIteration:
21 | # StopIteration is thrown if dataset ends
22 | # reinitialize data loader
23 | self.data_iter = iter(self.data_loader)
24 | data = next(self.data_iter)
25 |
26 | return data
--------------------------------------------------------------------------------
/datasets/TransformDataSampler.py:
--------------------------------------------------------------------------------
1 | class TransformDataSampler(object):
2 | def __init__(self, data_loader, transform, transform_device):
3 | # create dataloader-iterator
4 | self.data_loader = data_loader
5 |
6 | if not hasattr(data_loader, "__next__"):
7 | self.data_iter = iter(data_loader)
8 | else:
9 | self.data_iter = data_loader
10 |
11 | self.transform = transform
12 | self.device = transform_device
13 |
14 | def next(self):
15 | return self.__next__()
16 |
17 | def __next__(self):
18 | data = next(self.data_iter)
19 |
20 | # Put transform input to proper device first
21 | if self.device != None:
22 | if isinstance(data, list):
23 | data = [item.to(self.device) for item in data]
24 | else:
25 | data = data.to(self.device)
26 |
27 | # Transform batch of samples
28 | data = self.transform(data).detach()
29 |
30 | return data
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/image2image/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from datasets.image2image.base_dataset import BaseDataset
3 | import os
4 |
5 | def get_aligned_dataset(opt, subset):
6 | opt.dataset_mode = "aligned"
7 | opt.dataroot = os.path.join("datasets", "image2image", opt.dataset)
8 | opt.phase = subset
9 | opt.direction = "AtoB"
10 | opt.no_flip = True
11 | dataset = create_dataset(opt)
12 | return dataset
13 |
14 | def find_dataset_using_name(dataset_name):
15 | # Given the option --dataset_mode [datasetname],
16 | # the file "data/datasetname_dataset.py"
17 | # will be imported.
18 | dataset_filename = "datasets.image2image." + dataset_name + "_dataset"
19 | datasetlib = importlib.import_module(dataset_filename)
20 |
21 | # In the file, the class called DatasetNameDataset() will
22 | # be instantiated. It has to be a subclass of BaseDataset,
23 | # and it is case-insensitive.
24 | dataset = None
25 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
26 | for name, cls in datasetlib.__dict__.items():
27 | if name.lower() == target_dataset_name.lower() \
28 | and issubclass(cls, BaseDataset):
29 | dataset = cls
30 |
31 | if dataset is None:
32 | print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
33 | exit(0)
34 |
35 | return dataset
36 |
37 | def get_option_setter(dataset_name):
38 | dataset_class = find_dataset_using_name(dataset_name)
39 | return dataset_class.modify_commandline_options
40 |
41 |
42 | def create_dataset(opt):
43 | dataset = find_dataset_using_name(opt.dataset_mode)
44 | instance = dataset()
45 | instance.initialize(opt)
46 |
47 | # Add fixed no. of channel information for A and B
48 | if opt.dataset == "edges2shoes":
49 | dataset.A_nc = 1
50 | dataset.B_nc = 3
51 | else:
52 | dataset.A_nc = 3
53 | dataset.B_nc = 3
54 |
55 | print("dataset [%s] was created" % (instance.name()))
56 | return instance
--------------------------------------------------------------------------------
/datasets/image2image/aligned_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import random
3 | import torchvision.transforms as transforms
4 | import torch
5 | from datasets.image2image.base_dataset import BaseDataset
6 | from datasets.image2image.image_folder import make_dataset
7 | from PIL import Image
8 |
9 |
10 | class AlignedDataset(BaseDataset):
11 | @staticmethod
12 | def modify_commandline_options(parser, is_train):
13 | return parser
14 |
15 | def initialize(self, opt):
16 | self.opt = opt
17 | self.root = opt.dataroot
18 | self.dir_AB = os.path.join(opt.dataroot, opt.phase)
19 | self.AB_paths = sorted(make_dataset(self.dir_AB))
20 | #assert(opt.resize_or_crop == 'resize_and_crop')
21 |
22 | def __getitem__(self, index):
23 | AB_path = self.AB_paths[index]
24 | AB = Image.open(AB_path).convert('RGB')
25 | w, h = AB.size
26 | #assert(self.opt.loadSize >= self.opt.fineSize)
27 | w2 = int(w / 2)
28 | A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
29 | B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
30 | A = transforms.ToTensor()(A)
31 | B = transforms.ToTensor()(B)
32 |
33 | if self.opt.direction == 'BtoA':
34 | input_nc = self.B_nc
35 | output_nc = self.A_nc
36 | else:
37 | input_nc = self.A_nc
38 | output_nc = self.B_nc
39 |
40 | if (not self.opt.no_flip) and random.random() < 0.5:
41 | idx = [i for i in range(A.size(2) - 1, -1, -1)]
42 | idx = torch.LongTensor(idx)
43 | A = A.index_select(2, idx)
44 | B = B.index_select(2, idx)
45 |
46 | if input_nc == 1: # RGB to gray
47 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
48 | A = tmp.unsqueeze(0)
49 |
50 | if output_nc == 1: # RGB to gray
51 | tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
52 | B = tmp.unsqueeze(0)
53 |
54 | #return {'A': A, 'B': B,
55 | # 'A_paths': AB_path, 'B_paths': AB_path}
56 | #print(time.time() - t)
57 | return torch.cat([A,B])
58 |
59 | def __len__(self):
60 | return len(self.AB_paths)
61 |
62 | def name(self):
63 | return 'AlignedDataset'
64 |
--------------------------------------------------------------------------------
/datasets/image2image/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from PIL import Image
3 | import torchvision.transforms as transforms
4 |
5 |
6 | class BaseDataset(data.Dataset):
7 | def __init__(self):
8 | super(BaseDataset, self).__init__()
9 |
10 | def name(self):
11 | return 'BaseDataset'
12 |
13 | @staticmethod
14 | def modify_commandline_options(parser, is_train):
15 | return parser
16 |
17 | def initialize(self, opt):
18 | pass
19 |
20 | def __len__(self):
21 | return 0
22 |
23 |
24 | def get_transform(opt):
25 | osize = [opt.loadSize, opt.loadSize]
26 | transform_list = [transforms.Resize(osize, Image.BICUBIC), transforms.ToTensor()]
27 | return transforms.Compose(transform_list)
28 |
29 |
30 | # just modify the width and height to be multiple of 4
31 | def __adjust(img):
32 | ow, oh = img.size
33 |
34 | # the size needs to be a multiple of this number,
35 | # because going through generator network may change img size
36 | # and eventually cause size mismatch error
37 | mult = 4
38 | if ow % mult == 0 and oh % mult == 0:
39 | return img
40 | w = (ow - 1) // mult
41 | w = (w + 1) * mult
42 | h = (oh - 1) // mult
43 | h = (h + 1) * mult
44 |
45 | if ow != w or oh != h:
46 | __print_size_warning(ow, oh, w, h)
47 |
48 | return img.resize((w, h), Image.BICUBIC)
49 |
50 |
51 | def __scale_width(img, target_width):
52 | ow, oh = img.size
53 |
54 | # the size needs to be a multiple of this number,
55 | # because going through generator network may change img size
56 | # and eventually cause size mismatch error
57 | mult = 4
58 | assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult
59 | if (ow == target_width and oh % mult == 0):
60 | return img
61 | w = target_width
62 | target_height = int(target_width * oh / ow)
63 | m = (target_height - 1) // mult
64 | h = (m + 1) * mult
65 |
66 | if target_height != h:
67 | __print_size_warning(target_width, target_height, w, h)
68 |
69 | return img.resize((w, h), Image.BICUBIC)
70 |
71 |
72 | def __print_size_warning(ow, oh, w, h):
73 | if not hasattr(__print_size_warning, 'has_printed'):
74 | print("The image size needs to be a multiple of 4. "
75 | "The loaded image size was (%d, %d), so it was adjusted to "
76 | "(%d, %d). This adjustment will be done to all images "
77 | "whose sizes are not multiples of 4" % (ow, oh, w, h))
78 | __print_size_warning.has_printed = True
79 |
--------------------------------------------------------------------------------
/datasets/image2image/download_image2image.sh:
--------------------------------------------------------------------------------
1 | FILE=$1
2 |
3 | if [[ $FILE != "cityscapes" && $FILE != "night2day" && $FILE != "edges2handbags" && $FILE != "edges2shoes" && $FILE != "facades" && $FILE != "maps" ]]; then
4 | echo "Available datasets are cityscapes, night2day, edges2handbags, edges2shoes, facades, maps"
5 | exit 1
6 | fi
7 |
8 | echo "Specified [$FILE]"
9 |
10 | URL=http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/$FILE.tar.gz
11 | TAR_FILE=./$FILE.tar.gz
12 | TARGET_DIR=./$FILE/
13 | wget -N $URL -O $TAR_FILE
14 | mkdir -p $TARGET_DIR
15 | tar -zxvf $TAR_FILE -C ./
16 | rm $TAR_FILE
--------------------------------------------------------------------------------
/datasets/image2image/image_folder.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4 | # Modified the original code so that it also loads images from the current
5 | # directory as well as the subdirectories
6 | ###############################################################################
7 |
8 | import torch.utils.data as data
9 |
10 | from PIL import Image
11 | import os
12 | import os.path
13 |
14 | IMG_EXTENSIONS = [
15 | '.jpg', '.JPG', '.jpeg', '.JPEG',
16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
17 | ]
18 |
19 |
20 | def is_image_file(filename):
21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22 |
23 |
24 | def make_dataset(dir):
25 | images = []
26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
27 |
28 | for root, _, fnames in sorted(os.walk(dir)):
29 | for fname in fnames:
30 | if is_image_file(fname):
31 | path = os.path.join(root, fname)
32 | images.append(path)
33 |
34 | return images
35 |
36 |
37 | def default_loader(path):
38 | return Image.open(path).convert('RGB')
39 |
40 |
41 | class ImageFolder(data.Dataset):
42 |
43 | def __init__(self, root, transform=None, return_paths=False,
44 | loader=default_loader):
45 | imgs = make_dataset(root)
46 | if len(imgs) == 0:
47 | raise(RuntimeError("Found 0 images in: " + root + "\n"
48 | "Supported image extensions are: " +
49 | ",".join(IMG_EXTENSIONS)))
50 |
51 | self.root = root
52 | self.imgs = imgs
53 | self.transform = transform
54 | self.return_paths = return_paths
55 | self.loader = loader
56 |
57 | def __getitem__(self, index):
58 | path = self.imgs[index]
59 | img = self.loader(path)
60 | if self.transform is not None:
61 | img = self.transform(img)
62 | if self.return_paths:
63 | return img, path
64 | else:
65 | return img
66 |
67 | def __len__(self):
68 | return len(self.imgs)
69 |
--------------------------------------------------------------------------------
/eval/Cityscapes.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import namedtuple
3 |
4 | # THIS IS TAKEN FROM https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py
5 | #--------------------------------------------------------------------------------
6 | # Definitions
7 | #--------------------------------------------------------------------------------
8 |
9 | # a label and all meta information
10 | Label = namedtuple( 'Label' , [
11 |
12 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... .
13 | # We use them to uniquely name a class
14 |
15 | 'id' , # An integer ID that is associated with this label.
16 | # The IDs are used to represent the label in ground truth images
17 | # An ID of -1 means that this label does not have an ID and thus
18 | # is ignored when creating ground truth images (e.g. license plate).
19 | # Do not modify these IDs, since exactly these IDs are expected by the
20 | # evaluation server.
21 |
22 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create
23 | # ground truth images with train IDs, using the tools provided in the
24 | # 'preparation' folder. However, make sure to validate or submit results
25 | # to our evaluation server using the regular IDs above!
26 | # For trainIds, multiple labels might have the same ID. Then, these labels
27 | # are mapped to the same class in the ground truth images. For the inverse
28 | # mapping, we use the label that is defined first in the list below.
29 | # For example, mapping all void-type classes to the same ID in training,
30 | # might make sense for some approaches.
31 | # Max value is 255!
32 |
33 | 'category' , # The name of the category that this label belongs to
34 |
35 | 'categoryId' , # The ID of this category. Used to create ground truth images
36 | # on category level.
37 |
38 | 'hasInstances', # Whether this label distinguishes between single instances or not
39 |
40 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
41 | # during evaluations or not
42 |
43 | 'color' , # The color of this label
44 | ] )
45 |
46 |
47 | #--------------------------------------------------------------------------------
48 | # A list of all labels
49 | #--------------------------------------------------------------------------------
50 |
51 | # Please adapt the train IDs as appropriate for your approach.
52 | # Note that you might want to ignore labels with ID 255 during training.
53 | # Further note that the current train IDs are only a suggestion. You can use whatever you like.
54 | # Make sure to provide your results using the original IDs and not the training IDs.
55 | # Note that many IDs are ignored in evaluation and thus you never need to predict these!
56 |
57 | labels = [
58 | # name id trainId category catId hasInstances ignoreInEval color
59 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
60 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
61 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
62 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
63 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
64 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ),
65 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ),
66 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ),
67 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ),
68 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ),
69 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ),
70 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ),
71 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ),
72 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ),
73 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ),
74 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ),
75 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ),
76 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ),
77 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ),
78 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ),
79 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ),
80 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ),
81 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ),
82 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ),
83 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ),
84 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ),
85 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ),
86 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ),
87 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ),
88 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ),
89 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ),
90 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ),
91 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ),
92 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ),
93 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ),
94 | ]
95 | #END OF CITYSCAPES SCRIPTS
96 |
97 | def get_L2(opt, device, G, val_data, G_noise):
98 | '''
99 | Compute average L2 loss for a generator model on validation data
100 | :param opt: Option dict
101 | :param device: Device to use
102 | :param G: Generator model
103 | :param val_data: Validation data
104 | :param G_noise: Noise input to generator
105 | :return: Average L2 loss over the whole validation data
106 | '''
107 | G.eval()
108 | batch_losses = list()
109 | for val_batch in val_data:
110 | # Get next batch and sample generator noise
111 | val_batch = val_batch.to(device)
112 | noise = G_noise.sample_n(opt.batchSize).to(device)
113 |
114 | # Get fake samples from generator
115 | fake_sample = G([val_batch[:,:3,:,:], noise])
116 |
117 | # Evaluate outputs
118 | L2 = ((fake_sample - val_batch)[:,3:,:,:] ** 2).mean()
119 | batch_losses.append(L2.detach().cpu().numpy())
120 | return np.mean(np.array(batch_losses))
121 |
122 | def get_pixel_acc(opt, device, G, val_data, G_noise):
123 | '''
124 | Get pixel-wise classification accuracy for Cityscapes dataset
125 | :param opt: Option dictionary
126 | :param device: Device to use
127 | :param G: Generator model
128 | :param val_data: Validation dataset
129 | :param G_noise: Generator input noise
130 | :return: Average pixel-wise accuracy
131 | '''
132 | all_labels_colors = np.array([l.color for l in labels])
133 |
134 | eval_labels = [l for l in labels if not l.ignoreInEval]
135 | eval_labels_colors = np.array([l.color for l in eval_labels])
136 |
137 | G.eval()
138 | batch_losses = list()
139 | for val_batch in val_data:
140 | label_batch = val_batch[:,3:,:,:].detach().cpu().numpy() * 255.0
141 |
142 | val_batch = val_batch.to(device)
143 | noise = G_noise.sample_n(opt.batchSize).to(device)
144 |
145 | # Get fake samples from generator
146 | fake_sample = G([val_batch[:,:3,:,:], noise]).detach().cpu().numpy()
147 | fake_pred = fake_sample[:,3:,:,:] * 255.0
148 |
149 | # EVALUATE MODEL OUTPUTS
150 | # Assign colour of closest label to the prediction image
151 | dist_imag = list()
152 | for l in eval_labels:
153 | label_imag = np.zeros(fake_pred.shape)
154 | label_imag[:,0,:,:] = l.color[0]
155 | label_imag[:, 1, :, :] = l.color[1]
156 | label_imag[:, 2, :, :] = l.color[2]
157 | dist_imag.append(np.sqrt(np.sum(np.square(fake_pred - label_imag), axis=1)))
158 |
159 | dist_imag = np.array(dist_imag)
160 | pred_idx = np.argmin(dist_imag, axis=0)
161 | pred_sample = np.transpose(eval_labels_colors[pred_idx], [0, 3, 1, 2]).astype(np.float32)
162 |
163 | # Assign colour of closest label to label image (this is necessary due to image compression artifacts)
164 | dist_imag = list()
165 | for l in labels:
166 | label_imag = np.zeros(fake_pred.shape)
167 | label_imag[:, 0, :, :] = l.color[0]
168 | label_imag[:, 1, :, :] = l.color[1]
169 | label_imag[:, 2, :, :] = l.color[2]
170 | dist_imag.append(np.sqrt(np.sum(np.square(label_batch - label_imag), axis=1)))
171 |
172 | dist_imag = np.array(dist_imag)
173 | label_idx = np.argmin(dist_imag, axis=0)
174 | label_batch = np.transpose(all_labels_colors[label_idx], [0, 3, 1, 2]).astype(np.float32)
175 |
176 | # Create a mask to remove those pixels from evaluation with a ground truth label that is not part of the evaluation
177 | valid_mask = np.ones([pred_sample.shape[0], pred_sample.shape[2], pred_sample.shape[3]])
178 | for l in labels: # Go through all ignored labels
179 | if l.ignoreInEval:
180 | # Create an image with just one colour in it equal to the colour of this label
181 | label_imag = np.zeros(fake_pred.shape)
182 | label_imag[:, 0, :, :] = l.color[0]
183 | label_imag[:, 1, :, :] = l.color[1]
184 | label_imag[:, 2, :, :] = l.color[2]
185 |
186 | # Set mask to 0 where the label image has the same colour as the label to ignore
187 | valid_mask[np.sum(np.abs(label_batch - label_imag), axis=1) < 1e-5] = 0.0
188 |
189 | # Compute accuracy
190 | correct_preds = np.sum(np.abs(pred_sample - label_batch), axis=1) < 1e-5
191 | correct = np.sum(np.logical_and(correct_preds, valid_mask != 0.0), axis=(1,2))
192 | total = np.sum(valid_mask != 0.0, axis=(1,2))
193 |
194 | batch_losses.append(np.mean(correct.astype(np.float32) / total.astype(np.float32)))
195 | return np.mean(np.array(batch_losses))
--------------------------------------------------------------------------------
/eval/FID.py:
--------------------------------------------------------------------------------
1 | '''
2 | Calculates the Frechet Inception Distance (FID) to evalulate GANs
3 | The FID metric calculates the distance between two distributions of images.
4 | Typically, we have summary statistics (mean & covariance matrix) of one
5 | of these distributions, while the 2nd distribution is given by a GAN.
6 | When run as a stand-alone program, it compares the distribution of
7 | images that are stored as PNG/JPEG at a specified location with a
8 | distribution given by summary statistics (in pickle format).
9 | The FID is calculated by assuming that X_1 and X_2 are the activations of
10 | the pool_3 layer of the inception net for generated samples and real world
11 | samples respectivly.
12 | See --help to see further details.
13 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead of Tensorflow
14 |
15 | Copyright 2018 Institute of Bioinformatics, JKU Linz
16 | Licensed under the Apache License, Version 2.0 (the "License");
17 | you may not use this file except in compliance with the License.
18 | You may obtain a copy of the License at
19 | http://www.apache.org/licenses/LICENSE-2.0
20 | Unless required by applicable law or agreed to in writing, software
21 | distributed under the License is distributed on an "AS IS" BASIS,
22 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23 | See the License for the specific language governing permissions and
24 | limitations under the License.
25 | '''
26 |
27 | import glob
28 | import os
29 |
30 | import torch
31 | import numpy as np
32 | from imageio import imread
33 | from scipy import linalg
34 | from torch.utils.data import TensorDataset, DataLoader
35 |
36 | def get_activations(dataset, model, device, crop=None, verbose=False):
37 | """Calculates the activations of the pool_3 layer for all images.
38 | Params:
39 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values
40 | must lie between 0 and 1.
41 | -- model : Instance of inception model
42 | -- batch_size : the images numpy array is split into batches with
43 | batch size batch_size. A reasonable batch size depends
44 | on the hardware.
45 | -- dims : Dimensionality of features returned by Inception
46 | -- cuda : If set to True, use GPU
47 | -- verbose : If set to True and parameter out_step is given, the number
48 | of calculated batches is reported.
49 | Returns:
50 | -- A numpy array of dimension (num images, dims) that contains the
51 | activations of the given tensor when feeding inception with the
52 | query tensor.
53 | """
54 | model.eval()
55 |
56 | # Compute activations for real data
57 | act = list()
58 | for batch in dataset:
59 | if isinstance(batch, list) or isinstance(batch, tuple):
60 | assert(len(batch) == 1)
61 | batch = batch[0]
62 |
63 | batch = batch.view((batch.shape[0], 1, -1, 28)).to(device)
64 | # Crop batch samples if desired
65 | if crop != None:
66 | batch = crop(batch)
67 | # Get classifier activation
68 | act.append(model(batch, return_hidden=True).detach().cpu().numpy())
69 | act = np.concatenate(act)
70 | act = np.reshape(act, (-1, act.shape[-1]))
71 |
72 | if verbose:
73 | print(' done')
74 |
75 | return act
76 |
77 |
78 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
79 | """Numpy implementation of the Frechet Distance.
80 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
81 | and X_2 ~ N(mu_2, C_2) is
82 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
83 | Stable version by Dougal J. Sutherland.
84 | Params:
85 | -- mu1 : Numpy array containing the activations of a layer of the
86 | inception net (like returned by the function 'get_predictions')
87 | for generated samples.
88 | -- mu2 : The sample mean over activations, precalculated on an
89 | representive data set.
90 | -- sigma1: The covariance matrix over activations for generated samples.
91 | -- sigma2: The covariance matrix over activations, precalculated on an
92 | representive data set.
93 | Returns:
94 | -- : The Frechet Distance.
95 | """
96 |
97 | mu1 = np.atleast_1d(mu1)
98 | mu2 = np.atleast_1d(mu2)
99 |
100 | sigma1 = np.atleast_2d(sigma1)
101 | sigma2 = np.atleast_2d(sigma2)
102 |
103 | assert mu1.shape == mu2.shape, \
104 | 'Training and test mean vectors have different lengths'
105 | assert sigma1.shape == sigma2.shape, \
106 | 'Training and test covariances have different dimensions'
107 |
108 | diff = mu1 - mu2
109 |
110 | # Product might be almost singular
111 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
112 | if not np.isfinite(covmean).all():
113 | msg = ('fid calculation produces singular product; '
114 | 'adding %s to diagonal of cov estimates') % eps
115 | print(msg)
116 | offset = np.eye(sigma1.shape[0]) * eps
117 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
118 |
119 | # Numerical error might give slight imaginary component
120 | if np.iscomplexobj(covmean):
121 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
122 | m = np.max(np.abs(covmean.imag))
123 | raise ValueError('Imaginary component {}'.format(m))
124 | covmean = covmean.real
125 |
126 | tr_covmean = np.trace(covmean)
127 |
128 | return (diff.dot(diff) + np.trace(sigma1) +
129 | np.trace(sigma2) - 2 * tr_covmean)
130 |
131 |
132 | def calculate_activation_statistics(dataset, model, device, crop, verbose=False):
133 | """Calculation of the statistics used by the FID.
134 | Params:
135 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values
136 | must lie between 0 and 1.
137 | -- model : Instance of inception model
138 | -- batch_size : The images numpy array is split into batches with
139 | batch size batch_size. A reasonable batch size
140 | depends on the hardware.
141 | -- dims : Dimensionality of features returned by Inception
142 | -- cuda : If set to True, use GPU
143 | -- verbose : If set to True and parameter out_step is given, the
144 | number of calculated batches is reported.
145 | Returns:
146 | -- mu : The mean over samples of the activations of the pool_3 layer of
147 | the inception model.
148 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
149 | the inception model.
150 | """
151 | act = get_activations(dataset, model, device, crop, verbose)
152 | mu = np.mean(act, axis=0)
153 | sigma = np.cov(act, rowvar=False)
154 | return mu, sigma
155 |
156 |
157 | def _compute_statistics_of_path(path, model, batch_size, device, crop):
158 | if path.endswith('.npz'):
159 | f = np.load(path)
160 | m, s = f['mu'][:], f['sigma'][:]
161 | f.close()
162 | else:
163 | files = glob.glob(os.path.join(path, '*.jpg')) + glob.glob(os.path.join(path, '*.png'))
164 |
165 | imgs = np.array([np.mean(imread(str(fn)).astype(np.float32), axis=2) / 255.0 for fn in files])
166 | imgs = torch.stack([torch.Tensor(i) for i in imgs])
167 |
168 | dataset = TensorDataset(imgs)
169 | dataset = DataLoader(dataset, batch_size)
170 |
171 | m, s = calculate_activation_statistics(dataset, model, device, crop)
172 |
173 | return m, s
174 |
175 | def calculate_fid_given_paths(model, paths, batch_size, cuda):
176 | """Calculates the FID of two paths"""
177 | for p in paths:
178 | if not os.path.exists(p):
179 | raise RuntimeError('Invalid path: %s' % p)
180 |
181 | m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size,
182 | cuda)
183 | m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size,
184 | cuda)
185 | fid_value = calculate_frechet_distance(m1, s1, m2, s2)
186 |
187 | return fid_value
188 |
189 | def evaluate_MNIST(opt, classifier, real_dataset, generated_path, device, crop_real=None, crop_fake=None):
190 | # Check if we have the real classifier activations pre-computed and saved already
191 | if os.path.exists("MNIST_classifier_stats.npz"):
192 | # Load saved activations for real data
193 | real_stats = np.load("MNIST_classifier_stats.npz")
194 | m1, s1 = real_stats['m'], real_stats['s']
195 | else:
196 | real_activations = get_activations(real_dataset, classifier, device, crop_real)
197 |
198 | # Compute statistics of activations and save
199 | m1 = np.mean(real_activations, axis=0)
200 | s1 = np.cov(real_activations, rowvar=False)
201 | np.savez("MNIST_classifier_stats.npz", m=m1, s=s1)
202 |
203 | m2, s2 = _compute_statistics_of_path(generated_path, classifier, opt.batchSize, device, crop_fake)
204 | fid_value = calculate_frechet_distance(m1, s1, m2, s2)
205 |
206 | return fid_value
--------------------------------------------------------------------------------
/eval/LS.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def lsgan_loss(model, real_input, fake_input, device):
5 | '''
6 | Compute LS-GAN loss function
7 | :param model: LS Discriminator model
8 | :param real_input: Real input data
9 | :param fake_input: Fake input data
10 | :param device: Device to use
11 | :return: LS loss
12 | '''
13 | real_output = model(real_input.to(device))
14 | fake_output = model(fake_input.to(device))
15 |
16 | a = 0.0
17 | b = 1.0
18 |
19 | return ((real_output - b)**2).mean() + ((fake_output - a)**2).mean()
20 |
21 | def train_lsgan(model, real_data_loader, fake_data_loader, device):
22 | '''
23 | Trains LS discriminator model on real and fake data
24 | :param model: LS discriminator model
25 | :param real_data_loader: Real training data
26 | :param fake_data_loader: Fake training data
27 | :param device: Device to use
28 | '''
29 | model.train()
30 | NUM_EPOCHS = 40
31 | LR = 1e-4
32 |
33 | optim = torch.optim.Adam(model.parameters(), lr=LR)
34 |
35 | for epoch in range(NUM_EPOCHS):
36 | print("Epoch " + str(epoch))
37 | for real_batch in real_data_loader:
38 | fake_batch = next(fake_data_loader)
39 |
40 | optim.zero_grad()
41 | loss = lsgan_loss(model, real_batch, fake_batch, device)
42 | print(loss.item())
43 | loss.backward()
44 | optim.step()
45 |
46 | print("Finished training LSGAN Disc")
47 |
48 | def lsgan_test_loss(model, real_data_loader, fake_data_loader, device):
49 | '''
50 | Obtains LS distance for pre-trained LS discriminator on test set
51 | :param model: Pre-trained LS discriminator
52 | :param real_data_loader: Real test data
53 | :param fake_data_loader: Fake test data
54 | :param device: Device to use
55 | :return: LS distance
56 | '''
57 | model.eval()
58 |
59 | batch_losses = list()
60 | batch_weights = list()
61 | for real_batch in real_data_loader:
62 | fake_batch = next(fake_data_loader)
63 |
64 | batch_losses.append(lsgan_loss(model, real_batch, fake_batch, device).item())
65 | batch_weights.append(real_batch.shape[0])
66 |
67 | total_loss = np.sum(np.array(batch_losses) * np.array(batch_weights)) / np.sum(np.array(batch_weights))
68 | print("LS: " + str(total_loss))
69 | return total_loss
70 |
71 | def compute_ls_metric(classifier_factory, real_train, real_test, generator_sampler, repetitions, device):
72 | '''
73 | Computes LS metric for a generator model for evaluation
74 | :param classifier_factory: Function yielding a newly initialized (different each time!) LS classifier when called
75 | :param real_train: Real "test-train" dataset for training the LS discriminator
76 | :param real_test: Real "test-test" dataset for measuring LS distance after training LS discriminator
77 | :param generator_sampler: Generator output dataset
78 | :param repetitions: Number of times the classifier should be trained
79 | :param device: Device to use
80 | :return: List of LS distance metrics obtained for each training run (length "repetitions")
81 | '''
82 | losses = list()
83 | for _ in range(repetitions):
84 | classifier = classifier_factory()
85 | train_lsgan(classifier, real_train, generator_sampler, device)
86 | losses.append(lsgan_test_loss(classifier, real_test, generator_sampler, device))
87 | del classifier
88 | return losses
--------------------------------------------------------------------------------
/eval/SourceSeparation.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import musdb
3 | import museval
4 | import numpy as np
5 | import torch
6 | import glob
7 | import os
8 | import json
9 |
10 | import Utils
11 |
12 |
13 | def produce_musdb_source_estimates(model_config, model, model_noise, output_path, subsets=None):
14 | '''
15 | Predicts source estimates for MUSDB for a given model checkpoint and configuration, and evaluate them.
16 | :param model_config: Model configuration of the model to be evaluated
17 | :param load_model: Model checkpoint path
18 | :return:
19 | '''
20 | print("Evaluating trained model on MUSDB and saving source estimate audio to " + str(output_path))
21 | model.eval()
22 |
23 | mus = musdb.DB(root_dir=model_config.musdb_path)
24 | predict_fun = lambda track : predict(track, model_config, model, model_noise, output_path)
25 | assert(mus.test(predict_fun))
26 | mus.run(predict_fun, estimates_dir=output_path, subsets=subsets)
27 |
28 |
29 | def predict(track, model_config, model, model_noise, results_dir=None):
30 | '''
31 | Function in accordance with MUSB evaluation API. Takes MUSDB track object and computes corresponding source estimates, as well as calls evlauation script.
32 | Model has to be saved beforehand into a pickle file containing model configuration dictionary and checkpoint path!
33 | :param track: Track object
34 | :param results_dir: Directory where SDR etc. values should be saved
35 | :return: Source estimates dictionary
36 | '''
37 |
38 | # Get noise once, use that for all predictions to keep consistency
39 | noise = model_noise.sample()
40 |
41 | # Determine input and output shapes, if we use U-net as separator
42 | sep_input_shape = [1, 1, model_config.input_height, model_config.input_width] # [N, C, H, W]
43 |
44 | print("Testing...")
45 |
46 | mix_audio, orig_sr, mix_channels = track.audio, track.rate, track.audio.shape[1] # Audio has (n_samples, n_channels) shape
47 | separator_preds = predict_track(model_config, model, noise, mix_audio, orig_sr, sep_input_shape, sep_input_shape)
48 |
49 | # Upsample predicted source audio and convert to stereo. Make sure to resample back to the exact number of samples in the original input (with fractional orig_sr/new_sr this causes issues otherwise)
50 | pred_audio = {name : librosa.resample(separator_preds[name], model_config.sample_rate, orig_sr)[:len(mix_audio)] for name in separator_preds.keys()}
51 |
52 | if mix_channels > 1: # Convert to multichannel if mixture input was multichannel by duplicating mono estimate
53 | pred_audio = {name : np.repeat(np.expand_dims(pred_audio[name], 1), mix_channels, axis=1) for name in pred_audio.keys()}
54 |
55 | # Evaluate using museval, if we are currently evaluating MUSDB
56 | if results_dir is not None:
57 | scores = museval.eval_mus_track(track, pred_audio, output_dir=results_dir, win=15, hop=15.0)
58 |
59 | # print nicely formatted mean scores
60 | print(scores)
61 |
62 | return pred_audio
63 |
64 |
65 | def predict_track(model_config, model, model_noise, mix_audio, mix_sr, sep_input_shape, sep_output_shape):
66 | '''
67 | Outputs source estimates for a given input mixture signal mix_audio [n_frames, n_channels]
68 | It iterates through the track, collecting segment-wise predictions to form the output.
69 | :param model_config: Model configuration dictionary
70 | :param mix_audio: [n_frames, n_channels] audio signal (numpy array). Can have higher sampling rate or channels than the model supports, will be downsampled correspondingly.
71 | :param mix_sr: Sampling rate of mix_audio
72 | :param sep_input_shape: Input shape of separator
73 | :param sep_output_shape: Input shape of separator
74 | :return:
75 | '''
76 | # Load mixture, convert to mono and downsample then
77 | assert(len(mix_audio.shape) == 2)
78 |
79 | # Prepare mixture
80 | mix_audio = np.mean(mix_audio, axis=1)
81 | mix_audio = librosa.resample(mix_audio, mix_sr, model_config.sample_rate)
82 |
83 | mix_audio = librosa.util.fix_length(mix_audio, len(mix_audio) + model_config.fft_size // 2)
84 |
85 | # Convert to spectrogram
86 | mix_mags, mix_ph = Utils.compute_spectrogram(mix_audio, model_config.fft_size, model_config.hop_size)
87 |
88 | # Preallocate source predictions (same shape as input mixture)
89 | source_time_frames = mix_mags.shape[1]
90 | source_preds = {name : np.zeros(mix_mags.shape, np.float32) for name in ["accompaniment", "vocals"]}
91 |
92 | input_time_frames = sep_input_shape[3]
93 | output_time_frames = sep_output_shape[3]
94 |
95 | # Iterate over mixture magnitudes, fetch network rpediction
96 | for source_pos in range(0, source_time_frames, output_time_frames):
97 | # If this output patch would reach over the end of the source spectrogram, set it so we predict the very end of the output, then stop
98 | if source_pos + output_time_frames > source_time_frames:
99 | source_pos = source_time_frames - output_time_frames
100 |
101 | # Prepare mixture excerpt by selecting time interval
102 | mix_part = mix_mags[:, source_pos:source_pos + input_time_frames]
103 | mix_part = np.expand_dims(np.expand_dims(mix_part, axis=0), axis=0)
104 |
105 | device = next(model.parameters()).device
106 | source_parts = model([torch.from_numpy(mix_part).to(device), model_noise.to(device)]).detach().cpu().numpy()
107 |
108 | # Save predictions
109 | source_preds["accompaniment"][:,source_pos:source_pos + output_time_frames] = source_parts[0, 1]
110 | if source_parts[0].shape[0] > 2:
111 | source_preds["vocals"][:, source_pos:source_pos + output_time_frames] = source_parts[0, 2]
112 | else:
113 | source_preds["vocals"][:, source_pos:source_pos + output_time_frames] = source_parts[0, 1] # Copy acc prediction into vocals for acc-only model
114 |
115 | # Convert predictions back to audio signal
116 | for key in source_preds.keys():
117 | mags = Utils.denormalise_spectrogram(source_preds[key])
118 | source_preds[key] = Utils.spectrogramToAudioFile(mags, model_config.fft_size, model_config.hop_size, phase=np.angle(mix_ph))
119 |
120 | return source_preds
121 |
122 | def compute_mean_metrics(json_folder, compute_averages=True, metric="SDR"):
123 | '''
124 | Computes averages or collects evaluation metrics produced from MUSDB evaluation of a separator
125 | (see "produce_musdb_source_estimates" function), namely the mean, standard deviation, median, and median absolute
126 | deviation (MAD). Function is used to produce the results in the paper.
127 | Averaging ignores NaN values arising from parts where a source is silent
128 | :param json_folder: Path to the folder in which a collection of json files was written by the MUSDB evaluation library, one for each song.
129 | This is the output of the "produce_musdb_source_estimates" function.(By default, this is model_config["estimates_path"] + test or train)
130 | :param compute_averages: Whether to compute the average over all song segments (to get final evaluation measures) or to return the full list of segments
131 | :param metric: Which metric to evaluate (either "SDR", "SIR", "SAR" or "ISR")
132 | :return: IF compute_averages is True, returns a list with length equal to the number of separated sources, with each list element a tuple of (median, MAD, mean, SD).
133 | If it is false, also returns this list, but each element is now a numpy vector containing all segment-wise performance values
134 | '''
135 | files = glob.glob(os.path.join(json_folder, "*.json"))
136 | inst_list = None
137 | print("Found " + str(len(files)) + " JSON files to evaluate...")
138 | for path in files:
139 | #print(path)
140 | if path.__contains__("test.json"):
141 | print("Found test JSON, skipping...")
142 | continue
143 |
144 | with open(path, "r") as f:
145 | js = json.load(f)
146 |
147 | if inst_list is None:
148 | inst_list = [list() for _ in range(len(js["targets"]))]
149 |
150 | for i in range(len(js["targets"])):
151 | inst_list[i].extend([np.float(f['metrics'][metric]) for f in js["targets"][i]["frames"]])
152 |
153 | #return np.array(sdr_acc), np.array(sdr_voc)
154 | inst_list = [np.array(perf) for perf in inst_list]
155 |
156 | if compute_averages:
157 | return [(np.nanmedian(perf), np.nanmedian(np.abs(perf - np.nanmedian(perf))), np.nanmean(perf), np.nanstd(perf)) for perf in inst_list]
158 | else:
159 | return inst_list
--------------------------------------------------------------------------------
/eval/Visualisation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torchvision.utils as vutils
3 |
4 | def generate_images(generator, generator_inputs, out_path, num_images, device, transform=None):
5 | # Create subfolder if it doesn't exist
6 | if not os.path.exists(out_path):
7 | os.makedirs(out_path)
8 |
9 | # Generate images and save
10 | idx = 0
11 | while idx < num_images:
12 | gen_input = next(generator_inputs)
13 | gen_input = [item.to(device) for item in gen_input]
14 | sample_batch = generator(gen_input)
15 | if transform is not None:
16 | sample_batch = transform(sample_batch)
17 | for sample in sample_batch:
18 | save_as_image(sample, os.path.join(out_path, "gen_" + str(idx) + ".png"))
19 | idx += 1
20 |
21 | def save_as_image(tensor, path):
22 | vutils.save_image(tensor, path, normalize=True)
--------------------------------------------------------------------------------
/eval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/eval/__init__.py
--------------------------------------------------------------------------------
/factorgan_conditional.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/factorgan_conditional.png
--------------------------------------------------------------------------------
/factorgan_unconditional.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/factorgan_unconditional.png
--------------------------------------------------------------------------------
/models/MNISTClassifier.py:
--------------------------------------------------------------------------------
1 | from torch import nn as nn
2 | from torch.nn import functional as F
3 |
4 | class MNISTModel(nn.Module):
5 | def __init__(self):
6 | '''
7 | Simple CNN to use as MNIST classifier
8 | '''
9 | super(MNISTModel, self).__init__()
10 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
11 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
12 | self.conv2_drop = nn.Dropout2d()
13 | self.fc1 = nn.Linear(320, 50)
14 | self.fc2 = nn.Linear(50, 10)
15 |
16 | def forward(self, x, return_hidden=False):
17 | x = F.leaky_relu(F.avg_pool2d(self.conv1(x), 2))
18 | x = F.leaky_relu(F.avg_pool2d(self.conv2(x), 2))
19 | x = x.view(-1, 320)
20 | x = F.leaky_relu(self.fc1(x))
21 | x = F.dropout(x, training=self.training)
22 |
23 | # If return_hidden flag is given, we return the activations from the last layer instead of the final output
24 | if return_hidden:
25 | return x
26 |
27 | x = self.fc2(x)
28 | if self.training:
29 | return F.log_softmax(x, dim=1)
30 | else:
31 | return F.softmax(x, dim=1)
--------------------------------------------------------------------------------
/models/SpectralNorm.py:
--------------------------------------------------------------------------------
1 | '''
2 | The following is taken from
3 |
4 | https://github.com/christiancosgrove/pytorch-spectral-normalization-gan/blob/master/spectral_normalization.py
5 |
6 | licensed under
7 |
8 | MIT License
9 |
10 | Copyright (c) 2017 Christian Cosgrove
11 |
12 | Permission is hereby granted, free of charge, to any person obtaining a copy
13 | of this software and associated documentation files (the "Software"), to deal
14 | in the Software without restriction, including without limitation the rights
15 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16 | copies of the Software, and to permit persons to whom the Software is
17 | furnished to do so, subject to the following conditions:
18 |
19 | The above copyright notice and this permission notice shall be included in all
20 | copies or substantial portions of the Software.
21 |
22 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28 | SOFTWARE.
29 | '''
30 |
31 | import torch
32 | from torch import nn
33 | from torch.nn import Parameter
34 |
35 | def l2normalize(v, eps=1e-12):
36 | return v / (v.norm() + eps)
37 |
38 | class SpectralNorm(nn.Module):
39 | def __init__(self, module, name='weight', power_iterations=1):
40 | super(SpectralNorm, self).__init__()
41 | self.module = module
42 | self.name = name
43 | self.power_iterations = power_iterations
44 | if not self._made_params():
45 | self._make_params()
46 |
47 | def _update_u_v(self):
48 | u = getattr(self.module, self.name + "_u")
49 | v = getattr(self.module, self.name + "_v")
50 | w = getattr(self.module, self.name + "_bar")
51 |
52 | height = w.data.shape[0]
53 | for _ in range(self.power_iterations):
54 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
55 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
56 |
57 | sigma = u.dot(w.view(height, -1).mv(v))
58 | setattr(self.module, self.name, w / sigma.expand_as(w))
59 |
60 | def _made_params(self):
61 | try:
62 | u = getattr(self.module, self.name + "_u")
63 | v = getattr(self.module, self.name + "_v")
64 | w = getattr(self.module, self.name + "_bar")
65 | return True
66 | except AttributeError:
67 | return False
68 |
69 | def _make_params(self):
70 | w = getattr(self.module, self.name)
71 |
72 | height = w.data.shape[0]
73 | width = w.view(height, -1).data.shape[1]
74 |
75 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
76 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
77 | u.data = l2normalize(u.data)
78 | v.data = l2normalize(v.data)
79 | w_bar = Parameter(w.data)
80 |
81 | del self.module._parameters[self.name]
82 |
83 | self.module.register_parameter(self.name + "_u", u)
84 | self.module.register_parameter(self.name + "_v", v)
85 | self.module.register_parameter(self.name + "_bar", w_bar)
86 |
87 |
88 | def forward(self, *args):
89 | self._update_u_v()
90 | return self.module.forward(*args)
91 |
92 | def set_spectral_norm(module, activate):
93 | if activate:
94 | return SpectralNorm(module)
95 | else:
96 | return module
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/models/__init__.py
--------------------------------------------------------------------------------
/models/discriminators/ConvDiscriminator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | from models.SpectralNorm import set_spectral_norm
6 |
7 | class ConvDiscriminator(nn.Module):
8 | def __init__(self, x_dim, y_dim, input_channels=1, filters=32, spectral_norm=True):
9 | super(ConvDiscriminator, self).__init__()
10 |
11 | num_layers = int(np.log2(min(x_dim, y_dim)) - 3)
12 | feature_width = x_dim // (2 ** (num_layers+1))
13 | feature_height = y_dim // (2 ** (num_layers+1))
14 |
15 | assert(np.mod(num_layers, 1) == 0)
16 | num_layers = int(num_layers)
17 |
18 | conv_layers = list()
19 | conv_layers.append(set_spectral_norm(nn.Conv2d(input_channels, filters, 4, 2, 1), spectral_norm))
20 | conv_layers.append(nn.LeakyReLU())
21 | for i in range(num_layers):
22 | conv_layers.append(set_spectral_norm(nn.Conv2d(filters * (2 ** i), filters * (2 ** (i + 1)), 4, 2, 1), spectral_norm))
23 | conv_layers.append(nn.LeakyReLU())
24 |
25 | self.fc = set_spectral_norm(nn.Linear(feature_width * feature_height * filters * (2 ** num_layers), 1, bias=False), spectral_norm)
26 | self.conv_layers = nn.ModuleList(conv_layers)
27 |
28 | def conv_size(self, orig_size, filter_size, padding, stride):
29 | return np.floor((orig_size + 2*padding - filter_size).astype(float) / stride.astype(float)).astype(int) + 1
30 |
31 | def forward(self, input):
32 | x = input
33 | for layer in self.conv_layers:
34 | x = layer(x)
35 | x = self.fc(x.view(x.shape[0], -1))
36 |
37 | return torch.squeeze(x, 1)
--------------------------------------------------------------------------------
/models/discriminators/FCDiscriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from models.SpectralNorm import set_spectral_norm
5 |
6 | class FCDiscriminator(nn.Module):
7 | def __init__(self, input_dim, spectral_norm=True, preprocess_func=None):
8 | '''
9 | Fully connected discriminator network
10 | :param input_dim: Number of inputs
11 | :param spectral_norm: Whether to use spectral normalisation
12 | :param preprocess_func: Function that preprocesses the input before feeding to the network
13 | '''
14 | super(FCDiscriminator, self).__init__()
15 | self.preprocess_func = preprocess_func
16 |
17 | self.fc1 = set_spectral_norm(nn.Linear(input_dim, 128), spectral_norm)
18 | self.fc2 = set_spectral_norm(nn.Linear(128, 128), spectral_norm)
19 | self.fc3 = set_spectral_norm(nn.Linear(128, 1), spectral_norm)
20 | self.activation = nn.LeakyReLU()
21 | self.output_activation = nn.Sigmoid()
22 |
23 | def forward(self, input):
24 | if self.preprocess_func != None:
25 | input = self.preprocess_func(input)
26 |
27 | x = self.fc1(input)
28 | x = self.activation(x)
29 | x = self.fc2(x)
30 | x = self.activation(x)
31 | x = self.fc3(x)
32 | return torch.squeeze(x, 1)
--------------------------------------------------------------------------------
/models/discriminators/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/models/discriminators/__init__.py
--------------------------------------------------------------------------------
/models/generators/ConvGenerator.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 |
4 | class ConvGenerator(nn.Module):
5 | def __init__(self, opt, ngf, out_width, nc=3):
6 | super(ConvGenerator, self).__init__()
7 | self.conv_size = 4
8 |
9 | num_layers = np.log2(out_width) - 3
10 | assert(np.mod(num_layers, 1) == 0)
11 | num_layers = int(num_layers)
12 |
13 | conv_list = list()
14 |
15 | # Compute channel numbers in each layer
16 | channel_list = [ngf * (2 ** (i-1)) for i in range(num_layers+1, 0, -1)]
17 |
18 | # First layer
19 | conv_list.append(nn.ConvTranspose2d(opt.nz, channel_list[0], 4, 1, 0, bias=True))
20 | conv_list.append(nn.ReLU(True))
21 |
22 | for i in range(0, num_layers):
23 | conv_list.append(nn.ConvTranspose2d(channel_list[i], channel_list[i+1], self.conv_size, 2, 1, bias=True))
24 | conv_list.append(nn.ReLU(True))
25 |
26 | # Last layer
27 | conv_list.append(nn.ConvTranspose2d(ngf, nc, self.conv_size, 2, 1, bias=True))
28 | if opt.generator_activation == "sigmoid":
29 | conv_list.append(nn.Sigmoid())
30 | elif opt.generator_activation == "relu":
31 | conv_list.append(nn.ReLU())
32 | else:
33 | print("WARNING: Using ConvGenerator without output activation")
34 |
35 | self.main = nn.Sequential(*conv_list)
36 |
37 | def forward(self, input):
38 | assert (len(input) == 1)
39 | noise = input[0].unsqueeze(2).unsqueeze(2)
40 | output = self.main(noise)
41 | return output
42 |
--------------------------------------------------------------------------------
/models/generators/FCGenerator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class FCGenerator(nn.Module):
5 | def __init__(self, opt, output_dim):
6 | super(FCGenerator, self).__init__()
7 | self.opt = opt
8 | self.fc1 = nn.Linear(opt.nz, 128)
9 | self.fc2 = nn.Linear(128, 128)
10 | self.fc3 = nn.Linear(128, output_dim)
11 | self.activation = nn.ReLU()
12 |
13 | if opt.generator_activation == "sigmoid":
14 | self.output_activation = nn.Sigmoid()
15 | elif opt.generator_activation == "relu":
16 | self.output_activation = nn.ReLU()
17 | else:
18 | print("Using generator without output activation")
19 | self.output_activation = None
20 |
21 | def forward(self, input):
22 | assert(len(input) == 1)
23 |
24 | x = self.fc1(input[0])
25 | x = self.activation(x)
26 | x = self.fc2(x)
27 | x = self.activation(x)
28 | x = self.fc3(x)
29 | if self.output_activation != None:
30 | x = self.output_activation(x)
31 | return x
--------------------------------------------------------------------------------
/models/generators/Unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import Utils
6 |
7 | class Unet(nn.Module):
8 | def __init__(self, opt, ngf, nc_in=3, nc_out=3):
9 |
10 | super(Unet, self).__init__()
11 | self.inc = inconv(nc_in, ngf)
12 | self.down1 = down(ngf, ngf*2)
13 | self.down2 = down(ngf*2, ngf*4)
14 | self.down3 = down(ngf*4, ngf*8)
15 | self.down4 = down(ngf*8, ngf*8)
16 |
17 | fmap_shape = (opt.input_width//16)*(opt.input_height//16)
18 | noise_channels = 2
19 | self.fc_noise = nn.Linear(opt.nz, noise_channels*fmap_shape)
20 |
21 | self.up1 = up(ngf*16 + noise_channels, ngf*4)
22 | self.up2 = up(ngf*8, ngf*2)
23 | self.up3 = up(ngf*4, ngf)
24 | self.up4 = up(ngf*2, ngf)
25 | self.outc = outconv(ngf, nc_out)
26 |
27 | if opt.generator_activation == "sigmoid":
28 | self.output_activation = nn.Sigmoid()
29 | elif opt.generator_activation == "relu":
30 | self.output_activation = nn.ReLU()
31 | else:
32 | print("WARNING: Using Unet without output activation")
33 | self.output_activation = None
34 |
35 | if hasattr(opt, "generator_mask"):
36 | assert(opt.generator_mask == 0 or nc_out == 1)
37 | assert(opt.generator_mask == 0 or opt.generator_activation == "sigmoid")
38 | self.mask = opt.generator_mask
39 | else:
40 | self.mask = False
41 |
42 | def forward(self, input):
43 | cond = input[0]
44 | noise = input[1]
45 |
46 | # Downward
47 | x1 = self.inc(cond)
48 | x2 = self.down1(x1)
49 | x3 = self.down2(x2)
50 | x4 = self.down3(x3)
51 | x5 = self.down4(x4)
52 |
53 | # FC to append noise channels to feature map
54 | noise_fmap_shape = [x5.shape[0], -1, x5.shape[2], x5.shape[3]]
55 | noise_fmap = self.fc_noise(noise).view(noise_fmap_shape)
56 | x5 = torch.cat([x5, noise_fmap], dim=1)
57 |
58 | # Upward
59 | x = self.up1(x5, x4)
60 | x = self.up2(x, x3)
61 | x = self.up3(x, x2)
62 | x = self.up4(x, x1)
63 | x = self.outc(x)
64 |
65 | if self.output_activation != None:
66 | x = self.output_activation(x)
67 |
68 | if self.mask:
69 | # Denormalise audio mixture input
70 | original_cond = Utils.denormalise_spectrogram_torch(cond)
71 |
72 | # Compute acc and voice based on mask and unnormalised mixture spectrogram, then renormalise
73 | acc = Utils.normalise_spectrogram_torch(original_cond * x)
74 | voice = Utils.normalise_spectrogram_torch(original_cond * (1.0 - x))
75 |
76 | x = torch.cat([acc, voice], dim=1)
77 |
78 | return torch.cat([cond, x], dim=1)
79 |
80 | class double_conv(nn.Module):
81 | '''(conv => BN => ReLU) * 2'''
82 |
83 | def __init__(self, in_ch, out_ch):
84 | super(double_conv, self).__init__()
85 | self.conv = nn.Sequential(
86 | nn.Conv2d(in_ch, out_ch, 3, padding=1),
87 | nn.BatchNorm2d(out_ch),
88 | nn.ReLU(inplace=True),
89 | nn.Conv2d(out_ch, out_ch, 3, padding=1),
90 | nn.BatchNorm2d(out_ch),
91 | nn.ReLU(inplace=True)
92 | )
93 |
94 | def forward(self, x):
95 | x = self.conv(x)
96 | return x
97 |
98 | class inconv(nn.Module):
99 | def __init__(self, in_ch, out_ch):
100 | super(inconv, self).__init__()
101 | self.conv = double_conv(in_ch, out_ch)
102 |
103 | def forward(self, x):
104 | x = self.conv(x)
105 | return x
106 |
107 |
108 | class down(nn.Module):
109 | def __init__(self, in_ch, out_ch):
110 | super(down, self).__init__()
111 | self.mpconv = nn.Sequential(
112 | nn.MaxPool2d(2),
113 | double_conv(in_ch, out_ch)
114 | )
115 |
116 | def forward(self, x):
117 | x = self.mpconv(x)
118 | return x
119 |
120 |
121 | class up(nn.Module):
122 | def __init__(self, in_ch, out_ch, bilinear=True):
123 | super(up, self).__init__()
124 |
125 | # would be a nice idea if the upsampling could be learned too,
126 | # but my machine do not have enough memory to handle all those weights
127 | if bilinear:
128 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
129 | else:
130 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
131 |
132 | self.conv = double_conv(in_ch, out_ch)
133 |
134 | def forward(self, x1, x2):
135 | x1 = self.up(x1)
136 |
137 | # input is CHW
138 | diffY = x2.size()[2] - x1.size()[2]
139 | diffX = x2.size()[3] - x1.size()[3]
140 |
141 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
142 | diffY // 2, diffY - diffY // 2))
143 |
144 | # for padding issues, see
145 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
146 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
147 |
148 | x = torch.cat([x2, x1], dim=1)
149 | x = self.conv(x)
150 | return x
151 |
152 |
153 | class outconv(nn.Module):
154 | def __init__(self, in_ch, out_ch):
155 | super(outconv, self).__init__()
156 | self.conv = nn.Conv2d(in_ch, out_ch, 1)
157 |
158 | def forward(self, x):
159 | x = self.conv(x)
160 | return x
--------------------------------------------------------------------------------
/models/generators/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/models/generators/__init__.py
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | imageio
2 | seaborn
3 | pandas
4 | matplotlib
5 | soundfile
6 | scipy
7 | museval
8 | musdb
9 | librosa
10 | numpy
11 | tensorboard
12 | torchvision==0.5.0
13 | torch==1.4.0
14 | tqdm
--------------------------------------------------------------------------------
/training/AdversarialTraining.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from tqdm import tqdm
4 |
5 | from training.DiscriminatorTraining import *
6 | from torch.utils.tensorboard import SummaryWriter
7 |
8 | def train(cfg, G, G_input, G_opt, D_marginal_setups, D_dep_pairs, device, logdir):
9 | print("START TRAINING! Writing logs to " + logdir)
10 | writer = SummaryWriter(logdir)
11 |
12 | # Create expression for overall discriminator output given a complete fake sample
13 | # Marginal disc sum output
14 | marginal_sum = lambda y: sum(setup.D(setup.crop_fake(y)) for setup in D_marginal_setups)
15 | if cfg.factorGAN == 1:
16 | # Dep disc sum output
17 | dep_sum = lambda y: sum(disc_pair.get_comb_disc()(y) for disc_pair in D_dep_pairs)
18 | jointD = lambda y: marginal_sum(y) + dep_sum(y)
19 | else:
20 | jointD = marginal_sum
21 |
22 | # START NORMAL TRAINING
23 | for epoch in range(cfg.epochs):
24 | for i in tqdm(range(cfg.epoch_iter)):
25 | total_it = epoch * cfg.epoch_iter + i
26 | # If dependency GAN active, train marginal discriminators here from both extra data and main data
27 | for j in range(cfg.disc_iter): # No. of disc iterations
28 | # Train marginal discriminators
29 | for D_setup in D_marginal_setups:
30 | errD, correct, _, _ = get_marginal_disc_output(D_setup, device, backward=True, zero_gradients=True)
31 | if j==cfg.disc_iter-1: writer.add_scalar(D_setup.name + "_acc", correct, total_it)
32 | D_setup.optim.step()
33 |
34 | if cfg.factorGAN == 1:
35 | # Additionally train dependency discriminators
36 | for D_dep_pair in D_dep_pairs:
37 | # Train REAL dependency discriminator
38 | if cfg.use_real_dep_disc == 1:
39 | # Training step for real dep disc
40 | errD, correct, _, _ = get_dep_disc_output(D_dep_pair.real_disc, device, backward=True,zero_gradients=True)
41 | D_dep_pair.real_disc.optim.step()
42 |
43 | # Logging for last discriminator update
44 | if j == cfg.disc_iter - 1: writer.add_scalar(D_dep_pair.real_disc.name + "_acc", correct, total_it)
45 | if j == cfg.disc_iter - 1: writer.add_scalar(D_dep_pair.real_disc.name + "_errD", errD, total_it)
46 |
47 | # Train FAKE dependency discriminator. Use combined output of real and fake dependency discs for regularisation purposes => Fake dep. disc needs to ensure its gradients stay close to the real dep. ones
48 | errD, correct, _, _ = get_dep_disc_output(D_dep_pair.fake_disc, device, backward=True, zero_gradients=True)
49 | D_dep_pair.fake_disc.optim.step()
50 |
51 | # Logging for last discriminator update
52 | if j == cfg.disc_iter - 1: writer.add_scalar(D_dep_pair.fake_disc.name + "_acc", correct, total_it)
53 | if j == cfg.disc_iter - 1: writer.add_scalar(D_dep_pair.fake_disc.name + "_errD", errD, total_it)
54 |
55 | ############################
56 | # (2) Update G network:
57 | ###########################
58 |
59 | G.zero_grad()
60 |
61 | # Get fake samples from generator
62 | gen_input = G_input.__next__()
63 | gen_input = [item.to(device) for item in gen_input]
64 | fake_sample = G(gen_input)
65 |
66 | #TODO Produce log outputs for all tasks here?
67 |
68 | # Get setup information from first marginal discriminator (which is the only one in normal GAN training)
69 | real_label = D_marginal_setups[0].real_label
70 | criterion = D_marginal_setups[0].criterion
71 |
72 | label = torch.full((cfg.batchSize,), real_label, device=device) # fake labels are real for generator cost
73 | disc_output = jointD(fake_sample)
74 | writer.add_scalar("probG", torch.mean(torch.nn.Sigmoid()(disc_output)), total_it)
75 | if cfg.objective == "JSD":
76 | errG = criterion()(disc_output, label) # Normal JSD
77 | elif cfg.objective == "KL":
78 | errG = -torch.mean(disc_output) # KL[q|p]
79 | else:
80 | raise NotImplementedError
81 |
82 | writer.add_scalar("errG", errG, i)
83 |
84 | errG.backward()
85 | G_opt.step()
86 |
87 | print("EPOCH FINISHED")
88 |
89 | model_output_path = os.path.join(cfg.experiment_path, "G_" + str(epoch))
90 | print("Saving generator at " + model_output_path)
91 | torch.save(G.state_dict(), model_output_path)
92 |
93 | # FINISHED
94 | print("TRAINING FINISHED")
--------------------------------------------------------------------------------
/training/DiscriminatorTraining.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn
3 |
4 | class DiscriminatorSetup(object):
5 | def __init__(self, name, D, optim, real_data, fake_data,
6 | crop_real=lambda x:x, crop_fake=lambda x:x,
7 | criterion=torch.nn.BCEWithLogitsLoss, real_label=1, fake_label=0):
8 | '''
9 | Disriminator model including optimiser, input sources, how inputs are cropped, and how the discriminator is trained
10 | :param name: Name of discriminator
11 | :param D: Discriminator model
12 | :param optim: Optimiser for D
13 | :param real_data: Real data iterator
14 | :param fake_data: Fake data iterator
15 | :param crop_real: Function that crops real inputs
16 | :param crop_fake: Function that crops fake inputs
17 | :param criterion: Discriminator training loss to use
18 | :param real_label: Real label for training loss
19 | :param fake_label: Fake label for training loss
20 | '''
21 | self.name = name
22 | self.D = D
23 | self.optim = optim
24 | self.real_data = real_data
25 | self.fake_data = fake_data
26 |
27 | self.real_label = real_label
28 | self.fake_label = fake_label
29 |
30 | self.crop_real = crop_real
31 | self.crop_fake = crop_fake
32 |
33 | self.criterion = criterion
34 |
35 | class DependencyDiscriminatorSetup(object):
36 | def __init__(self, name, D, optim, data, shuffle_batch_func, crop_func=lambda x:x,
37 | criterion=torch.nn.BCEWithLogitsLoss, real_label=1, fake_label=0):
38 | '''
39 | Dependency discriminator model including name, optimiser, input data source, how batches are shuffled, and training criterion
40 | :param name: Name of discriminator
41 | :param D: Discriminator model
42 | :param optim: Optimiser
43 | :param data: Joint data source ("real" data) with dependencies
44 | :param shuffle_batch_func: Function that shuffles a batch taken from data to get an independent version
45 | :param crop_func: Function that crops the input before feeding it to the discriminator
46 | :param criterion: Training loss
47 | :param real_label: Real label
48 | :param fake_label: Fake label (shuffled/independent batch)
49 | '''
50 | self.name = name
51 | self.D = D
52 | self.optim = optim
53 | self.data = data
54 | self.real_label = real_label
55 | self.fake_label = fake_label
56 | self.crop_func = crop_func
57 | self.shuffle_batch_func = shuffle_batch_func
58 | self.criterion = criterion
59 |
60 | class DependencyDiscriminatorPair(object):
61 | def __init__(self, real_disc, fake_disc:DependencyDiscriminatorSetup):
62 | '''
63 | Binds to dependency discriminators (for p and q) together, to compute a combined output
64 | :param real_disc: p-dependency discriminator. Can be None in case none is used
65 | :param fake_disc: q-dependency discriminator
66 | '''
67 | assert(real_disc == None or isinstance(real_disc, DependencyDiscriminatorSetup))
68 | self.real_disc = real_disc
69 | self.fake_disc = fake_disc
70 |
71 | def get_comb_disc(self):
72 | '''
73 | Computes combined output d_P(x) - d_Q(x) that represents the dependency-based part of the generator loss
74 | :return:
75 | '''
76 | if self.real_disc != None:
77 | return lambda x : self.real_disc.D(self.real_disc.crop_func(x)) - self.fake_disc.D(self.fake_disc.crop_func(x))
78 | else:
79 | return lambda x : - self.fake_disc.D(self.fake_disc.crop_func(x))
80 |
81 | def get_dep_disc_output(disc_setup:DependencyDiscriminatorSetup, device, backward=False, zero_gradients=False):
82 | '''
83 | Computes output of a dependency discriminator (and optionally gradients) for another input batch, by using the batch
84 | itself as real, and a shuffled version as fake input
85 | :param disc_setup: Dependency discriminator object
86 | :param device: Device to use
87 | :param backward: Whether to compute gradients
88 | :param zero_gradients: Whether to zero gradients at the beginning
89 | :return: see get_disc_output_batch
90 | '''
91 | real_sample = disc_setup.data.__next__()
92 | real_sample = disc_setup.crop_func(real_sample)
93 |
94 | fake_sample = disc_setup.shuffle_batch_func(real_sample) # Get fake data by simply shuffling current batch
95 |
96 | return get_disc_output_batch(disc_setup.D, real_sample, fake_sample, disc_setup.real_label, disc_setup.fake_label, disc_setup.criterion, device, backward, zero_gradients)
97 |
98 | def get_marginal_disc_output(disc_setup:DiscriminatorSetup, device, backward=False, zero_gradients=False):
99 | '''
100 | Computes output of a marginal discriminator
101 | :param disc_setup: Marginal discriminator object
102 | :param device: Device to use
103 | :param backward: Whether to compute gradients
104 | :param zero_gradients: Whether to zero gradients at the beginning
105 | :return: see get_disc_output_batch
106 | '''
107 | real_sample = disc_setup.real_data.__next__().to(device)
108 | real_sample = disc_setup.crop_real(real_sample)
109 |
110 | fake_sample = disc_setup.fake_data.__next__().to(device)
111 | fake_sample = disc_setup.crop_fake(fake_sample)
112 |
113 | return get_disc_output_batch(disc_setup.D, real_sample, fake_sample, disc_setup.real_label, disc_setup.fake_label, disc_setup.criterion, device, backward, zero_gradients)
114 |
115 | def get_disc_output_batch(D, real_sample, fake_sample, real_label, fake_label, criterion, device, backward, zero_gradients):
116 | '''
117 | Compute loss, output and optionally gradients for a discriminator model with a given training loss
118 | :param D: Discriminator model
119 | :param real_sample: Batch of real samples
120 | :param fake_sample: Batch of fake samples
121 | :param real_label: Target label for real batch
122 | :param fake_label: Target label for fake batch
123 | :param criterion: Training loss
124 | :param device: Device to use
125 | :param backward: Whether to compute gradients
126 | :param zero_gradients: Whether to zero gradients at the beginning
127 | :return: Average of real and fake training loss, discriminator accuracy, discriminator outputs for real and fake
128 | '''
129 | # Never backpropagate through disc input in this function
130 | # Transfer inputs to correct device
131 | real_sample = real_sample.detach().to(device)
132 | fake_sample = fake_sample.detach().to(device)
133 |
134 | if zero_gradients:
135 | D.zero_grad()
136 |
137 | # Get real sample output
138 | real_batch_size = real_sample.size()[0]
139 |
140 | label = torch.full((real_batch_size,), real_label, device=device)
141 | real_output = D(real_sample)
142 |
143 | errD_real = criterion()(real_output, label)
144 | if backward:
145 | errD_real.backward()
146 |
147 | # Get fake sample output
148 | fake_batch_size = fake_sample.size()[0]
149 | label = torch.full((fake_batch_size,), fake_label, device=device)
150 | fake_output = D(fake_sample)
151 |
152 | errD_fake = criterion()(fake_output, label)
153 | if backward:
154 | errD_fake.backward()
155 |
156 | # Accuracy
157 | correct = 0.5 * (real_output > 0.0).sum().item() / real_batch_size + 0.5 * (fake_output < 0.0).sum().item() / fake_batch_size
158 | return 0.5*errD_real + 0.5*errD_fake, correct, real_output, fake_output
--------------------------------------------------------------------------------
/training/MNIST.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import torch.nn.functional as F
4 | import torch.optim as optim
5 | import os
6 | from torchvision import datasets, transforms
7 |
8 | from models.MNISTClassifier import MNISTModel
9 |
10 |
11 | def train(model, device, train_loader, optimizer, epoch):
12 | '''
13 | Train MNIST model for one epoch
14 | :param model: MNIST model
15 | :param device: Device to use
16 | :param train_loader: Training dataset loader
17 | :param optimizer: Optimiser to use
18 | :param epoch: Current epoch index
19 | '''
20 | model.train()
21 | for batch_idx, (data, target) in enumerate(train_loader):
22 | data, target = data.to(device), target.to(device)
23 | optimizer.zero_grad()
24 | output = model(data)
25 | loss = F.nll_loss(output, target)
26 | loss.backward()
27 | optimizer.step()
28 | if batch_idx % 10 == 0:
29 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
30 | epoch, batch_idx * len(data), len(train_loader.dataset),
31 | 100. * batch_idx / len(train_loader), loss.item()))
32 |
33 | def test(model, device, test_loader):
34 | '''
35 | Test MNIST classifier, prints out results into standard output
36 | :param model: Classifier model
37 | :param device: Device to use
38 | :param test_loader: Test dataset loader
39 | '''
40 | model.eval()
41 | test_loss = 0
42 | correct = 0
43 | with torch.no_grad():
44 | for data, target in test_loader:
45 | data, target = data.to(device), target.to(device)
46 | output = model(data)
47 | test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
48 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
49 | correct += pred.eq(target.view_as(pred)).sum().item()
50 |
51 | test_loss /= len(test_loader.dataset)
52 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
53 | test_loss, correct, len(test_loader.dataset),
54 | 100. * correct / len(test_loader.dataset)))
55 |
56 | def main(opt):
57 | use_cuda = opt.cuda and torch.cuda.is_available()
58 |
59 | torch.manual_seed(opt.seed)
60 |
61 | device = torch.device("cuda" if use_cuda else "cpu")
62 |
63 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
64 | train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,
65 | transform=transforms.Compose([transforms.ToTensor()])),
66 | batch_size=opt.batchSize, shuffle=True, **kwargs)
67 | test_loader = torch.utils.data.DataLoader(
68 | datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor()])),
69 | batch_size=opt.batchSize, shuffle=True, **kwargs)
70 |
71 |
72 | model = MNISTModel().to(device)
73 | optimizer = optim.Adam(model.parameters(), lr=1e-4)
74 |
75 | # Check if we saved the model before already, in that case, just load that!
76 | MODEL_NAME = "mnist_classifier_model"
77 | if os.path.exists(MODEL_NAME):
78 | print("Found pre-trained MNIST classifier, loading from " + MODEL_NAME)
79 | model.load_state_dict(torch.load(MODEL_NAME))
80 | return model
81 |
82 | # Train model for a certain number of epochs
83 | NUM_EPOCHS = 4
84 | print("TRAINING MNIST CLASSIFIER")
85 | for epoch in range(NUM_EPOCHS):
86 | train(model, device, train_loader, optimizer, epoch)
87 | test(model, device, test_loader)
88 | torch.save(model.state_dict(), MODEL_NAME)
89 | return model
90 |
91 | if __name__ == '__main__':
92 | main()
--------------------------------------------------------------------------------
/training/TrainingOptions.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 |
4 | def get_parser():
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--experiment_name', type=str, default=str(np.random.randint(0, 100000)), help='experiment name')
7 | parser.add_argument('--out_path', type=str, default="out", help="Output path")
8 |
9 | parser.add_argument('--eval', action='store_true', help='Perform evaluation instead of training')
10 | parser.add_argument('--eval_model', type=str, default='G', help='Name of generator checkpoint file to load for evaluation')
11 |
12 | parser.add_argument('--epochs', type=int, default=40, help='number of epochs to train for')
13 | parser.add_argument('--epoch_iter', type=int, default=5000, help="Number of generator updates per epoch")
14 |
15 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate, default=0.0002')
16 | parser.add_argument('--L2', type=float, default=0.0, help='L2 regularisation for discriminators (except joint p dependency discriminator')
17 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
18 |
19 | parser.add_argument('--cuda', action='store_true', help='enables cuda')
20 | parser.add_argument('--seed', type=int, default=1337, help='manual seed')
21 |
22 | # Generator settings
23 | parser.add_argument('--nz', type=int, default=50, help='size of the latent z vector')
24 | parser.add_argument('--factorGAN', type=int, default=0, help="Activate FactorGAN instead of normal GAN")
25 |
26 | parser.add_argument('--disc_iter', type=int, default=2, help="Number of discriminator(s) iteration per generator update")
27 | parser.add_argument('--objective', type=str, default="JSD", help="JSD or KL as generator objective")
28 |
29 | # Dependency settings
30 | parser.add_argument('--lipschitz_q', type=int, default=1, help="Spectral norm regularisation for fake dependency discriminators")
31 | parser.add_argument('--lipschitz_p', type=int, default=1, help="Spectral norm regularisation for real dependency discriminators")
32 | parser.add_argument('--use_real_dep_disc', type=int, default=1, help="1 to use the dependency discriminator on real data normally, 0 to not use it and set the output to zero, assuming our real data dimensions are independent")
33 |
34 | # Data loading settings
35 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=1)
36 | parser.add_argument('--batchSize', type=int, default=25, help='input batch size')
37 |
38 | return parser
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f90/FactorGAN/ae57301195984092ee40742273e1034f3ae27e32/training/__init__.py
--------------------------------------------------------------------------------