├── .gitignore
├── README.md
├── btp_dataset.py
├── delta_trial.txt
├── finetune_model.py
├── generate_dataset.py
├── main.py
├── main_cgan.py
├── models
├── __init__.py
├── convolutional_models.py
└── recurrent_models.py
├── requirements.txt
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | net*
3 | log/
4 | Dataset*
5 | clean_last_experiment.py
6 | checkpoints/
7 | images/
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Financial time series generation using GANs
2 | This repository contains the implementation of a GAN-based method for real-valued financial time series generation. See for instance [Real-valued (Medical) Time Series Generation with Recurrent Conditional GANs](https://arxiv.org/abs/1706.02633).
3 |
4 |
5 |
6 | Main features:
7 | - [Causal Convolution](https://arxiv.org/abs/1609.03499) or LSTM architectures for disciminator and generator
8 | - Non-saturing GAN training (see this [tutorial](https://arxiv.org/abs/1701.00160) for more info)
9 | - Generation can be unconditioned or conditioned on the difference between the last and the first element of the time series to be generated (i.e., a daily delta)
10 | - Conditional training done by supervised learning on the generator, either alternating optimization steps or combining adversarial and supervised loss
11 |
12 | During conditional training, daily deltas that are given as additional input to the generator are sampled from a Gaussian distribution estimated from real data via maximum likelihood.
13 |
14 | ## Some words on the dataset
15 | Considering the original data provided in csv format, the values for the time series are obtained from the feature **btp_price**.
16 | Minimal preprocessing, including normalization in the range [-1,1], is done inside `btp_dataset.py`. The resulting dataset has 173 sequences of length 96, for an overall tensor shape of (173 x 96 x 1).
17 | If you use a dataset that is not compatible with this preprocessing, you can just write your own loader.
18 |
19 | ## Project structure
20 | The files and directories composing the project are:
21 | - `main.py`: runs the training. It can save the model checkpoints and images of generated time series, and features visualizations (loss, gradients) via tensorboard. Run `python main.py -h` to see all the options.
22 | - `generate_dataset.py`: generates a fake dataset using a trained generator. The path of the generator checkpoint and of the output \*.npy file for the dataset must be passed as options. Optionally, the path of a file containing daily deltas (one per line) for conditioning the time series generation can be provided.
23 | - `finetune_model.py`: uses pure supervised training for finetuning a trained generator. *Discouraged*, it is generally better to train in supervised and unsupervised way jointly.
24 | - `models/`: directory containing the model architecture for both discriminator and generator.
25 | - `utils.py`: contains some utility functions. It also contains a `DatasetGenerator` class that is used for fake dataset generation.
26 | - `main_cgan.py`: runs training with standard conditional GANs. Cannot produce nice results, but it is kept for reference.
27 |
28 | By default, during training, model weights are saved into the `checkpoints/` directory, snapshots of generated series into `images/` and tensorboard logs into `log/`.
29 |
30 | Use:
31 | ```
32 | tensorboard --logdir log
33 | ```
34 | from inside the project directory to run tensoboard on the default port (6006).
35 |
36 | ## Examples
37 | Run training with recurrent generator and convolutional discriminator, conditioning generator on deltas and alternating adversarial and supervised optimization:
38 | ```
39 | python main.py --dataset_path some_dataset.csv --delta_condition --gen_type lstm --dis_type cnn --alternate --run_tag cnn_dis_lstm_gen_alternte_my_first_trial
40 | ```
41 |
42 | Generate fake dataset `prova.npy` using deltas contained in `delta_trial.txt` and model trained for 70 epochs:
43 | ```
44 | python generate_dataset.py --delta_path delta_trial.txt --checkpoint_path checkpoints/cnn_conditioned_alternate1_netG_epoch_70.pth --output_path prova.npy
45 | ```
46 | Finetune checkpoint of generator with supervised training:
47 | ```
48 | python finetune_model.py --checkpoint checkpoints/cnn_dis_lstm_gen_noalt_new_netG_epoch_39.pth --output_path finetuned.pth
49 | ```
50 |
51 | ## Insights and directions for improvement
52 | - As reported in several works in sequence generation using GANs, recurrent discriminators are usually less stable than convolutional discriminators. Thus, I recommend the convolution-based one.
53 | - I did not perform extensive search over hyperparameters and training procedures, being qualitative evaluation the only one easily possible. If a target task is configured (e.g., learning a policy), intuitions and quantitative evaluations can be obtained and used for selecting the best model.
54 | - There is a bit of a tradeoff between performance on realistic generation and error with respect to input delta. If having mild precision on the delta is not a problem for the final task, its error can be ignored; if one wants to reduce the error on the deltas as much as possible, it is possible either to weight the supervised objective more or to use the supervised fine tuning.
55 | - The training is sometimes prone to mode collapse: the current implementation could benefit from the use of recent GAN variations such as [Wasserstein GANs](https://arxiv.org/abs/1704.00028). It would be sufficient to just change the adversarial part of the training.
56 | - The [standard way](https://arxiv.org/abs/1411.1784) to inject conditions in GANs cannot work without concern for the problem of generation conditioned by deltas: as I observed in some indepedent experiments, causal convolution-based neural networks are able to easily solve the problem of detecting the delta of a given sequence. Therefore, a discriminator that receives the delta as input can easily distinguish between real sequences with correct deltas and fake sequences with incorrect deltas.
57 |
--------------------------------------------------------------------------------
/btp_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | import pandas as pd
4 | import numpy as np
5 |
6 | class BtpDataset(Dataset):
7 | """Btp time series dataset."""
8 | def __init__(self, csv_file, normalize=True):
9 | """
10 | Args:
11 | csv_file (string): path to csv file
12 | normalize (bool): whether to normalize the data in [-1,1]
13 | """
14 | df = pd.read_csv(csv_file, sep=";")
15 | df['Timestamp'] = pd.to_datetime(df["data_column"].map(str) + " " + df["orario_column"], dayfirst=True)
16 | df = df.drop(['data_column', 'orario_column'], axis=1).set_index("Timestamp")
17 | btp_price = df.BTP_Price
18 | data = torch.from_numpy(np.expand_dims(np.array([group[1] for group in btp_price.groupby(df.index.date)]), -1)).float()
19 | self.data = self.normalize(data) if normalize else data
20 | self.seq_len = data.size(1)
21 |
22 | #Estimates distribution parameters of deltas (Gaussian) from normalized data
23 | original_deltas = data[:, -1] - data[:, 0]
24 | self.original_deltas = original_deltas
25 | self.or_delta_max, self.or_delta_min = original_deltas.max(), original_deltas.min()
26 | deltas = self.data[:, -1] - self.data[:, 0]
27 | self.deltas = deltas
28 | self.delta_mean, self.delta_std = deltas.mean(), deltas.std()
29 | self.delta_max, self.delta_min = deltas.max(), deltas.min()
30 |
31 | def __len__(self):
32 | return len(self.data)
33 |
34 | def __getitem__(self, idx):
35 | return self.data[idx]
36 |
37 | def normalize(self, x):
38 | """Normalize input in [-1,1] range, saving statics for denormalization"""
39 | self.max = x.max()
40 | self.min = x.min()
41 | return (2 * (x - x.min())/(x.max() - x.min()) - 1)
42 |
43 | def denormalize(self, x):
44 | """Revert [-1,1] normalization"""
45 | if not hasattr(self, 'max') or not hasattr(self, 'min'):
46 | raise Exception("You are calling denormalize, but the input was not normalized")
47 | return 0.5 * (x*self.max - x*self.min + self.max + self.min)
48 |
49 | def sample_deltas(self, number):
50 | """Sample a vector of (number) deltas from the fitted Gaussian"""
51 | return (torch.randn(number, 1) + self.delta_mean) * self.delta_std
52 |
53 | def normalize_deltas(self, x):
54 | return ((self.delta_max - self.delta_min) * (x - self.or_delta_min)/(self.or_delta_max - self.or_delta_min) + self.delta_min)
55 |
56 |
--------------------------------------------------------------------------------
/delta_trial.txt:
--------------------------------------------------------------------------------
1 | 0.1
2 | 0.2
3 | 0.3
4 | 0.05
5 |
--------------------------------------------------------------------------------
/finetune_model.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from torch import nn, optim
4 | from utils import DatasetGenerator
5 | from btp_dataset import BtpDataset
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--checkpoint_path', required=True, help='Path of the generator checkpoint')
9 | parser.add_argument('--output_path', required=True, help='Path of the output .pth checkpoint')
10 | parser.add_argument('--dataset_path', default='DatasetDVA_2018-03-13_cleaned.csv', help="Path of the dataset for normalization")
11 | parser.add_argument('--batches', type=int, default=50, help="Number of batches to use for finetuning")
12 | parser.add_argument('--batch_size', type=int, default=32)
13 | parser.add_argument('--learning_rate', type=float, default=2e-4)
14 |
15 | opt = parser.parse_args()
16 |
17 | #If an unknown option is provided for the dataset, then don't use any normalization
18 | dataset = BtpDataset(opt.dataset_path)
19 |
20 | model = torch.load(opt.checkpoint_path)
21 |
22 | #"Validation" deltas
23 | val_size = 1000
24 | fixed_noise = torch.randn(val_size, dataset.seq_len, 100)
25 | fixed_deltas = dataset.sample_deltas(val_size).unsqueeze(2).repeat(1, dataset.seq_len, 1)
26 | fixed_noise = torch.cat((fixed_noise, fixed_deltas), dim=2)
27 |
28 | delta_criterion = nn.MSELoss()
29 |
30 | with torch.no_grad():
31 | out_seqs = model(fixed_noise)
32 | delta_loss = delta_criterion(out_seqs[:, -1] - out_seqs[:, 0], fixed_deltas[:,0])
33 | print("Initial error on deltas:", delta_loss.item())
34 |
35 | optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate)
36 |
37 | for i in range(opt.batches):
38 | optimizer.zero_grad()
39 | noise = torch.randn(opt.batch_size, dataset.seq_len, 100)
40 | deltas = dataset.sample_deltas(opt.batch_size).unsqueeze(2).repeat(1, dataset.seq_len, 1)
41 | noise = torch.cat((noise, deltas), dim=2)
42 | #Generate sequence given noise w/ deltas and deltas
43 | out_seqs = model(noise)
44 | delta_loss = delta_criterion(out_seqs[:, -1] - out_seqs[:, 0], deltas[:,0])
45 | delta_loss.backward()
46 | print("\rBatch", i, "Loss:", delta_loss.item(), end="")
47 | optimizer.step()
48 |
49 | with torch.no_grad():
50 | out_seqs = model(fixed_noise)
51 | delta_loss = delta_criterion(out_seqs[:, -1] - out_seqs[:, 0], fixed_deltas[:,0])
52 | print()
53 | print("Final error on deltas:", delta_loss.item())
54 | torch.save(model, opt.output_path)
55 |
--------------------------------------------------------------------------------
/generate_dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from utils import DatasetGenerator
4 | from btp_dataset import BtpDataset
5 |
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument('--checkpoint_path', required=True, help='Path of the generator checkpoint')
8 | parser.add_argument('--output_path', required=True, help='Path of the output .npy file')
9 | parser.add_argument('--delta_path', default='', help='Path of the file containing the list of deltas for conditional generation')
10 | parser.add_argument('--dataset', default="btp", help='dataset to use for normalization (only btp for now)')
11 | parser.add_argument('--dataset_path', required=True, help="Path of the dataset for normalization")
12 | parser.add_argument('--size', default=1000, help='Size of the dataset to generate in case of unconditional generation')
13 | opt = parser.parse_args()
14 |
15 | #If an unknown option is provided for the dataset, then don't use any normalization
16 | dataset = BtpDataset(opt.dataset_path) if opt.dataset == 'btp' else None
17 |
18 | model = torch.load(opt.checkpoint_path)
19 | generator = DatasetGenerator(generator=model, dataset=dataset) #Using default params
20 |
21 | if opt.delta_path != '':
22 | delta_list = [float(line) for line in open(opt.delta_path)]
23 | else:
24 | delta_list = None
25 |
26 | #Size is ignored if delta_list is not None: it is inferred as the length of the list of deltas
27 | generator.generate_dataset(outfile=opt.output_path, delta_list=delta_list, size=opt.size)
28 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import torch
5 | import torch.nn as nn
6 | import torch.backends.cudnn as cudnn
7 | import torch.optim as optim
8 | import torch.utils.data
9 | import torchvision
10 | import datetime
11 | from btp_dataset import BtpDataset
12 | from utils import time_series_to_plot
13 | from tensorboardX import SummaryWriter
14 | from models.recurrent_models import LSTMGenerator, LSTMDiscriminator
15 | from models.convolutional_models import CausalConvGenerator, CausalConvDiscriminator
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('--dataset', default="btp", help='dataset to use (only btp for now)')
19 | parser.add_argument('--dataset_path', required=True, help='path to dataset')
20 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
21 | parser.add_argument('--batchSize', type=int, default=16, help='input batch size')
22 | parser.add_argument('--nz', type=int, default=100, help='dimensionality of the latent vector z')
23 | parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train for')
24 | parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
25 | parser.add_argument('--cuda', action='store_true', help='enables cuda')
26 | parser.add_argument('--netG', default='', help="path to netG (to continue training)")
27 | parser.add_argument('--netD', default='', help="path to netD (to continue training)")
28 | parser.add_argument('--outf', default='checkpoints', help='folder to save checkpoints')
29 | parser.add_argument('--imf', default='images', help='folder to save images')
30 | parser.add_argument('--manualSeed', type=int, help='manual seed')
31 | parser.add_argument('--logdir', default='log', help='logdir for tensorboard')
32 | parser.add_argument('--run_tag', default='', help='tags for the current run')
33 | parser.add_argument('--checkpoint_every', default=5, help='number of epochs after which saving checkpoints')
34 | parser.add_argument('--tensorboard_image_every', default=5, help='interval for displaying images on tensorboard')
35 | parser.add_argument('--delta_condition', action='store_true', help='whether to use the mse loss for deltas')
36 | parser.add_argument('--delta_lambda', type=int, default=10, help='weight for the delta condition')
37 | parser.add_argument('--alternate', action='store_true', help='whether to alternate between adversarial and mse loss in generator')
38 | parser.add_argument('--dis_type', default='cnn', choices=['cnn','lstm'], help='architecture to be used for discriminator to use')
39 | parser.add_argument('--gen_type', default='lstm', choices=['cnn','lstm'], help='architecture to be used for generator to use')
40 | opt = parser.parse_args()
41 |
42 | #Create writer for tensorboard
43 | date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M")
44 | run_name = f"{opt.run_tag}_{date}" if opt.run_tag != '' else date
45 | log_dir_name = os.path.join(opt.logdir, run_name)
46 | writer = SummaryWriter(log_dir_name)
47 | writer.add_text('Options', str(opt), 0)
48 | print(opt)
49 |
50 | try:
51 | os.makedirs(opt.outf)
52 | except OSError:
53 | pass
54 | try:
55 | os.makedirs(opt.imf)
56 | except OSError:
57 | pass
58 |
59 | if opt.manualSeed is None:
60 | opt.manualSeed = random.randint(1, 10000)
61 | print("Random Seed: ", opt.manualSeed)
62 | random.seed(opt.manualSeed)
63 | torch.manual_seed(opt.manualSeed)
64 |
65 | cudnn.benchmark = True
66 |
67 | if torch.cuda.is_available() and not opt.cuda:
68 | print("You have a cuda device, so you might want to run with --cuda as option")
69 |
70 | if opt.dataset == "btp":
71 | dataset = BtpDataset(opt.dataset_path)
72 | assert dataset
73 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
74 | shuffle=True, num_workers=int(opt.workers))
75 |
76 | device = torch.device("cuda:0" if opt.cuda else "cpu")
77 | nz = int(opt.nz)
78 | #Retrieve the sequence length as first dimension of a sequence in the dataset
79 | seq_len = dataset[0].size(0)
80 | #An additional input is needed for the delta
81 | in_dim = opt.nz + 1 if opt.delta_condition else opt.nz
82 |
83 | if opt.dis_type == "lstm":
84 | netD = LSTMDiscriminator(in_dim=1, hidden_dim=256).to(device)
85 | if opt.dis_type == "cnn":
86 | netD = CausalConvDiscriminator(input_size=1, n_layers=8, n_channel=10, kernel_size=8, dropout=0).to(device)
87 | if opt.gen_type == "lstm":
88 | netG = LSTMGenerator(in_dim=in_dim, out_dim=1, hidden_dim=256).to(device)
89 | if opt.gen_type == "cnn":
90 | netG = CausalConvGenerator(noise_size=in_dim, output_size=1, n_layers=8, n_channel=10, kernel_size=8, dropout=0.2).to(device)
91 |
92 | assert netG
93 | assert netD
94 |
95 | if opt.netG != '':
96 | netG.load_state_dict(torch.load(opt.netG))
97 | if opt.netD != '':
98 | netD.load_state_dict(torch.load(opt.netD))
99 |
100 | print("|Discriminator Architecture|\n", netD)
101 | print("|Generator Architecture|\n", netG)
102 |
103 | criterion = nn.BCELoss().to(device)
104 | delta_criterion = nn.MSELoss().to(device)
105 |
106 | #Generate fixed noise to be used for visualization
107 | fixed_noise = torch.randn(opt.batchSize, seq_len, nz, device=device)
108 |
109 | if opt.delta_condition:
110 | #Sample both deltas and noise for visualization
111 | deltas = dataset.sample_deltas(opt.batchSize).unsqueeze(2).repeat(1, seq_len, 1)
112 | fixed_noise = torch.cat((fixed_noise, deltas), dim=2)
113 |
114 | real_label = 1
115 | fake_label = 0
116 |
117 | # setup optimizer
118 | optimizerD = optim.Adam(netD.parameters(), lr=opt.lr)
119 | optimizerG = optim.Adam(netG.parameters(), lr=opt.lr)
120 |
121 | for epoch in range(opt.epochs):
122 | for i, data in enumerate(dataloader, 0):
123 | niter = epoch * len(dataloader) + i
124 |
125 | #Save just first batch of real data for displaying
126 | if i == 0:
127 | real_display = data.cpu()
128 |
129 | ############################
130 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
131 | ###########################
132 |
133 | #Train with real data
134 | netD.zero_grad()
135 | real = data.to(device)
136 | batch_size, seq_len = real.size(0), real.size(1)
137 | label = torch.full((batch_size, seq_len, 1), real_label, device=device)
138 |
139 | output = netD(real)
140 | errD_real = criterion(output, label)
141 | errD_real.backward()
142 | D_x = output.mean().item()
143 |
144 | #Train with fake data
145 | noise = torch.randn(batch_size, seq_len, nz, device=device)
146 | if opt.delta_condition:
147 | #Sample a delta for each batch and concatenate to the noise for each timestep
148 | deltas = dataset.sample_deltas(batch_size).unsqueeze(2).repeat(1, seq_len, 1)
149 | noise = torch.cat((noise, deltas), dim=2)
150 | fake = netG(noise)
151 | label.fill_(fake_label)
152 | output = netD(fake.detach())
153 | errD_fake = criterion(output, label)
154 | errD_fake.backward()
155 | D_G_z1 = output.mean().item()
156 | errD = errD_real + errD_fake
157 | optimizerD.step()
158 |
159 | #Visualize discriminator gradients
160 | for name, param in netD.named_parameters():
161 | writer.add_histogram("DiscriminatorGradients/{}".format(name), param.grad, niter)
162 |
163 | ############################
164 | # (2) Update G network: maximize log(D(G(z)))
165 | ###########################
166 | netG.zero_grad()
167 | label.fill_(real_label)
168 | output = netD(fake)
169 | errG = criterion(output, label)
170 | errG.backward()
171 | D_G_z2 = output.mean().item()
172 |
173 |
174 | if opt.delta_condition:
175 | #If option is passed, alternate between the losses instead of using their sum
176 | if opt.alternate:
177 | optimizerG.step()
178 | netG.zero_grad()
179 | noise = torch.randn(batch_size, seq_len, nz, device=device)
180 | deltas = dataset.sample_deltas(batch_size).unsqueeze(2).repeat(1, seq_len, 1)
181 | noise = torch.cat((noise, deltas), dim=2)
182 | #Generate sequence given noise w/ deltas and deltas
183 | out_seqs = netG(noise)
184 | delta_loss = opt.delta_lambda * delta_criterion(out_seqs[:, -1] - out_seqs[:, 0], deltas[:,0])
185 | delta_loss.backward()
186 |
187 | optimizerG.step()
188 |
189 | #Visualize generator gradients
190 | for name, param in netG.named_parameters():
191 | writer.add_histogram("GeneratorGradients/{}".format(name), param.grad, niter)
192 |
193 | ###########################
194 | # (3) Supervised update of G network: minimize mse of input deltas and actual deltas of generated sequences
195 | ###########################
196 |
197 | #Report metrics
198 | print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
199 | % (epoch, opt.epochs, i, len(dataloader),
200 | errD.item(), errG.item(), D_x, D_G_z1, D_G_z2), end='')
201 | if opt.delta_condition:
202 | writer.add_scalar('MSE of deltas of generated sequences', delta_loss.item(), niter)
203 | print(' DeltaMSE: %.4f' % (delta_loss.item()/opt.delta_lambda), end='')
204 | print()
205 | writer.add_scalar('DiscriminatorLoss', errD.item(), niter)
206 | writer.add_scalar('GeneratorLoss', errG.item(), niter)
207 | writer.add_scalar('D of X', D_x, niter)
208 | writer.add_scalar('D of G of z', D_G_z1, niter)
209 |
210 | ##### End of the epoch #####
211 | real_plot = time_series_to_plot(dataset.denormalize(real_display))
212 | if (epoch % opt.tensorboard_image_every == 0) or (epoch == (opt.epochs - 1)):
213 | writer.add_image("Real", real_plot, epoch)
214 |
215 | fake = netG(fixed_noise)
216 | fake_plot = time_series_to_plot(dataset.denormalize(fake))
217 | torchvision.utils.save_image(fake_plot, os.path.join(opt.imf, opt.run_tag+'_epoch'+str(epoch)+'.jpg'))
218 | if (epoch % opt.tensorboard_image_every == 0) or (epoch == (opt.epochs - 1)):
219 | writer.add_image("Fake", fake_plot, epoch)
220 |
221 | # Checkpoint
222 | if (epoch % opt.checkpoint_every == 0) or (epoch == (opt.epochs - 1)):
223 | torch.save(netG, '%s/%s_netG_epoch_%d.pth' % (opt.outf, opt.run_tag, epoch))
224 | torch.save(netD, '%s/%s_netD_epoch_%d.pth' % (opt.outf, opt.run_tag, epoch))
225 |
--------------------------------------------------------------------------------
/main_cgan.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import torch
5 | import torch.nn as nn
6 | import torch.backends.cudnn as cudnn
7 | import torch.optim as optim
8 | import torch.utils.data
9 | import torchvision
10 | import datetime
11 | from btp_dataset import BtpDataset
12 | from utils import time_series_to_plot
13 | from tensorboardX import SummaryWriter
14 | from models.recurrent_models import LSTMGenerator, LSTMDiscriminator
15 | from models.convolutional_models import CausalConvGenerator, CausalConvDiscriminator
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('--dataset', default="btp", help='dataset to use (only btp for now)')
19 | parser.add_argument('--dataset_path', help='path to dataset')
20 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
21 | parser.add_argument('--batchSize', type=int, default=16, help='input batch size')
22 | parser.add_argument('--nz', type=int, default=100, help='dimensionality of the latent vector z')
23 | parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train for')
24 | parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
25 | parser.add_argument('--cuda', action='store_true', help='enables cuda')
26 | parser.add_argument('--netG', default='', help="path to netG (to continue training)")
27 | parser.add_argument('--netD', default='', help="path to netD (to continue training)")
28 | parser.add_argument('--outf', default='checkpoints', help='folder to save checkpoints')
29 | parser.add_argument('--imf', default='images', help='folder to save images')
30 | parser.add_argument('--manualSeed', type=int, help='manual seed')
31 | parser.add_argument('--logdir', default='log', help='logdir for tensorboard')
32 | parser.add_argument('--run_tag', default='', help='tags for the current run')
33 | parser.add_argument('--checkpoint_every', default=5, help='number of epochs after which saving checkpoints')
34 | parser.add_argument('--tensorboard_image_every', default=5, help='interval for displaying images on tensorboard')
35 | parser.add_argument('--dis_type', default='cnn', choices=['cnn','lstm'], help='architecture to be used for discriminator to use')
36 | parser.add_argument('--gen_type', default='lstm', choices=['cnn','lstm'], help='architecture to be used for generator to use')
37 | opt = parser.parse_args()
38 |
39 | #Create writer for tensorboard
40 | date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M")
41 | run_name = f"{opt.run_tag}_{date}" if opt.run_tag != '' else date
42 | log_dir_name = os.path.join(opt.logdir, run_name)
43 | writer = SummaryWriter(log_dir_name)
44 | writer.add_text('Options', str(opt), 0)
45 | print(opt)
46 |
47 | try:
48 | os.makedirs(opt.outf)
49 | except OSError:
50 | pass
51 | try:
52 | os.makedirs(opt.imf)
53 | except OSError:
54 | pass
55 |
56 | if opt.manualSeed is None:
57 | opt.manualSeed = random.randint(1, 10000)
58 | print("Random Seed: ", opt.manualSeed)
59 | random.seed(opt.manualSeed)
60 | torch.manual_seed(opt.manualSeed)
61 |
62 | cudnn.benchmark = True
63 |
64 | if torch.cuda.is_available() and not opt.cuda:
65 | print("You have a cuda device, so you might want to run with --cuda as option")
66 |
67 | if opt.dataset == "btp":
68 | dataset = BtpDataset(opt.dataset_path)
69 | assert dataset
70 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
71 | shuffle=True, num_workers=int(opt.workers))
72 |
73 | device = torch.device("cuda:0" if opt.cuda else "cpu")
74 | nz = int(opt.nz)
75 | #Retrieve the sequence length as first dimension of a sequence in the dataset
76 | seq_len = dataset[0].size(0)
77 | #An additional input is needed for the delta
78 | in_dim = opt.nz + 1
79 |
80 | if opt.dis_type == "lstm":
81 | netD = LSTMDiscriminator(in_dim=2, hidden_dim=256).to(device)
82 | if opt.dis_type == "cnn":
83 | netD = CausalConvDiscriminator(input_size=2, n_layers=8, n_channel=10, kernel_size=8, dropout=0).to(device)
84 | if opt.gen_type == "lstm":
85 | netG = LSTMGenerator(in_dim=in_dim, out_dim=1, hidden_dim=256).to(device)
86 | if opt.gen_type == "cnn":
87 | netG = CausalConvGenerator(noise_size=in_dim, output_size=1, n_layers=8, n_channel=10, kernel_size=8, dropout=0.2).to(device)
88 |
89 | assert netG
90 | assert netD
91 |
92 | if opt.netG != '':
93 | netG.load_state_dict(torch.load(opt.netG))
94 | if opt.netD != '':
95 | netD.load_state_dict(torch.load(opt.netD))
96 |
97 | print("|Discriminator Architecture|\n", netD)
98 | print("|Generator Architecture|\n", netG)
99 |
100 | criterion = nn.BCELoss().to(device)
101 | delta_criterion = nn.MSELoss().to(device)
102 |
103 | #Generate fixed noise to be used for visualization
104 | fixed_noise = torch.randn(opt.batchSize, seq_len, nz, device=device)
105 |
106 | #Sample both deltas and noise for visualization
107 | deltas = dataset.sample_deltas(opt.batchSize).unsqueeze(2).repeat(1, seq_len, 1)
108 | fixed_noise = torch.cat((fixed_noise, deltas), dim=2)
109 |
110 | real_label = 1
111 | fake_label = 0
112 |
113 | # setup optimizer
114 | optimizerD = optim.Adam(netD.parameters(), lr=opt.lr)
115 | optimizerG = optim.Adam(netG.parameters(), lr=opt.lr)
116 |
117 | for epoch in range(opt.epochs):
118 | for i, data in enumerate(dataloader, 0):
119 | niter = epoch * len(dataloader) + i
120 |
121 | #Save just first batch of real data for displaying
122 | if i == 0:
123 | real_display = data.cpu()
124 |
125 | ############################
126 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
127 | ###########################
128 |
129 | #Train with real data
130 | netD.zero_grad()
131 | real = data.to(device)
132 | batch_size, seq_len = real.size(0), real.size(1)
133 | label = torch.full((batch_size, seq_len, 1), real_label, device=device)
134 |
135 | deltas = (real[:, -1] - real[:, 0]).unsqueeze(2).repeat(1, seq_len, 1)
136 | real = torch.cat((real, deltas), dim=2)
137 | output = netD(real)
138 | errD_real = criterion(output, label)
139 | errD_real.backward()
140 | D_x = output.mean().item()
141 |
142 | #Train with fake data
143 | noise = torch.randn(batch_size, seq_len, nz, device=device)
144 | #Sample a delta for each batch and concatenate to the noise for each timestep
145 | deltas = dataset.sample_deltas(batch_size).unsqueeze(2).repeat(1, seq_len, 1)
146 | noise = torch.cat((noise, deltas), dim=2)
147 |
148 | fake = netG(noise)
149 | label.fill_(fake_label)
150 | output = netD(torch.cat((fake.detach(), deltas), dim=2))
151 | errD_fake = criterion(output, label)
152 | errD_fake.backward()
153 | D_G_z1 = output.mean().item()
154 | errD = errD_real + errD_fake
155 | optimizerD.step()
156 |
157 | #Visualize discriminator gradients
158 | for name, param in netD.named_parameters():
159 | writer.add_histogram("DiscriminatorGradients/{}".format(name), param.grad, niter)
160 |
161 | ############################
162 | # (2) Update G network: maximize log(D(G(z)))
163 | ###########################
164 | netG.zero_grad()
165 | label.fill_(real_label)
166 | output = netD(torch.cat((fake, deltas), dim=2))
167 | errG = criterion(output, label)
168 | errG.backward()
169 | D_G_z2 = output.mean().item()
170 |
171 | optimizerG.step()
172 |
173 | #Visualize generator gradients
174 | for name, param in netG.named_parameters():
175 | writer.add_histogram("GeneratorGradients/{}".format(name), param.grad, niter)
176 |
177 | ###########################
178 | # (3) Supervised update of G network: minimize mse of input deltas and actual deltas of generated sequences
179 | ###########################
180 |
181 | #Report metrics
182 | print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
183 | % (epoch, opt.epochs, i, len(dataloader),
184 | errD.item(), errG.item(), D_x, D_G_z1, D_G_z2), end='')
185 | print()
186 | writer.add_scalar('DiscriminatorLoss', errD.item(), niter)
187 | writer.add_scalar('GeneratorLoss', errG.item(), niter)
188 | writer.add_scalar('D of X', D_x, niter)
189 | writer.add_scalar('D of G of z', D_G_z1, niter)
190 |
191 | ##### End of the epoch #####
192 | real_plot = time_series_to_plot(dataset.denormalize(real_display))
193 | if (epoch % opt.tensorboard_image_every == 0) or (epoch == (opt.epochs - 1)):
194 | writer.add_image("Real", real_plot, epoch)
195 |
196 | fake = netG(fixed_noise)
197 | fake_plot = time_series_to_plot(dataset.denormalize(fake))
198 | torchvision.utils.save_image(fake_plot, os.path.join(opt.imf, opt.run_tag+'_epoch'+str(epoch)+'.jpg'))
199 | if (epoch % opt.tensorboard_image_every == 0) or (epoch == (opt.epochs - 1)):
200 | writer.add_image("Fake", fake_plot, epoch)
201 |
202 | # Checkpoint
203 | if (epoch % opt.checkpoint_every == 0) or (epoch == (opt.epochs - 1)):
204 | torch.save(netG, '%s/%s_netG_epoch_%d.pth' % (opt.outf, opt.run_tag, epoch))
205 | torch.save(netD, '%s/%s_netD_epoch_%d.pth' % (opt.outf, opt.run_tag, epoch))
206 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/proceduralia/pytorch-GAN-timeseries/8e7d62fed6f4061d13ec9dfd84e07520d4257ed2/models/__init__.py
--------------------------------------------------------------------------------
/models/convolutional_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils import weight_norm
4 |
5 | class Chomp1d(nn.Module):
6 | def __init__(self, chomp_size):
7 | super(Chomp1d, self).__init__()
8 | self.chomp_size = chomp_size
9 |
10 | def forward(self, x):
11 | return x[:, :, :-self.chomp_size].contiguous()
12 |
13 |
14 | class TemporalBlock(nn.Module):
15 | def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
16 | super(TemporalBlock, self).__init__()
17 | self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
18 | stride=stride, padding=padding, dilation=dilation))
19 | self.chomp1 = Chomp1d(padding)
20 | self.relu1 = nn.ReLU()
21 | self.dropout1 = nn.Dropout(dropout)
22 |
23 | self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
24 | stride=stride, padding=padding, dilation=dilation))
25 | self.chomp2 = Chomp1d(padding)
26 | self.relu2 = nn.ReLU()
27 | self.dropout2 = nn.Dropout(dropout)
28 |
29 | self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
30 | self.conv2, self.chomp2, self.relu2, self.dropout2)
31 | self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
32 | self.relu = nn.ReLU()
33 | self.init_weights()
34 |
35 | def init_weights(self):
36 | self.conv1.weight.data.normal_(0, 0.01)
37 | self.conv2.weight.data.normal_(0, 0.01)
38 | if self.downsample is not None:
39 | self.downsample.weight.data.normal_(0, 0.01)
40 |
41 | def forward(self, x):
42 | out = self.net(x)
43 | res = x if self.downsample is None else self.downsample(x)
44 | return self.relu(out + res)
45 |
46 |
47 | class TemporalConvNet(nn.Module):
48 | def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
49 | super(TemporalConvNet, self).__init__()
50 | layers = []
51 | num_levels = len(num_channels)
52 | for i in range(num_levels):
53 | dilation_size = 2 ** i
54 | in_channels = num_inputs if i == 0 else num_channels[i-1]
55 | out_channels = num_channels[i]
56 | layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
57 | padding=(kernel_size-1) * dilation_size, dropout=dropout)]
58 |
59 | self.network = nn.Sequential(*layers)
60 |
61 | def forward(self, x):
62 | return self.network(x)
63 |
64 |
65 | class TCN(nn.Module):
66 | def __init__(self, input_size, output_size, num_channels, kernel_size, dropout):
67 | super(TCN, self).__init__()
68 | self.tcn = TemporalConvNet(input_size, num_channels, kernel_size=kernel_size, dropout=dropout)
69 | self.linear = nn.Linear(num_channels[-1], output_size)
70 | self.init_weights()
71 |
72 | def init_weights(self):
73 | self.linear.weight.data.normal_(0, 0.01)
74 |
75 | def forward(self, x, channel_last=True):
76 | #If channel_last, the expected format is (batch_size, seq_len, features)
77 | y1 = self.tcn(x.transpose(1, 2) if channel_last else x)
78 | return self.linear(y1.transpose(1, 2))
79 |
80 |
81 | class CausalConvDiscriminator(nn.Module):
82 | """Discriminator using casual dilated convolution, outputs a probability for each time step
83 |
84 | Args:
85 | input_size (int): dimensionality (channels) of the input
86 | n_layers (int): number of hidden layers
87 | n_channels (int): number of channels in the hidden layers (it's always the same)
88 | kernel_size (int): kernel size in all the layers
89 | dropout: (float in [0-1]): dropout rate
90 |
91 | Input: (batch_size, seq_len, input_size)
92 | Output: (batch_size, seq_len, 1)
93 | """
94 | def __init__(self, input_size, n_layers, n_channel, kernel_size, dropout=0):
95 | super().__init__()
96 | #Assuming same number of channels layerwise
97 | num_channels = [n_channel] * n_layers
98 | self.tcn = TCN(input_size, 1, num_channels, kernel_size, dropout)
99 |
100 | def forward(self, x, channel_last=True):
101 | return torch.sigmoid(self.tcn(x, channel_last))
102 |
103 | class CausalConvGenerator(nn.Module):
104 | """Generator using casual dilated convolution, expecting a noise vector for each timestep as input
105 |
106 | Args:
107 | noise_size (int): dimensionality (channels) of the input noise
108 | output_size (int): dimenstionality (channels) of the output sequence
109 | n_layers (int): number of hidden layers
110 | n_channels (int): number of channels in the hidden layers (it's always the same)
111 | kernel_size (int): kernel size in all the layers
112 | dropout: (float in [0-1]): dropout rate
113 |
114 | Input: (batch_size, seq_len, input_size)
115 | Output: (batch_size, seq_len, outputsize)
116 | """
117 | def __init__(self, noise_size, output_size, n_layers, n_channel, kernel_size, dropout=0):
118 | super().__init__()
119 | num_channels = [n_channel] * n_layers
120 | self.tcn = TCN(noise_size, output_size, num_channels, kernel_size, dropout)
121 |
122 | def forward(self, x, channel_last=True):
123 | return torch.tanh(self.tcn(x, channel_last))
124 |
125 |
126 | if __name__ == "__main__":
127 | #30-dimensional noise
128 | input = torch.randn(8, 32, 30)
129 |
130 | gen = CausalConvGenerator(noise_size=30, output_size=1, n_layers=8, n_channel=10, kernel_size=8, dropout=0)
131 | dis = CausalConvDiscriminator(input_size=1, n_layers=8, n_channel=10, kernel_size=8, dropout=0)
132 |
133 | print("Input shape:", input.size())
134 | fake = gen(input)
135 | print("Generator output shape:", fake.size())
136 | dis_out = dis(fake)
137 | print("Discriminator output shape:", dis_out.size())
138 |
--------------------------------------------------------------------------------
/models/recurrent_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class LSTMGenerator(nn.Module):
5 | """An LSTM based generator. It expects a sequence of noise vectors as input.
6 |
7 | Args:
8 | in_dim: Input noise dimensionality
9 | out_dim: Output dimensionality
10 | n_layers: number of lstm layers
11 | hidden_dim: dimensionality of the hidden layer of lstms
12 |
13 | Input: noise of shape (batch_size, seq_len, in_dim)
14 | Output: sequence of shape (batch_size, seq_len, out_dim)
15 | """
16 |
17 | def __init__(self, in_dim, out_dim, n_layers=1, hidden_dim=256):
18 | super().__init__()
19 | self.n_layers = n_layers
20 | self.hidden_dim = hidden_dim
21 | self.out_dim = out_dim
22 |
23 | self.lstm = nn.LSTM(in_dim, hidden_dim, n_layers, batch_first=True)
24 | self.linear = nn.Sequential(nn.Linear(hidden_dim, out_dim), nn.Tanh())
25 |
26 | def forward(self, input):
27 | batch_size, seq_len = input.size(0), input.size(1)
28 | h_0 = torch.zeros(self.n_layers, batch_size, self.hidden_dim)
29 | c_0 = torch.zeros(self.n_layers, batch_size, self.hidden_dim)
30 |
31 | recurrent_features, _ = self.lstm(input, (h_0, c_0))
32 | outputs = self.linear(recurrent_features.contiguous().view(batch_size*seq_len, self.hidden_dim))
33 | outputs = outputs.view(batch_size, seq_len, self.out_dim)
34 | return outputs
35 |
36 |
37 | class LSTMDiscriminator(nn.Module):
38 | """An LSTM based discriminator. It expects a sequence as input and outputs a probability for each element.
39 |
40 | Args:
41 | in_dim: Input noise dimensionality
42 | n_layers: number of lstm layers
43 | hidden_dim: dimensionality of the hidden layer of lstms
44 |
45 | Inputs: sequence of shape (batch_size, seq_len, in_dim)
46 | Output: sequence of shape (batch_size, seq_len, 1)
47 | """
48 |
49 | def __init__(self, in_dim, n_layers=1, hidden_dim=256):
50 | super().__init__()
51 | self.n_layers = n_layers
52 | self.hidden_dim = hidden_dim
53 |
54 | self.lstm = nn.LSTM(in_dim, hidden_dim, n_layers, batch_first=True)
55 | self.linear = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid())
56 |
57 | def forward(self, input):
58 | batch_size, seq_len = input.size(0), input.size(1)
59 | h_0 = torch.zeros(self.n_layers, batch_size, self.hidden_dim)
60 | c_0 = torch.zeros(self.n_layers, batch_size, self.hidden_dim)
61 |
62 | recurrent_features, _ = self.lstm(input, (h_0, c_0))
63 | outputs = self.linear(recurrent_features.contiguous().view(batch_size*seq_len, self.hidden_dim))
64 | outputs = outputs.view(batch_size, seq_len, 1)
65 | return outputs
66 |
67 |
68 | if __name__ == "__main__":
69 | batch_size = 16
70 | seq_len = 32
71 | noise_dim = 100
72 | seq_dim = 4
73 |
74 | gen = LSTMGenerator(noise_dim, seq_dim)
75 | dis = LSTMDiscriminator(seq_dim)
76 | noise = torch.randn(8, 16, noise_dim)
77 | gen_out = gen(noise)
78 | dis_out = dis(gen_out)
79 |
80 | print("Noise: ", noise.size())
81 | print("Generator output: ", gen_out.size())
82 | print("Discriminator output: ", dis_out.size())
83 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.0.0
2 | matplotlib==2.1.0
3 | torchvision==0.2.1
4 | pandas==0.22.0
5 | numpy==1.15.4
6 | tensorboardX==1.6
7 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import torchvision.utils as vutils
5 |
6 | def time_series_to_plot(time_series_batch, dpi=35, feature_idx=0, n_images_per_row=4, titles=None):
7 | """Convert a batch of time series to a tensor with a grid of their plots
8 |
9 | Args:
10 | time_series_batch (Tensor): (batch_size, seq_len, dim) tensor of time series
11 | dpi (int): dpi of a single image
12 | feature_idx (int): index of the feature that goes in the plots (the first one by default)
13 | n_images_per_row (int): number of images per row in the plot
14 | titles (list of strings): list of titles for the plots
15 |
16 | Output:
17 | single (channels, width, height)-shaped tensor representing an image
18 | """
19 | #Iterates over the time series
20 | images = []
21 | for i, series in enumerate(time_series_batch.detach()):
22 | fig = plt.figure(dpi=dpi)
23 | ax = fig.add_subplot(1,1,1)
24 | if titles:
25 | ax.set_title(titles[i])
26 | ax.plot(series[:, feature_idx].numpy()) #plots a single feature of the time series
27 | fig.canvas.draw()
28 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
29 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
30 | images.append(data)
31 | plt.close(fig)
32 |
33 | #Swap channel
34 | images = torch.from_numpy(np.stack(images)).permute(0, 3, 1, 2)
35 | #Make grid
36 | grid_image = vutils.make_grid(images.detach(), nrow=n_images_per_row)
37 | return grid_image
38 |
39 | def tensor_to_string_list(tensor):
40 | """Convert a tensor to a list of strings representing its value"""
41 | scalar_list = tensor.squeeze().numpy().tolist()
42 | return ["%.5f" % scalar for scalar in scalar_list]
43 |
44 | class DatasetGenerator:
45 | def __init__(self, generator, seq_len=96, noise_dim=100, dataset=None):
46 | """Class for fake dataset generation
47 | Args:
48 | generator (pytorch module): trained generator to use
49 | seq_len (int): length of the sequences to be generated
50 | noise_dim (int): input noise dimension for gan generator
51 | dataset (Dataset): dataset providing normalize and denormalize functions for deltas and series (by default, don't normalize)
52 | """
53 | self.generator = generator
54 | self.seq_len = seq_len
55 | self.noise_dim = noise_dim
56 | self.dataset = dataset
57 |
58 | def generate_dataset(self, outfile=None, batch_size=4, delta_list=None, size=1000):
59 | """Method for generating a dataset
60 | Args:
61 | outfile (string): name of the npy file to save the dataset. If None, it is simply returned as pytorch tensor
62 | batch_size (int): batch size for generation
63 | seq_len (int): sequence length of the sequences to be generated
64 | delta_list (list): list of deltas to be used in the case of conditional generation
65 | size (int): number of time series to generate if delta_list is present, this parameter is ignored
66 | """
67 | #If conditional generation is required, then input for generator must contain deltas
68 | if delta_list:
69 | noise = torch.randn(len(delta_list), self.seq_len, self.noise_dim)
70 | deltas = torch.FloatTensor(delta_list).view(-1, 1, 1).repeat(1, self.seq_len, 1)
71 | if self.dataset:
72 | #Deltas are provided in original range, normalization required
73 | deltas = self.dataset.normalize_deltas(deltas)
74 | noise = torch.cat((noise, deltas), dim=2)
75 | else:
76 | noise = torch.randn(size, self.seq_len, self.noise_dim)
77 |
78 | out_list = []
79 | for batch in noise.split(batch_size):
80 | out_list.append(self.generator(batch))
81 | out_tensor = torch.cat(out_list, dim=0)
82 |
83 | #Puts generated sequences in original range
84 | if self.dataset:
85 | out_tensor = self.dataset.denormalize(out_tensor)
86 |
87 | if outfile:
88 | np.save(outfile, out_tensor.detach().numpy())
89 | else:
90 | return out_tensor
91 |
92 |
93 | if __name__ == "__main__":
94 | model = torch.load('checkpoints/cnn_conditioned_alternate1_netG_epoch_85.pth')
95 | gen = DatasetGenerator(model)
96 | print("Shape of example dataset:", gen.generate_dataset(delta_list=[i for i in range(100)]).size())
97 |
--------------------------------------------------------------------------------