├── .flake8 ├── .gitignore ├── LICENSE ├── README.md ├── base ├── __init__.py ├── base_data_loader.py ├── base_model.py └── base_trainer.py ├── config.json ├── data_loader ├── data_loaders.py ├── polyphonic_dataloader.py └── seq_util.py ├── logger ├── __init__.py ├── logger.py ├── logger_config.json └── visualization.py ├── model ├── loss.py ├── metric.py ├── model.py └── modules.py ├── parse_config.py ├── test.py ├── train.py ├── trainer ├── __init__.py └── trainer.py └── utils ├── __init__.py └── util.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = F401, F403 3 | max-line-length = 120 4 | exclude = 5 | .git, 6 | __pycache__, 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | <<<<<<< HEAD 2 | __pycache__/ 3 | .data/ 4 | .vscode/ 5 | lightning_logs/ 6 | reconstructions/ 7 | *.ipynb 8 | ======= 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | env/ 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # dotenv 91 | .env 92 | 93 | # virtualenv 94 | .venv 95 | venv/ 96 | ENV/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | 111 | # input data, saved log, checkpoints 112 | data/ 113 | input/ 114 | saved/ 115 | datasets/ 116 | 117 | # editor, os cache directory 118 | .vscode/ 119 | .idea/ 120 | __MACOSX/ 121 | 122 | .data/ 123 | >>>>>>> 318f18692763ee44d83762b30b53d24835e7adcf 124 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yin-Jyun Luo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-deep-markov-model 2 | PyTorch re-implementatoin of the Deep Markov Model (https://arxiv.org/abs/1609.09869) 3 | ``` 4 | @inproceedings{10.5555/3298483.3298543, 5 | author = {Krishnan, Rahul G. and Shalit, Uri and Sontag, David}, 6 | title = {Structured Inference Networks for Nonlinear State Space Models}, 7 | year = {2017}, 8 | publisher = {AAAI Press}, 9 | booktitle = {Proceedings of the Thirty-First AAAI Conference on Artificial Intelligence}, 10 | pages = {2101–2109}, 11 | numpages = {9}, 12 | location = {San Francisco, California, USA}, 13 | series = {AAAI'17} 14 | } 15 | ``` 16 | **Note:** 17 | 1. The calculated metrics in `model/metrics.py` do not match those reported in the paper, which might be (more likely) due to differences in parameter settings and metric calculations. 18 | 2. The current implementatoin only supports JSB polyphonic music dataset. 19 | 20 | ## Under-development 21 | Refer to the branch `factorial-dmm` for a model described as [Factorial DMM](https://groups.csail.mit.edu/sls/publications/2019/SameerKhurana_ICASSP-2019.pdf). 22 | The other branch `refractor` is trying to improve readability with increased options of models (DOCUMENT NOT UPDATED YET!). 23 | 24 | ## Usage 25 | Training the model with the default `config.json`: 26 | 27 | python train.py -c config.json 28 | 29 | 30 | add `-i` flag to specifically name the experiment that is to be saved under `saved/`. 31 | 32 | ## `config.json` 33 | This file specifies parameters and configurations. 34 | Below explains some key parameters. 35 | 36 | **A careful fine-tuning of the parameters seems necessary to match the reported performances.** 37 | ```javascript 38 | { 39 | "arch": { 40 | "type": "DeepMarkovModel", 41 | "args": { 42 | "input_dim": 88, 43 | "z_dim": 100, 44 | "emission_dim": 100, 45 | "transition_dim": 200, 46 | "rnn_dim": 600, 47 | "rnn_type": "lstm", 48 | "rnn_layers": 1, 49 | "rnn_bidirection": false, // condition z_t on both directions of inputs, 50 | // manually turn off `reverse_rnn_input` if True 51 | // (this is minor and could be quickly fixed) 52 | "use_embedding": true, // use extra linear layer before RNN 53 | "orthogonal_init": true, // orthogonal initialization for RNN 54 | "gated_transition": true, // use linear/non-linear gated transition 55 | "train_init": false, // make z0 trainble 56 | "mean_field": false, // use mean-field posterior q(z_t | x) 57 | "reverse_rnn_input": true, // condition z_t on future inputs 58 | "sample": true // sample during reparameterization 59 | } 60 | }, 61 | "optimizer": { 62 | "type": "Adam", 63 | "args":{ 64 | "lr": 0.0008, // default value from the author's source code 65 | "weight_decay": 0.0, // debugging stage indicates that 1.0 prevents training 66 | "amsgrad": true, 67 | "betas": [0.9, 0.999] 68 | } 69 | }, 70 | "trainer": { 71 | "epochs": 3000, 72 | "overfit_single_batch": false, // overfit one single batch for debug 73 | 74 | "save_dir": "saved/", 75 | "save_period": 500, 76 | "verbosity": 2, 77 | 78 | "monitor": "min val_loss", 79 | "early_stop": 100, 80 | 81 | "tensorboard": true, 82 | 83 | "min_anneal_factor": 0.0, 84 | "anneal_update": 5000 85 | } 86 | } 87 | ``` 88 | 89 | ## References 90 | 0. Project template brought from the [pytorch-template](https://github.com/victoresque/pytorch-template) 91 | 1. The original [source code](https://github.com/clinicalml/structuredinference/tree/master/expt-polyphonic-fast) in Theano 92 | 2. PyTorch implementation in [Pyro](https://github.com/pyro-ppl/pyro/tree/dev/examples/dmm) framework 93 | 3. Another PyTorch implementation by [@guxd](https://github.com/guxd/deepHMM) 94 | 95 | ## To-Do 96 | - [ ] fine-tune to match the reported performances in the paper 97 | - [ ] correct (if any) errors in metric calculation, `model/metric.py` 98 | - [ ] optimize important sampling 99 | -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_data_loader import * 2 | from .base_model import * 3 | from .base_trainer import * 4 | -------------------------------------------------------------------------------- /base/base_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataloader import default_collate 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | 6 | 7 | class BaseDataLoader(DataLoader): 8 | """ 9 | Base class for all data loaders 10 | """ 11 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): 12 | self.validation_split = validation_split 13 | self.shuffle = shuffle 14 | 15 | self.batch_idx = 0 16 | self.n_samples = len(dataset) 17 | 18 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) 19 | 20 | self.init_kwargs = { 21 | 'dataset': dataset, 22 | 'batch_size': batch_size, 23 | 'shuffle': self.shuffle, 24 | 'collate_fn': collate_fn, 25 | 'num_workers': num_workers 26 | } 27 | super().__init__(sampler=self.sampler, **self.init_kwargs) 28 | 29 | def _split_sampler(self, split): 30 | if split == 0.0: 31 | return None, None 32 | 33 | idx_full = np.arange(self.n_samples) 34 | 35 | np.random.seed(0) 36 | np.random.shuffle(idx_full) 37 | 38 | if isinstance(split, int): 39 | assert split > 0 40 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." 41 | len_valid = split 42 | else: 43 | len_valid = int(self.n_samples * split) 44 | 45 | valid_idx = idx_full[0:len_valid] 46 | train_idx = np.delete(idx_full, np.arange(0, len_valid)) 47 | 48 | train_sampler = SubsetRandomSampler(train_idx) 49 | valid_sampler = SubsetRandomSampler(valid_idx) 50 | 51 | # turn off shuffle option which is mutually exclusive with sampler 52 | self.shuffle = False 53 | self.n_samples = len(train_idx) 54 | 55 | return train_sampler, valid_sampler 56 | 57 | def split_validation(self): 58 | if self.valid_sampler is None: 59 | return None 60 | else: 61 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) 62 | -------------------------------------------------------------------------------- /base/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | from abc import abstractmethod 4 | 5 | 6 | class BaseModel(nn.Module): 7 | """ 8 | Base class for all models 9 | """ 10 | @abstractmethod 11 | def forward(self, *inputs): 12 | """ 13 | Forward pass logic 14 | 15 | :return: Model output 16 | """ 17 | raise NotImplementedError 18 | 19 | def __str__(self): 20 | """ 21 | Model prints with number of trainable parameters 22 | """ 23 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 24 | params = sum([np.prod(p.size()) for p in model_parameters]) 25 | return super().__str__() + '\nTrainable parameters: {}'.format(params) 26 | -------------------------------------------------------------------------------- /base/base_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import abstractmethod 3 | from numpy import inf 4 | from logger import TensorboardWriter 5 | 6 | 7 | class BaseTrainer: 8 | """ 9 | Base class for all trainers 10 | """ 11 | def __init__(self, model, criterion, metric_ftns, optimizer, config): 12 | self.config = config 13 | self.logger = config.get_logger('trainer', config['trainer']['verbosity']) 14 | 15 | # setup GPU device if available, move model into configured device 16 | self.device, device_ids = self._prepare_device(config['n_gpu']) 17 | self.model = model.to(self.device) 18 | if len(device_ids) > 1: 19 | self.model = torch.nn.DataParallel(model, device_ids=device_ids) 20 | 21 | self.criterion = criterion 22 | self.metric_ftns = metric_ftns 23 | self.optimizer = optimizer 24 | 25 | cfg_trainer = config['trainer'] 26 | self.epochs = cfg_trainer['epochs'] 27 | self.save_period = cfg_trainer['save_period'] 28 | self.monitor = cfg_trainer.get('monitor', 'off') 29 | 30 | # configuration to monitor model performance and save best 31 | if self.monitor == 'off': 32 | self.mnt_mode = 'off' 33 | self.mnt_best = 0 34 | else: 35 | self.mnt_mode, self.mnt_metric = self.monitor.split() 36 | assert self.mnt_mode in ['min', 'max'] 37 | 38 | self.mnt_best = inf if self.mnt_mode == 'min' else -inf 39 | self.early_stop = cfg_trainer.get('early_stop', inf) 40 | 41 | self.start_epoch = 1 42 | 43 | self.checkpoint_dir = config.save_dir 44 | 45 | # setup visualization writer instance 46 | self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) 47 | 48 | if config.resume is not None: 49 | self._resume_checkpoint(config.resume) 50 | 51 | @abstractmethod 52 | def _train_epoch(self, epoch): 53 | """ 54 | Training logic for an epoch 55 | 56 | :param epoch: Current epoch number 57 | """ 58 | raise NotImplementedError 59 | 60 | def train(self): 61 | """ 62 | Full training logic 63 | """ 64 | not_improved_count = 0 65 | for epoch in range(self.start_epoch, self.epochs + 1): 66 | result = self._train_epoch(epoch) 67 | 68 | # save logged informations into log dict 69 | log = {'epoch': epoch} 70 | log.update(result) 71 | 72 | # print logged informations to the screen 73 | for key, value in log.items(): 74 | self.logger.info(' {:15s}: {}'.format(str(key), value)) 75 | 76 | # evaluate model performance according to configured metric, save best checkpoint as model_best 77 | best = False 78 | if self.mnt_mode != 'off': 79 | try: 80 | # check whether model performance improved or not, according to specified metric(mnt_metric) 81 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ 82 | (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) 83 | except KeyError: 84 | self.logger.warning("Warning: Metric '{}' is not found. " 85 | "Model performance monitoring is disabled.".format(self.mnt_metric)) 86 | self.mnt_mode = 'off' 87 | improved = False 88 | 89 | if improved: 90 | self.mnt_best = log[self.mnt_metric] 91 | not_improved_count = 0 92 | best = True 93 | else: 94 | not_improved_count += 1 95 | 96 | if not_improved_count > self.early_stop: 97 | self.logger.info("Validation performance didn\'t improve for {} epochs. " 98 | "Training stops.".format(self.early_stop)) 99 | break 100 | 101 | if epoch % self.save_period == 0: 102 | self._save_checkpoint(epoch, save_best=best) 103 | 104 | def _prepare_device(self, n_gpu_use): 105 | """ 106 | setup GPU device if available, move model into configured device 107 | """ 108 | n_gpu = torch.cuda.device_count() 109 | if n_gpu_use > 0 and n_gpu == 0: 110 | self.logger.warning("Warning: There\'s no GPU available on this machine," 111 | "training will be performed on CPU.") 112 | n_gpu_use = 0 113 | if n_gpu_use > n_gpu: 114 | self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available " 115 | "on this machine.".format(n_gpu_use, n_gpu)) 116 | n_gpu_use = n_gpu 117 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') 118 | list_ids = list(range(n_gpu_use)) 119 | return device, list_ids 120 | 121 | def _save_checkpoint(self, epoch, save_best=False): 122 | """ 123 | Saving checkpoints 124 | 125 | :param epoch: current epoch number 126 | :param log: logging information of the epoch 127 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth' 128 | """ 129 | arch = type(self.model).__name__ 130 | state = { 131 | 'arch': arch, 132 | 'epoch': epoch, 133 | 'state_dict': self.model.state_dict(), 134 | 'optimizer': self.optimizer.state_dict(), 135 | 'monitor_best': self.mnt_best, 136 | 'config': self.config 137 | } 138 | filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) 139 | torch.save(state, filename) 140 | self.logger.info("Saving checkpoint: {} ...".format(filename)) 141 | if save_best: 142 | best_path = str(self.checkpoint_dir / 'model_best.pth') 143 | torch.save(state, best_path) 144 | self.logger.info("Saving current best: model_best.pth ...") 145 | 146 | def _resume_checkpoint(self, resume_path): 147 | """ 148 | Resume from saved checkpoints 149 | 150 | :param resume_path: Checkpoint path to be resumed 151 | """ 152 | resume_path = str(resume_path) 153 | self.logger.info("Loading checkpoint: {} ...".format(resume_path)) 154 | checkpoint = torch.load(resume_path) 155 | self.start_epoch = checkpoint['epoch'] + 1 156 | self.mnt_best = checkpoint['monitor_best'] 157 | 158 | # load architecture params from checkpoint. 159 | if checkpoint['config']['arch'] != self.config['arch']: 160 | self.logger.warning("Warning: Architecture configuration given in config file is different from that of " 161 | "checkpoint. This may yield an exception while state_dict is being loaded.") 162 | self.model.load_state_dict(checkpoint['state_dict']) 163 | 164 | # load optimizer state from checkpoint only when optimizer type is not changed. 165 | if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: 166 | self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " 167 | "Optimizer parameters not being resumed.") 168 | else: 169 | self.optimizer.load_state_dict(checkpoint['optimizer']) 170 | 171 | self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) 172 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "DMM_rnn", 3 | "n_gpu": 1, 4 | 5 | "arch": { 6 | "type": "DeepMarkovModel", 7 | "args": { 8 | "input_dim": 88, 9 | "z_dim": 100, 10 | "emission_dim": 100, 11 | "transition_dim": 200, 12 | "rnn_dim": 600, 13 | "rnn_type": "lstm", 14 | "rnn_layers": 1, 15 | "rnn_bidirection": false, 16 | "use_embedding": true, 17 | "orthogonal_init": true, 18 | "gated_transition": true, 19 | "train_init": false, 20 | "mean_field": false, 21 | "reverse_rnn_input": true, 22 | "sample": true 23 | } 24 | }, 25 | "data_loader_train": { 26 | "type": "PolyMusicDataLoader", 27 | "args":{ 28 | "batch_size": 20, 29 | "data_dir": "jsb", 30 | "split": "train", 31 | "shuffle": true, 32 | "num_workers": 1 33 | } 34 | }, 35 | "data_loader_valid": { 36 | "type": "PolyMusicDataLoader", 37 | "args":{ 38 | "batch_size": 20, 39 | "data_dir": "jsb", 40 | "split": "valid", 41 | "shuffle": false, 42 | "num_workers": 1 43 | } 44 | }, 45 | "data_loader_test": { 46 | "type": "PolyMusicDataLoader", 47 | "args":{ 48 | "batch_size": 20, 49 | "data_dir": "jsb", 50 | "split": "test", 51 | "shuffle": false, 52 | "num_workers": 1 53 | } 54 | }, 55 | "optimizer": { 56 | "type": "Adam", 57 | "args":{ 58 | "lr": 0.0008, 59 | "weight_decay": 0.0, 60 | "amsgrad": true, 61 | "betas": [0.9, 0.999] 62 | } 63 | }, 64 | "loss": "dmm_loss", 65 | "metrics": [ 66 | "bound_eval", "importance_sample" 67 | ], 68 | "trainer": { 69 | "epochs": 3000, 70 | "overfit_single_batch": false, 71 | 72 | "save_dir": "saved/", 73 | "save_period": 500, 74 | "verbosity": 2, 75 | 76 | "monitor": "min val_loss", 77 | "early_stop": 100, 78 | 79 | "tensorboard": true, 80 | 81 | "min_anneal_factor": 0.0, 82 | "anneal_update": 5000 83 | } 84 | } 85 | 86 | -------------------------------------------------------------------------------- /data_loader/data_loaders.py: -------------------------------------------------------------------------------- 1 | from base import BaseDataLoader 2 | import data_loader.polyphonic_dataloader as poly 3 | from data_loader.seq_util import seq_collate_fn 4 | 5 | 6 | class PolyMusicDataLoader(BaseDataLoader): 7 | def __init__(self, 8 | batch_size, 9 | data_dir='jsb', 10 | split='train', 11 | shuffle=True, 12 | collate_fn=seq_collate_fn, 13 | num_workers=1): 14 | 15 | assert data_dir in ['jsb'] 16 | assert split in ['train', 'valid', 'test'] 17 | if data_dir == 'jsb': 18 | self.dataset = poly.PolyDataset(poly.JSB_CHORALES, split) 19 | self.data_dir = data_dir 20 | self.split = split 21 | 22 | super().__init__(self.dataset, 23 | batch_size, 24 | shuffle, 25 | 0.0, 26 | num_workers, 27 | seq_collate_fn) 28 | -------------------------------------------------------------------------------- /data_loader/polyphonic_dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data preparation code bought from 3 | https://github.com/pyro-ppl/pyro/blob/dev/examples/dmm/polyphonic_data_loader.py 4 | """ 5 | 6 | import os 7 | from collections import namedtuple 8 | from urllib.request import urlopen 9 | import pickle 10 | 11 | import torch 12 | from torch.nn.utils.rnn import pad_sequence 13 | from torch.utils.data import Dataset 14 | 15 | from data_loader.seq_util import get_data_directory 16 | 17 | 18 | dset = namedtuple("dset", ["name", "url", "filename"]) 19 | 20 | JSB_CHORALES = dset("jsb_chorales", 21 | "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/jsb_chorales.pickle", 22 | "jsb_chorales.pkl") 23 | 24 | PIANO_MIDI = dset("piano_midi", 25 | "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/piano_midi.pickle", 26 | "piano_midi.pkl") 27 | 28 | MUSE_DATA = dset("muse_data", 29 | "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/muse_data.pickle", 30 | "muse_data.pkl") 31 | 32 | NOTTINGHAM = dset("nottingham", 33 | "https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/nottingham.pickle", 34 | "nottingham.pkl") 35 | 36 | 37 | # this function processes the raw data; in particular it unsparsifies it 38 | def process_data(base_path, dataset, min_note=21, note_range=88): 39 | output = os.path.join(base_path, dataset.filename) 40 | if os.path.exists(output): 41 | try: 42 | with open(output, "rb") as f: 43 | return pickle.load(f) 44 | except (ValueError, UnicodeDecodeError): 45 | # Assume python env has changed. 46 | # Recreate pickle file in this env's format. 47 | os.remove(output) 48 | 49 | print("processing raw data - {} ...".format(dataset.name)) 50 | data = pickle.load(urlopen(dataset.url)) 51 | # added this line to see the difference between the raw and processed data 52 | pickle.dump(data, open(os.path.join(base_path, 53 | '-'.join(['raw', dataset.filename])), "wb"), pickle.HIGHEST_PROTOCOL) 54 | processed_dataset = {} 55 | for split, data_split in data.items(): 56 | processed_dataset[split] = {} 57 | n_seqs = len(data_split) 58 | processed_dataset[split]['sequence_lengths'] = torch.zeros(n_seqs, dtype=torch.long) 59 | processed_dataset[split]['sequences'] = [] 60 | for seq in range(n_seqs): 61 | seq_length = len(data_split[seq]) 62 | processed_dataset[split]['sequence_lengths'][seq] = seq_length 63 | processed_sequence = torch.zeros((seq_length, note_range)) 64 | for t in range(seq_length): 65 | note_slice = torch.tensor(list(data_split[seq][t])) - min_note 66 | slice_length = len(note_slice) 67 | if slice_length > 0: 68 | processed_sequence[t, note_slice] = torch.ones(slice_length) 69 | processed_dataset[split]['sequences'].append(processed_sequence) 70 | pickle.dump(processed_dataset, open(output, "wb"), pickle.HIGHEST_PROTOCOL) 71 | print("dumped processed data to %s" % output) 72 | 73 | 74 | # this logic will be initiated upon import 75 | base_path = get_data_directory(__file__) 76 | if not os.path.exists(base_path): 77 | os.mkdir(base_path) 78 | 79 | 80 | # ingest training/validation/test data from disk 81 | def load_data(dataset): 82 | # download and process dataset if it does not exist 83 | process_data(base_path, dataset) 84 | file_loc = os.path.join(base_path, dataset.filename) 85 | 86 | with open(file_loc, "rb") as f: 87 | dset = pickle.load(f) 88 | for k, v in dset.items(): 89 | sequences = v["sequences"] 90 | dset[k]["sequences"] = pad_sequence(sequences, batch_first=True).type(torch.Tensor) 91 | dset[k]["sequence_lengths"] = v["sequence_lengths"].to(device=torch.Tensor().device) 92 | return dset 93 | 94 | 95 | class PolyDataset(Dataset): 96 | def __init__(self, dataset, split): 97 | self.dataset = dataset 98 | self.split = split 99 | 100 | self.data = load_data(dataset)[split] 101 | self.seq_lengths = self.data['sequence_lengths'] 102 | self.seq = self.data['sequences'] 103 | self.n_seq = len(self.seq_lengths) 104 | self.n_time_slices = float(torch.sum(self.seq_lengths)) 105 | 106 | def __len__(self): 107 | return self.n_seq 108 | 109 | def __getitem__(self, idx): 110 | return idx, self.seq[idx], self.seq_lengths[idx] 111 | -------------------------------------------------------------------------------- /data_loader/seq_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def reverse_sequence(x, seq_lengths): 7 | """ 8 | Brought from 9 | https://github.com/pyro-ppl/pyro/blob/dev/examples/dmm/polyphonic_data_loader.py 10 | 11 | Parameters 12 | ---------- 13 | x: tensor (b, T_max, input_dim) 14 | seq_lengths: tensor (b, ) 15 | 16 | Returns 17 | ------- 18 | x_reverse: tensor (b, T_max, input_dim) 19 | The input x in reversed order w.r.t. time-axis 20 | """ 21 | x_reverse = torch.zeros_like(x) 22 | for b in range(x.size(0)): 23 | t = seq_lengths[b] 24 | time_slice = torch.arange(t - 1, -1, -1, device=x.device) 25 | reverse_seq = torch.index_select(x[b, :, :], 0, time_slice) 26 | x_reverse[b, 0:t, :] = reverse_seq 27 | 28 | return x_reverse 29 | 30 | 31 | def pad_and_reverse(rnn_output, seq_lengths): 32 | """ 33 | Brought from 34 | https://github.com/pyro-ppl/pyro/blob/dev/examples/dmm/polyphonic_data_loader.py 35 | 36 | Parameters 37 | ---------- 38 | rnn_output: tensor # shape to be confirmed, should be packed rnn output 39 | seq_lengths: tensor (b, ) 40 | 41 | Returns 42 | ------- 43 | reversed_output: tensor (b, T_max, input_dim) 44 | The input sequence, unpacked and padded, 45 | in reversed order w.r.t. time-axis 46 | """ 47 | rnn_output, _ = nn.utils.rnn.pad_packed_sequence(rnn_output, 48 | batch_first=True) 49 | reversed_output = reverse_sequence(rnn_output, seq_lengths) 50 | return reversed_output 51 | 52 | 53 | def get_mini_batch_mask(x, seq_lengths): 54 | """ 55 | Brought from 56 | https://github.com/pyro-ppl/pyro/blob/dev/examples/dmm/polyphonic_data_loader.py 57 | 58 | Parameters 59 | ---------- 60 | x: tensor (b, T_max, input_dim) 61 | seq_lengths: tensor (b, ) 62 | 63 | Returns 64 | ------- 65 | mask: tensor (b, T_max) 66 | A binary mask generated according to `seq_lengths` 67 | """ 68 | mask = torch.zeros(x.shape[0:2]) 69 | for b in range(x.shape[0]): 70 | mask[b, 0:seq_lengths[b]] = torch.ones(seq_lengths[b]) 71 | return mask 72 | 73 | 74 | def get_mini_batch(mini_batch_indices, sequences, seq_lengths, cuda=True): 75 | """ 76 | Prepare a mini-batch (size b) from the dataset (size D) 77 | for training or evaluation 78 | 79 | Brought from 80 | https://github.com/pyro-ppl/pyro/blob/dev/examples/dmm/polyphonic_data_loader.py 81 | 82 | Parameters 83 | ---------- 84 | mini_batch_indices: tensor (b, ) 85 | Indices of a mini-batch of data 86 | sequences: tensor (D, D_T_max, input_dim) 87 | Padded data 88 | seq_lengths: tensor (D, ) 89 | Effective sequence lengths of each sequence in the dataset 90 | cuda: bool 91 | 92 | Returns 93 | ------- 94 | mini_batch: tensor (b, T_max, input_dim) 95 | A mini-batch from the dataset 96 | mini_batch_reversed: pytorch packed object 97 | A mini-batch in the reversed order; 98 | used as the input to RnnEncoder in DeepMarkovModel 99 | """ 100 | seq_lengths = seq_lengths[mini_batch_indices] 101 | _, sorted_seq_length_indices = torch.sort(seq_lengths) 102 | sorted_seq_length_indices = sorted_seq_length_indices.flip(0) 103 | sorted_seq_lengths = seq_lengths[sorted_seq_length_indices] 104 | sorted_mini_batch_indices = mini_batch_indices[sorted_seq_length_indices] 105 | 106 | T_max = torch.max(seq_lengths) 107 | mini_batch = sequences[sorted_mini_batch_indices, 0:T_max, :] 108 | mini_batch_reversed = reverse_sequence(mini_batch, sorted_seq_lengths) 109 | mini_batch_mask = get_mini_batch_mask(mini_batch, sorted_seq_lengths) 110 | 111 | if cuda: 112 | mini_batch = mini_batch.cuda() 113 | mini_batch_mask = mini_batch_mask.cuda() 114 | mini_batch_reversed = mini_batch_reversed.cuda() 115 | 116 | mini_batch_reversed = nn.utils.rnn.pack_padded_sequence( 117 | mini_batch_reversed, 118 | sorted_seq_lengths, 119 | batch_first=True 120 | ) 121 | 122 | return mini_batch, mini_batch_reversed, mini_batch_mask, sorted_seq_lengths 123 | 124 | 125 | def get_data_directory(filepath=None): 126 | """ 127 | Brought from 128 | https://github.com/pyro-ppl/pyro/blob/2b4a4013291e59f251564aeaf5815c4c3a18f4ff/pyro/contrib/examples/util.py#L66 129 | """ 130 | if 'CI' in os.environ: 131 | return os.path.expanduser('~/.data') 132 | return os.path.abspath(os.path.join(os.path.dirname(filepath), 133 | '.data')) 134 | 135 | 136 | def seq_collate_fn(batch): 137 | """ 138 | A customized `collate_fn` intented for loading padded sequential data 139 | """ 140 | idx, seq, seq_lengths = zip(*batch) 141 | idx = torch.tensor(idx) 142 | seq = torch.stack(seq) 143 | seq_lengths = torch.tensor(seq_lengths) 144 | _, sorted_seq_length_indices = torch.sort(seq_lengths) 145 | sorted_seq_length_indices = sorted_seq_length_indices.flip(0) 146 | sorted_seq_lengths = seq_lengths[sorted_seq_length_indices] 147 | 148 | T_max = torch.max(seq_lengths) 149 | mini_batch = seq[sorted_seq_length_indices, 0:T_max, :] 150 | mini_batch_reversed = reverse_sequence(mini_batch, sorted_seq_lengths) 151 | mini_batch_mask = get_mini_batch_mask(mini_batch, sorted_seq_lengths) 152 | 153 | return mini_batch, mini_batch_reversed, mini_batch_mask, sorted_seq_lengths 154 | 155 | 156 | def pack_padded_seq(seq, seq_len, batch_first=True): 157 | return nn.utils.rnn.pack_padded_sequence( 158 | seq, 159 | seq_len, 160 | batch_first=batch_first 161 | ) 162 | -------------------------------------------------------------------------------- /logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | from .visualization import * -------------------------------------------------------------------------------- /logger/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | from pathlib import Path 4 | from utils import read_json 5 | 6 | 7 | def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO): 8 | """ 9 | Setup logging configuration 10 | """ 11 | log_config = Path(log_config) 12 | if log_config.is_file(): 13 | config = read_json(log_config) 14 | # modify logging paths based on run config 15 | for _, handler in config['handlers'].items(): 16 | if 'filename' in handler: 17 | handler['filename'] = str(save_dir / handler['filename']) 18 | 19 | logging.config.dictConfig(config) 20 | else: 21 | print("Warning: logging configuration file is not found in {}.".format(log_config)) 22 | logging.basicConfig(level=default_level) 23 | -------------------------------------------------------------------------------- /logger/logger_config.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "version": 1, 4 | "disable_existing_loggers": false, 5 | "formatters": { 6 | "simple": {"format": "%(message)s"}, 7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"} 8 | }, 9 | "handlers": { 10 | "console": { 11 | "class": "logging.StreamHandler", 12 | "level": "DEBUG", 13 | "formatter": "simple", 14 | "stream": "ext://sys.stdout" 15 | }, 16 | "info_file_handler": { 17 | "class": "logging.handlers.RotatingFileHandler", 18 | "level": "INFO", 19 | "formatter": "datetime", 20 | "filename": "info.log", 21 | "maxBytes": 10485760, 22 | "backupCount": 20, "encoding": "utf8" 23 | } 24 | }, 25 | "root": { 26 | "level": "INFO", 27 | "handlers": [ 28 | "console", 29 | "info_file_handler" 30 | ] 31 | } 32 | } -------------------------------------------------------------------------------- /logger/visualization.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from datetime import datetime 3 | 4 | 5 | class TensorboardWriter(): 6 | def __init__(self, log_dir, logger, enabled): 7 | self.writer = None 8 | self.selected_module = "" 9 | 10 | if enabled: 11 | log_dir = str(log_dir) 12 | 13 | # Retrieve vizualization writer. 14 | succeeded = False 15 | for module in ["torch.utils.tensorboard", "tensorboardX"]: 16 | try: 17 | self.writer = importlib.import_module(module).SummaryWriter(log_dir) 18 | succeeded = True 19 | break 20 | except ImportError: 21 | succeeded = False 22 | self.selected_module = module 23 | 24 | if not succeeded: 25 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ 26 | "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \ 27 | "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file." 28 | logger.warning(message) 29 | 30 | self.step = 0 31 | self.mode = '' 32 | 33 | self.tb_writer_ftns = { 34 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 35 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding', 'add_figure' 36 | } 37 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} 38 | self.timer = datetime.now() 39 | 40 | def set_step(self, step, mode='train'): 41 | self.mode = mode 42 | self.step = step 43 | if step == 0: 44 | self.timer = datetime.now() 45 | else: 46 | duration = datetime.now() - self.timer 47 | self.add_scalar('steps_per_sec', 1 / duration.total_seconds()) 48 | self.timer = datetime.now() 49 | 50 | def __getattr__(self, name): 51 | """ 52 | If visualization is configured to use: 53 | return add_data() methods of tensorboard with additional information (step, tag) added. 54 | Otherwise: 55 | return a blank function handle that does nothing 56 | """ 57 | if name in self.tb_writer_ftns: 58 | add_data = getattr(self.writer, name, None) 59 | 60 | def wrapper(tag, data, *args, **kwargs): 61 | if add_data is not None: 62 | # add mode(train/valid) tag 63 | if name not in self.tag_mode_exceptions: 64 | tag = '{}/{}'.format(tag, self.mode) 65 | add_data(tag, data, self.step, *args, **kwargs) 66 | return wrapper 67 | else: 68 | # default action for returning methods defined in this class, set_step() for instance. 69 | try: 70 | attr = object.__getattr__(name) 71 | except AttributeError: 72 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) 73 | return attr 74 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def kl_div(mu1, logvar1, mu2=None, logvar2=None): 6 | if mu2 is None: 7 | mu2 = torch.zeros_like(mu1) 8 | if logvar2 is None: 9 | logvar2 = torch.zeros_like(mu1) 10 | 11 | return 0.5 * ( 12 | logvar2 - logvar1 + ( 13 | torch.exp(logvar1) + (mu1 - mu2).pow(2) 14 | ) / torch.exp(logvar2) - 1) 15 | 16 | 17 | def nll_loss(x_hat, x): 18 | assert x_hat.dim() == x.dim() == 3 19 | assert x.size() == x_hat.size() 20 | return nn.BCEWithLogitsLoss(reduction='none')(x_hat, x) 21 | 22 | 23 | def dmm_loss(x, x_hat, mu1, logvar1, mu2, logvar2, kl_annealing_factor=1, mask=None): 24 | kl_raw = kl_div(mu1, logvar1, mu2, logvar2) 25 | nll_raw = nll_loss(x_hat, x) 26 | # feature-dimension reduced 27 | kl_fr = kl_raw.mean(dim=-1) 28 | nll_fr = nll_raw.mean(dim=-1) 29 | # masking 30 | if mask is not None: 31 | mask = mask.gt(0).view(-1) 32 | kl_m = kl_fr.view(-1).masked_select(mask).mean() 33 | nll_m = nll_fr.view(-1).masked_select(mask).mean() 34 | else: 35 | kl_m = kl_fr.view(-1).mean() 36 | nll_m = nll_fr.view(-1).mean() 37 | 38 | loss = kl_m * kl_annealing_factor + nll_m 39 | 40 | return kl_raw, nll_raw, \ 41 | kl_fr, nll_fr, \ 42 | kl_m, nll_m, \ 43 | loss 44 | -------------------------------------------------------------------------------- /model/metric.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from torch.distributions import Normal 4 | from model.loss import nll_loss, kl_div 5 | 6 | 7 | def accuracy(output, target): 8 | with torch.no_grad(): 9 | pred = torch.argmax(output, dim=1) 10 | assert pred.shape[0] == len(target) 11 | correct = 0 12 | correct += torch.sum(pred == target).item() 13 | return correct / len(target) 14 | 15 | 16 | def top_k_acc(output, target, k=3): 17 | with torch.no_grad(): 18 | pred = torch.topk(output, k, dim=1)[1] 19 | assert pred.shape[0] == len(target) 20 | correct = 0 21 | for i in range(k): 22 | correct += torch.sum(pred[:, i] == target).item() 23 | return correct / len(target) 24 | 25 | 26 | def nll_metric(output, target, mask): 27 | assert output.dim() == target.dim() == 3 28 | assert output.size() == target.size() 29 | assert mask.dim() == 2 30 | assert mask.size(1) == output.size(1) 31 | loss = nll_loss(output, target) # (batch_size, time_step, input_dim) 32 | loss = mask * loss.sum(dim=-1) # (batch_size, time_step) 33 | loss = loss.sum(dim=1, keepdim=True) # (batch_size, 1) 34 | return loss 35 | 36 | 37 | def kl_div_metric(output, target, mask): 38 | mu1, logvar1 = output 39 | mu2, logvar2 = target 40 | assert mu1.size() == mu2.size() 41 | assert logvar1.size() == logvar2.size() 42 | assert mu1.dim() == logvar1.dim() == 3 43 | assert mask.dim() == 2 44 | assert mask.size(1) == mu1.size(1) 45 | kl = kl_div(mu1, logvar1, mu2, logvar2) 46 | kl = mask * kl.sum(dim=-1) 47 | kl = kl.sum(dim=1, keepdim=True) 48 | return kl 49 | 50 | 51 | def bound_eval(output, target, mask): 52 | x_recon, mu_q, logvar_q = output 53 | x, mu_p, logvar_p = target 54 | # batch_size = x.size(0) 55 | neg_elbo = nll_metric(x_recon, x, mask) + \ 56 | kl_div_metric([mu_q, logvar_q], [mu_p, logvar_p], mask) 57 | # tsbn_bound_sum = elbo.div(mask.sum(dim=1, keepdim=True)).sum().div(batch_size) 58 | bound_sum = neg_elbo.sum().div(mask.sum()) 59 | return bound_sum 60 | 61 | 62 | def importance_sample(batch_idx, model, x, x_reversed, x_seq_lengths, mask, n_sample=500): 63 | sample_batch_size = 25 64 | n_batch = n_sample // sample_batch_size 65 | sample_left = n_sample % sample_batch_size 66 | if sample_left == 0: 67 | n_loop = n_batch 68 | else: 69 | n_loop = n_batch + 1 70 | 71 | ll_estimate = torch.zeros(n_loop).to(x.device) 72 | 73 | start_time = time.time() 74 | for i in range(n_loop): 75 | if i < n_batch: 76 | n_repeats = sample_batch_size 77 | else: 78 | n_repeats = sample_left 79 | 80 | x_tile = x.repeat_interleave(repeats=n_repeats, dim=0) 81 | x_reversed_tile = x_reversed.repeat_interleave(repeats=n_repeats, dim=0) 82 | x_seq_lengths_tile = x_seq_lengths.repeat_interleave(repeats=n_repeats, dim=0) 83 | mask_tile = mask.repeat_interleave(repeats=n_repeats, dim=0) 84 | 85 | x_recon, z_q_seq, z_p_seq, mu_q_seq, logvar_q_seq, mu_p_seq, logvar_p_seq = \ 86 | model(x_tile, x_reversed_tile, x_seq_lengths_tile) 87 | 88 | q_dist = Normal(mu_q_seq, logvar_q_seq.exp().sqrt()) 89 | p_dist = Normal(mu_p_seq, logvar_p_seq.exp().sqrt()) 90 | log_qz = q_dist.log_prob(z_q_seq).sum(dim=-1) * mask_tile 91 | log_pz = p_dist.log_prob(z_q_seq).sum(dim=-1) * mask_tile 92 | log_px_z = -1 * nll_loss(x_recon, x_tile).sum(dim=-1) * mask_tile 93 | ll_estimate_ = log_px_z.sum(dim=1, keepdim=True) + \ 94 | log_pz.sum(dim=1, keepdim=True) - \ 95 | log_qz.sum(dim=1, keepdim=True) 96 | 97 | ll_estimate[i] = ll_estimate_.sum().div(mask.sum()) 98 | 99 | ll_estimate = ll_estimate.sum().div(n_sample) 100 | print("%s-th batch, importance sampling took %.4f seconds." % (batch_idx, time.time() - start_time)) 101 | 102 | return ll_estimate 103 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.modules import Emitter, Transition, Combiner, RnnEncoder 4 | import data_loader.polyphonic_dataloader as poly 5 | from data_loader.seq_util import seq_collate_fn, pack_padded_seq 6 | from base import BaseModel 7 | 8 | 9 | class DeepMarkovModel(BaseModel): 10 | 11 | def __init__(self, 12 | input_dim, 13 | z_dim, 14 | emission_dim, 15 | transition_dim, 16 | rnn_dim, 17 | rnn_type, 18 | rnn_layers, 19 | rnn_bidirection, 20 | orthogonal_init, 21 | use_embedding, 22 | gated_transition, 23 | train_init, 24 | mean_field=False, 25 | reverse_rnn_input=True, 26 | sample=True): 27 | super().__init__() 28 | self.input_dim = input_dim 29 | self.z_dim = z_dim 30 | self.emission_dim = emission_dim 31 | self.transition_dim = transition_dim 32 | self.rnn_dim = rnn_dim 33 | self.rnn_type = rnn_type 34 | self.rnn_layers = rnn_layers 35 | self.rnn_bidirection = rnn_bidirection 36 | self.use_embedding = use_embedding 37 | self.orthogonal_init = orthogonal_init 38 | self.gated_transition = gated_transition 39 | self.train_init = train_init 40 | self.mean_field = mean_field 41 | self.reverse_rnn_input = reverse_rnn_input 42 | self.sample = sample 43 | 44 | if use_embedding: 45 | self.embedding = nn.Linear(input_dim, rnn_dim) 46 | rnn_input_dim = rnn_dim 47 | else: 48 | rnn_input_dim = input_dim 49 | 50 | # instantiate components of DMM 51 | # generative model 52 | self.emitter = Emitter(z_dim, emission_dim, input_dim) 53 | self.transition = Transition(z_dim, transition_dim, 54 | gated=gated_transition, identity_init=True) 55 | # inference model 56 | self.combiner = Combiner(z_dim, rnn_dim, mean_field=mean_field) 57 | self.encoder = RnnEncoder(rnn_input_dim, rnn_dim, 58 | n_layer=rnn_layers, drop_rate=0.0, 59 | bd=rnn_bidirection, nonlin='relu', 60 | rnn_type=rnn_type, 61 | reverse_input=reverse_rnn_input) 62 | 63 | # initialize hidden states 64 | self.mu_p_0, self.logvar_p_0 = self.transition.init_z_0(trainable=train_init) 65 | self.z_q_0 = self.combiner.init_z_q_0(trainable=train_init) 66 | # h_0 = self.encoder.init_hidden(trainable=train_init) 67 | # if self.encoder.rnn_type == 'lstm': 68 | # self.h_0, self.c_0 = h_0 69 | # else: 70 | # self.h_0 = h_0 71 | 72 | def reparameterization(self, mu, logvar): 73 | if not self.sample: 74 | return mu 75 | std = torch.exp(0.5 * logvar) 76 | eps = torch.randn_like(std) 77 | return mu + eps * std 78 | 79 | def forward(self, x, x_reversed, x_seq_lengths): 80 | T_max = x.size(1) 81 | batch_size = x.size(0) 82 | 83 | if self.encoder.reverse_input: 84 | input = x_reversed 85 | else: 86 | input = x 87 | 88 | if self.use_embedding: 89 | input = self.embedding(input) 90 | 91 | input = pack_padded_seq(input, x_seq_lengths) 92 | h_rnn = self.encoder(input, x_seq_lengths) 93 | z_q_0 = self.z_q_0.expand(batch_size, self.z_dim) 94 | mu_p_0 = self.mu_p_0.expand(batch_size, 1, self.z_dim) 95 | logvar_p_0 = self.logvar_p_0.expand(batch_size, 1, self.z_dim) 96 | z_prev = z_q_0 97 | 98 | x_recon = torch.zeros([batch_size, T_max, self.input_dim]).to(x.device) 99 | mu_q_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device) 100 | logvar_q_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device) 101 | mu_p_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device) 102 | logvar_p_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device) 103 | z_q_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device) 104 | z_p_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device) 105 | for t in range(T_max): 106 | # q(z_t | z_{t-1}, x_{t:T}) 107 | mu_q, logvar_q = self.combiner(h_rnn[:, t, :], z_prev, 108 | rnn_bidirection=self.rnn_bidirection) 109 | zt_q = self.reparameterization(mu_q, logvar_q) 110 | z_prev = zt_q 111 | # p(z_t | z_{t-1}) 112 | mu_p, logvar_p = self.transition(z_prev) 113 | zt_p = self.reparameterization(mu_p, logvar_p) 114 | 115 | xt_recon = self.emitter(zt_q).contiguous() 116 | 117 | mu_q_seq[:, t, :] = mu_q 118 | logvar_q_seq[:, t, :] = logvar_q 119 | z_q_seq[:, t, :] = zt_q 120 | mu_p_seq[:, t, :] = mu_p 121 | logvar_p_seq[:, t, :] = logvar_p 122 | z_p_seq[:, t, :] = zt_p 123 | x_recon[:, t, :] = xt_recon 124 | 125 | mu_p_seq = torch.cat([mu_p_0, mu_p_seq[:, :-1, :]], dim=1) 126 | logvar_p_seq = torch.cat([logvar_p_0, logvar_p_seq[:, :-1, :]], dim=1) 127 | z_p_0 = self.reparameterization(mu_p_0, logvar_p_0) 128 | z_p_seq = torch.cat([z_p_0, z_p_seq[:, :-1, :]], dim=1) 129 | 130 | return x_recon, z_q_seq, z_p_seq, mu_q_seq, logvar_q_seq, mu_p_seq, logvar_p_seq 131 | 132 | def generate(self, batch_size, seq_len): 133 | mu_p = self.mu_p_0.expand(batch_size, self.z_dim) 134 | logvar_p = self.logvar_p_0.expand(batch_size, self.z_dim) 135 | z_p_seq = torch.zeros([batch_size, seq_len, self.z_dim]).to(mu_p.device) 136 | mu_p_seq = torch.zeros([batch_size, seq_len, self.z_dim]).to(mu_p.device) 137 | logvar_p_seq = torch.zeros([batch_size, seq_len, self.z_dim]).to(mu_p.device) 138 | output_seq = torch.zeros([batch_size, seq_len, self.input_dim]).to(mu_p.device) 139 | for t in range(seq_len): 140 | mu_p_seq[:, t, :] = mu_p 141 | logvar_p_seq[:, t, :] = logvar_p 142 | z_p = self.reparameterization(mu_p, logvar_p) 143 | xt = self.emitter(z_p) 144 | mu_p, logvar_p = self.transition(z_p) 145 | 146 | output_seq[:, t, :] = xt 147 | z_p_seq[:, t, :] = z_p 148 | return output_seq, z_p_seq, mu_p_seq, logvar_p_seq 149 | -------------------------------------------------------------------------------- /model/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as weight_init 4 | from data_loader.seq_util import pad_and_reverse 5 | 6 | """ 7 | Generative modules 8 | """ 9 | 10 | 11 | class Emitter(nn.Module): 12 | """ 13 | Parameterize the Bernoulli observation likelihood `p(x_t | z_t)` 14 | 15 | Parameters 16 | ---------- 17 | z_dim: int 18 | Dim. of latent variables 19 | emission_dim: int 20 | Dim. of emission hidden units 21 | input_dim: int 22 | Dim. of inputs 23 | 24 | Returns 25 | ------- 26 | A valid probability that parameterizes the 27 | Bernoulli distribution `p(x_t | z_t)` 28 | """ 29 | def __init__(self, z_dim, emission_dim, input_dim): 30 | super().__init__() 31 | self.z_dim = z_dim 32 | self.emission_dim = emission_dim 33 | self.input_dim = input_dim 34 | 35 | self.lin1 = nn.Linear(z_dim, emission_dim) 36 | self.lin2 = nn.Linear(emission_dim, emission_dim) 37 | self.lin3 = nn.Linear(emission_dim, input_dim) 38 | self.act = nn.ReLU() 39 | # self.out = nn.Sigmoid() 40 | 41 | def forward(self, z_t): 42 | h1 = self.act(self.lin1(z_t)) 43 | h2 = self.act(self.lin2(h1)) 44 | # return self.out(self.lin3(h2)) 45 | return self.lin3(h2) 46 | 47 | 48 | class Transition(nn.Module): 49 | """ 50 | Parameterize the diagonal Gaussian latent transition probability 51 | `p(z_t | z_{t-1})` 52 | 53 | Parameters 54 | ---------- 55 | z_dim: int 56 | Dim. of latent variables 57 | transition_dim: int 58 | Dim. of transition hidden units 59 | gated: bool 60 | Use the gated mechanism to consider both linearity and non-linearity 61 | identity_init: bool 62 | Initialize the linearity transform as an identity matrix; 63 | ignored if `gated == False` 64 | 65 | Returns 66 | ------- 67 | mu: tensor (b, z_dim) 68 | Mean that parameterizes the Gaussian 69 | logvar: tensor (b, z_dim) 70 | Log-variance that parameterizes the Gaussian 71 | """ 72 | def __init__(self, z_dim, transition_dim, gated=True, identity_init=True): 73 | super().__init__() 74 | self.z_dim = z_dim 75 | self.transition_dim = transition_dim 76 | self.gated = gated 77 | self.identity_init = identity_init 78 | 79 | # compute the corresponding mean parameterizing the Gaussian 80 | self.lin1p = nn.Linear(z_dim, transition_dim) 81 | self.lin2p = nn.Linear(transition_dim, z_dim) 82 | 83 | if gated: 84 | # compute the gain (gate) of non-linearity 85 | self.lin1 = nn.Linear(z_dim, transition_dim) 86 | self.lin2 = nn.Linear(transition_dim, z_dim) 87 | # compute the linearity part 88 | self.lin_n = nn.Linear(z_dim, z_dim) 89 | 90 | # compute the logvar 91 | self.lin_v = nn.Linear(z_dim, z_dim) 92 | 93 | if gated and identity_init: 94 | self.lin_n.weight.data = torch.eye(z_dim) 95 | self.lin_n.bias.data = torch.zeros(z_dim) 96 | 97 | self.act_weight = nn.Sigmoid() 98 | self.act = nn.ReLU() 99 | 100 | def init_z_0(self, trainable=True): 101 | return nn.Parameter(torch.zeros(self.z_dim), requires_grad=trainable), \ 102 | nn.Parameter(torch.zeros(self.z_dim), requires_grad=trainable) 103 | 104 | def forward(self, z_t_1): 105 | _mu = self.act(self.lin1p(z_t_1)) 106 | mu = self.lin2p(_mu) 107 | logvar = self.lin_v(self.act(mu)) 108 | 109 | if self.gated: 110 | _gain = self.act(self.lin1(z_t_1)) 111 | gain = self.act_weight(self.lin2(_gain)) 112 | mu = (1 - gain) * self.lin_n(z_t_1) + gain * mu 113 | 114 | return mu, logvar 115 | 116 | 117 | """ 118 | Inference modules 119 | """ 120 | 121 | 122 | class Combiner(nn.Module): 123 | """ 124 | Parameterize variational distribution `q(z_t | z_{t-1}, x_{t:T})` 125 | a diagonal Gaussian distribution 126 | 127 | Parameters 128 | ---------- 129 | z_dim: int 130 | Dim. of latent variables 131 | rnn_dim: int 132 | Dim. of RNN hidden states 133 | 134 | Returns 135 | ------- 136 | mu: tensor (b, z_dim) 137 | Mean that parameterizes the variational Gaussian distribution 138 | logvar: tensor (b, z_dim) 139 | Log-var that parameterizes the variational Gaussian distribution 140 | """ 141 | def __init__(self, z_dim, rnn_dim, mean_field=False): 142 | super().__init__() 143 | self.z_dim = z_dim 144 | self.rnn_dim = rnn_dim 145 | self.mean_field = mean_field 146 | 147 | if not mean_field: 148 | self.lin1 = nn.Linear(z_dim, rnn_dim) 149 | self.act = nn.Tanh() 150 | 151 | self.lin2 = nn.Linear(rnn_dim, z_dim) 152 | self.lin_v = nn.Linear(rnn_dim, z_dim) 153 | 154 | def init_z_q_0(self, trainable=True): 155 | return nn.Parameter(torch.zeros(self.z_dim), requires_grad=trainable) 156 | 157 | def forward(self, h_rnn, z_t_1=None, rnn_bidirection=False): 158 | """ 159 | z_t_1: tensor (b, z_dim) 160 | h_rnn: tensor (b, rnn_dim) 161 | """ 162 | if not self.mean_field: 163 | assert z_t_1 is not None 164 | h_comb_ = self.act(self.lin1(z_t_1)) 165 | if rnn_bidirection: 166 | h_comb = (1.0 / 3) * (h_comb_ + h_rnn[:, :self.rnn_dim] + h_rnn[:, self.rnn_dim:]) 167 | else: 168 | h_comb = 0.5 * (h_comb_ + h_rnn) 169 | else: 170 | h_comb = h_rnn 171 | mu = self.lin2(h_comb) 172 | logvar = self.lin_v(h_comb) 173 | 174 | return mu, logvar 175 | 176 | 177 | class RnnEncoder(nn.Module): 178 | """ 179 | RNN encoder that outputs hidden states h_t using x_{t:T} 180 | 181 | Parameters 182 | ---------- 183 | input_dim: int 184 | Dim. of inputs 185 | rnn_dim: int 186 | Dim. of RNN hidden states 187 | n_layer: int 188 | Number of layers of RNN 189 | drop_rate: float [0.0, 1.0] 190 | RNN dropout rate between layers 191 | bd: bool 192 | Use bi-directional RNN or not 193 | 194 | Returns 195 | ------- 196 | h_rnn: tensor (b, T_max, rnn_dim * n_direction) 197 | RNN hidden states at every time-step 198 | """ 199 | def __init__(self, input_dim, rnn_dim, n_layer=1, drop_rate=0.0, bd=False, 200 | nonlin='relu', rnn_type='rnn', orthogonal_init=False, 201 | reverse_input=True): 202 | super().__init__() 203 | self.n_direction = 1 if not bd else 2 204 | self.input_dim = input_dim 205 | self.rnn_dim = rnn_dim 206 | self.n_layer = n_layer 207 | self.drop_rate = drop_rate 208 | self.bd = bd 209 | self.nonlin = nonlin 210 | self.reverse_input = reverse_input 211 | 212 | if not isinstance(rnn_type, str): 213 | raise ValueError("`rnn_type` should be type str.") 214 | self.rnn_type = rnn_type 215 | if rnn_type == 'rnn': 216 | self.rnn = nn.RNN(input_size=input_dim, hidden_size=rnn_dim, 217 | nonlinearity=nonlin, batch_first=True, 218 | bidirectional=bd, num_layers=n_layer, 219 | dropout=drop_rate) 220 | elif rnn_type == 'gru': 221 | self.rnn = nn.GRU(input_size=input_dim, hidden_size=rnn_dim, 222 | batch_first=True, 223 | bidirectional=bd, num_layers=n_layer, 224 | dropout=drop_rate) 225 | elif rnn_type == 'lstm': 226 | self.rnn = nn.LSTM(input_size=input_dim, hidden_size=rnn_dim, 227 | batch_first=True, 228 | bidirectional=bd, num_layers=n_layer, 229 | dropout=drop_rate) 230 | else: 231 | raise ValueError("`rnn_type` must instead be ['rnn', 'gru', 'lstm'] %s" 232 | % rnn_type) 233 | 234 | if orthogonal_init: 235 | self.init_weights() 236 | 237 | def init_weights(self): 238 | for w in self.rnn.parameters(): 239 | if w.dim() > 1: 240 | weight_init.orthogonal_(w) 241 | 242 | def calculate_effect_dim(self): 243 | return self.rnn_dim * self.n_direction 244 | 245 | def init_hidden(self, trainable=True): 246 | if self.rnn_type == 'lstm': 247 | h0 = nn.Parameter(torch.zeros(self.n_layer * self.n_direction, 1, self.rnn_dim), requires_grad=trainable) 248 | c0 = nn.Parameter(torch.zeros(self.n_layer * self.n_direction, 1, self.rnn_dim), requires_grad=trainable) 249 | return h0, c0 250 | else: 251 | h0 = nn.Parameter(torch.zeros(self.n_layer * self.n_direction, 1, self.rnn_dim), requires_grad=trainable) 252 | return h0 253 | 254 | def forward(self, x, seq_lengths): 255 | """ 256 | x: pytorch packed object 257 | input packed data; this can be obtained from 258 | `util.get_mini_batch()` 259 | h0: tensor (n_layer * n_direction, b, rnn_dim) 260 | seq_lengths: tensor (b, ) 261 | """ 262 | # if self.rnn_type == 'lstm': 263 | # _h_rnn, _ = self.rnn(x, (h0, c0)) 264 | # else: 265 | # _h_rnn, _ = self.rnn(x, h0) 266 | _h_rnn, _ = self.rnn(x) 267 | if self.reverse_input: 268 | h_rnn = pad_and_reverse(_h_rnn, seq_lengths) 269 | else: 270 | h_rnn, _ = nn.utils.rnn.pad_packed_sequence(_h_rnn, batch_first=True) 271 | return h_rnn 272 | -------------------------------------------------------------------------------- /parse_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from pathlib import Path 4 | from functools import reduce, partial 5 | from operator import getitem 6 | from datetime import datetime 7 | from logger import setup_logging 8 | from utils import read_json, write_json 9 | 10 | 11 | class ConfigParser: 12 | def __init__(self, config, resume=None, modification=None, run_id=None): 13 | """ 14 | class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving 15 | and logging module. 16 | :param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example. 17 | :param resume: String, path to the checkpoint being loaded. 18 | :param modification: Dict keychain:value, specifying position values to be replaced from config dict. 19 | :param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default 20 | """ 21 | # load config file and apply modification 22 | self._config = _update_config(config, modification) 23 | self.resume = resume 24 | 25 | # set save_dir where trained model and log will be saved. 26 | save_dir = Path(self.config['trainer']['save_dir']) 27 | 28 | exper_name = self.config['name'] 29 | if run_id is None: # use timestamp as default run-id 30 | run_id = datetime.now().strftime(r'%m%d_%H%M%S') 31 | self._save_dir = save_dir / 'models' / exper_name / run_id 32 | self._log_dir = save_dir / 'log' / exper_name / run_id 33 | 34 | # make directory for saving checkpoints and log. 35 | exist_ok = run_id == '' 36 | self.save_dir.mkdir(parents=True, exist_ok=exist_ok) 37 | self.log_dir.mkdir(parents=True, exist_ok=exist_ok) 38 | 39 | # save updated config file to the checkpoint dir 40 | write_json(self.config, self.save_dir / 'config.json') 41 | 42 | # configure logging module 43 | setup_logging(self.log_dir) 44 | self.log_levels = { 45 | 0: logging.WARNING, 46 | 1: logging.INFO, 47 | 2: logging.DEBUG 48 | } 49 | 50 | @classmethod 51 | def from_args(cls, args, options=''): 52 | """ 53 | Initialize this class from some cli arguments. Used in train, test. 54 | """ 55 | for opt in options: 56 | args.add_argument(*opt.flags, default=None, type=opt.type) 57 | if not isinstance(args, tuple): 58 | args = args.parse_args() 59 | 60 | if args.device is not None: 61 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 62 | if args.resume is not None: 63 | resume = Path(args.resume) 64 | cfg_fname = resume.parent / 'config.json' 65 | else: 66 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 67 | assert args.config is not None, msg_no_cfg 68 | resume = None 69 | cfg_fname = Path(args.config) 70 | 71 | config = read_json(cfg_fname) 72 | if args.config and resume: 73 | # update new config for fine-tuning 74 | config.update(read_json(args.config)) 75 | 76 | # parse custom cli options into dictionary 77 | modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options} 78 | return cls(config, resume, modification, run_id=args.identifier) 79 | 80 | def init_obj(self, name, module, *args, **kwargs): 81 | """ 82 | Finds a function handle with the name given as 'type' in config, and returns the 83 | instance initialized with corresponding arguments given. 84 | 85 | `object = config.init_obj('name', module, a, b=1)` 86 | is equivalent to 87 | `object = module.name(a, b=1)` 88 | """ 89 | module_name = self[name]['type'] 90 | module_args = dict(self[name]['args']) 91 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 92 | module_args.update(kwargs) 93 | return getattr(module, module_name)(*args, **module_args) 94 | 95 | def init_ftn(self, name, module, *args, **kwargs): 96 | """ 97 | Finds a function handle with the name given as 'type' in config, and returns the 98 | function with given arguments fixed with functools.partial. 99 | 100 | `function = config.init_ftn('name', module, a, b=1)` 101 | is equivalent to 102 | `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`. 103 | """ 104 | module_name = self[name]['type'] 105 | module_args = dict(self[name]['args']) 106 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 107 | module_args.update(kwargs) 108 | return partial(getattr(module, module_name), *args, **module_args) 109 | 110 | def __getitem__(self, name): 111 | """Access items like ordinary dict.""" 112 | return self.config[name] 113 | 114 | def get_logger(self, name, verbosity=2): 115 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys()) 116 | assert verbosity in self.log_levels, msg_verbosity 117 | logger = logging.getLogger(name) 118 | logger.setLevel(self.log_levels[verbosity]) 119 | return logger 120 | 121 | # setting read-only attributes 122 | @property 123 | def config(self): 124 | return self._config 125 | 126 | @property 127 | def save_dir(self): 128 | return self._save_dir 129 | 130 | @property 131 | def log_dir(self): 132 | return self._log_dir 133 | 134 | # helper functions to update config dict with custom cli options 135 | def _update_config(config, modification): 136 | if modification is None: 137 | return config 138 | 139 | for k, v in modification.items(): 140 | if v is not None: 141 | _set_by_path(config, k, v) 142 | return config 143 | 144 | def _get_opt_name(flags): 145 | for flg in flags: 146 | if flg.startswith('--'): 147 | return flg.replace('--', '') 148 | return flags[0].replace('--', '') 149 | 150 | def _set_by_path(tree, keys, value): 151 | """Set a value in a nested object in tree by sequence of keys.""" 152 | keys = keys.split(';') 153 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 154 | 155 | def _get_by_path(tree, keys): 156 | """Access a nested object in tree by sequence of keys.""" 157 | return reduce(getitem, keys, tree) 158 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from tqdm import tqdm 4 | import data_loader.data_loaders as module_data 5 | import model.loss as module_loss 6 | import model.metric as module_metric 7 | import model.model as module_arch 8 | from parse_config import ConfigParser 9 | 10 | 11 | def main(config): 12 | logger = config.get_logger('test') 13 | 14 | # setup data_loader instances 15 | data_loader = getattr(module_data, config['data_loader']['type'])( 16 | config['data_loader']['args']['data_dir'], 17 | batch_size=512, 18 | shuffle=False, 19 | validation_split=0.0, 20 | training=False, 21 | num_workers=2 22 | ) 23 | 24 | # build model architecture 25 | model = config.init_obj('arch', module_arch) 26 | logger.info(model) 27 | 28 | # get function handles of loss and metrics 29 | loss_fn = getattr(module_loss, config['loss']) 30 | metric_fns = [getattr(module_metric, met) for met in config['metrics']] 31 | 32 | logger.info('Loading checkpoint: {} ...'.format(config.resume)) 33 | checkpoint = torch.load(config.resume) 34 | state_dict = checkpoint['state_dict'] 35 | if config['n_gpu'] > 1: 36 | model = torch.nn.DataParallel(model) 37 | model.load_state_dict(state_dict) 38 | 39 | # prepare model for testing 40 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 41 | model = model.to(device) 42 | model.eval() 43 | 44 | total_loss = 0.0 45 | total_metrics = torch.zeros(len(metric_fns)) 46 | 47 | with torch.no_grad(): 48 | for i, (data, target) in enumerate(tqdm(data_loader)): 49 | data, target = data.to(device), target.to(device) 50 | output = model(data) 51 | 52 | # 53 | # save sample images, or do something with output here 54 | # 55 | 56 | # computing loss, metrics on test set 57 | loss = loss_fn(output, target) 58 | batch_size = data.shape[0] 59 | total_loss += loss.item() * batch_size 60 | for i, metric in enumerate(metric_fns): 61 | total_metrics[i] += metric(output, target) * batch_size 62 | 63 | n_samples = len(data_loader.sampler) 64 | log = {'loss': total_loss / n_samples} 65 | log.update({ 66 | met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns) 67 | }) 68 | logger.info(log) 69 | 70 | 71 | if __name__ == '__main__': 72 | args = argparse.ArgumentParser(description='PyTorch Template') 73 | args.add_argument('-c', '--config', default=None, type=str, 74 | help='config file path (default: None)') 75 | args.add_argument('-r', '--resume', default=None, type=str, 76 | help='path to latest checkpoint (default: None)') 77 | args.add_argument('-d', '--device', default=None, type=str, 78 | help='indices of GPUs to enable (default: all)') 79 | 80 | config = ConfigParser.from_args(args) 81 | main(config) 82 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import warnings 3 | import collections 4 | import torch 5 | import numpy as np 6 | import data_loader.data_loaders as module_data 7 | import model.loss as module_loss 8 | import model.metric as module_metric 9 | import model.model as module_arch 10 | from parse_config import ConfigParser 11 | from trainer import Trainer 12 | 13 | 14 | # fix random seeds for reproducibility 15 | SEED = 123 16 | torch.manual_seed(SEED) 17 | torch.backends.cudnn.deterministic = True 18 | torch.backends.cudnn.benchmark = False 19 | np.random.seed(SEED) 20 | 21 | 22 | def main(config): 23 | logger = config.get_logger('train') 24 | 25 | # setup data_loader instances 26 | data_loader = config.init_obj('data_loader_train', module_data) 27 | try: 28 | valid_data_loader = config.init_obj('data_loader_valid', module_data) 29 | except Exception: 30 | warnings.warn("Validation dataloader not given.") 31 | valid_data_loader = None 32 | try: 33 | test_data_loader = config.init_obj('data_loader_test', module_data) 34 | except Exception: 35 | warnings.warn("Test dataloader not given.") 36 | test_data_loader = None 37 | 38 | # build model architecture, then print to console 39 | model = config.init_obj('arch', module_arch) 40 | logger.info(model) 41 | 42 | # get function handles of loss and metrics 43 | criterion = getattr(module_loss, config['loss']) 44 | try: 45 | metrics = [getattr(module_metric, met) for met in config['metrics']] 46 | # ------------------------------------------------- 47 | # add flexibility to allow no metric in config.json 48 | except Exception: 49 | warnings.warn("No metrics are configured.") 50 | metrics = None 51 | # ------------------------------------------------- 52 | 53 | # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler 54 | trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 55 | optimizer = config.init_obj('optimizer', torch.optim, trainable_params) 56 | 57 | # ------------------------------------------------- 58 | # add flexibility to allow no lr_scheduler in config.json 59 | try: 60 | lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer) 61 | except Exception: 62 | warnings.warn("No learning scheduler is configured.") 63 | lr_scheduler = None 64 | # ------------------------------------------------- 65 | 66 | trainer = Trainer(model, criterion, metrics, optimizer, 67 | config=config, 68 | data_loader=data_loader, 69 | valid_data_loader=valid_data_loader, 70 | test_data_loader=test_data_loader, 71 | lr_scheduler=lr_scheduler, 72 | overfit_single_batch=config['trainer']['overfit_single_batch']) 73 | 74 | trainer.train() 75 | 76 | 77 | if __name__ == '__main__': 78 | args = argparse.ArgumentParser(description='PyTorch Template') 79 | args.add_argument('-c', '--config', default=None, type=str, 80 | help='config file path (default: None)') 81 | args.add_argument('-r', '--resume', default=None, type=str, 82 | help='path to latest checkpoint (default: None)') 83 | args.add_argument('-d', '--device', default=None, type=str, 84 | help='indices of GPUs to enable (default: all)') 85 | args.add_argument('-i', '--identifier', default=None, type=str, 86 | help='unique identifier of the experiment (default: None)') 87 | 88 | # custom cli options to modify configuration from default values given in json file. 89 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') 90 | options = [ 91 | CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'), 92 | CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size') 93 | ] 94 | config = ConfigParser.from_args(args, options) 95 | main(config) 96 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import * 2 | -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | from torchvision.utils import make_grid 5 | from base import BaseTrainer 6 | from utils import inf_loop, MetricTracker 7 | 8 | 9 | class Trainer(BaseTrainer): 10 | """ 11 | Trainer class 12 | """ 13 | def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, 14 | valid_data_loader=None, test_data_loader=None, 15 | lr_scheduler=None, len_epoch=None, overfit_single_batch=False): 16 | super().__init__(model, criterion, metric_ftns, optimizer, config) 17 | self.config = config 18 | self.data_loader = data_loader 19 | if len_epoch is None: 20 | # epoch-based training 21 | self.len_epoch = len(self.data_loader) 22 | else: 23 | # iteration-based training 24 | self.data_loader = inf_loop(data_loader) 25 | self.len_epoch = len_epoch 26 | self.valid_data_loader = valid_data_loader if not overfit_single_batch else None 27 | self.test_data_loader = test_data_loader if not overfit_single_batch else None 28 | self.do_validation = self.valid_data_loader is not None 29 | self.do_test = self.test_data_loader is not None 30 | self.lr_scheduler = lr_scheduler 31 | self.log_step = int(np.sqrt(data_loader.batch_size)) 32 | self.overfit_single_batch = overfit_single_batch 33 | 34 | # ------------------------------------------------- 35 | # add flexibility to allow no metric in config.json 36 | self.log_loss = ['loss', 'nll', 'kl'] 37 | if self.metric_ftns is None: 38 | self.train_metrics = MetricTracker(*self.log_loss, writer=self.writer) 39 | self.valid_metrics = MetricTracker(*self.log_loss, writer=self.writer) 40 | # ------------------------------------------------- 41 | else: 42 | self.train_metrics = MetricTracker(*self.log_loss, *[m.__name__ for m in self.metric_ftns], writer=self.writer) 43 | self.valid_metrics = MetricTracker(*self.log_loss, *[m.__name__ for m in self.metric_ftns], writer=self.writer) 44 | self.test_metrics = MetricTracker(*[m.__name__ for m in self.metric_ftns], writer=self.writer) 45 | 46 | def _train_epoch(self, epoch): 47 | """ 48 | Training logic for an epoch 49 | 50 | :param epoch: Integer, current training epoch. 51 | :return: A log that contains average loss and metric in this epoch. 52 | """ 53 | self.model.train() 54 | self.train_metrics.reset() 55 | 56 | # ---------------- 57 | # add logging grad 58 | dict_grad = {} 59 | for name, p in self.model.named_parameters(): 60 | if p.requires_grad and 'bias' not in name: 61 | dict_grad[name] = np.zeros(self.len_epoch) 62 | # ---------------- 63 | 64 | for batch_idx, batch in enumerate(self.data_loader): 65 | x, x_reversed, x_mask, x_seq_lengths = batch 66 | 67 | x = x.to(self.device) 68 | x_reversed = x_reversed.to(self.device) 69 | x_mask = x_mask.to(self.device) 70 | x_seq_lengths = x_seq_lengths.to(self.device) 71 | 72 | self.optimizer.zero_grad() 73 | x_recon, z_q_seq, z_p_seq, mu_q_seq, logvar_q_seq, mu_p_seq, logvar_p_seq = \ 74 | self.model(x, x_reversed, x_seq_lengths) 75 | kl_annealing_factor = \ 76 | determine_annealing_factor(self.config['trainer']['min_anneal_factor'], 77 | self.config['trainer']['anneal_update'], 78 | epoch - 1, self.len_epoch, batch_idx) 79 | kl_raw, nll_raw, kl_fr, nll_fr, kl_m, nll_m, loss = \ 80 | self.criterion(x, x_recon, mu_q_seq, logvar_q_seq, mu_p_seq, logvar_p_seq, kl_annealing_factor, x_mask) 81 | loss.backward() 82 | 83 | # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10) 84 | # ------------ 85 | # accumulate gradients that are to be logged later after epoch ends 86 | for name, p in self.model.named_parameters(): 87 | if p.requires_grad and 'bias' not in name: 88 | val = 0 if p.grad is None else p.grad.abs().mean() 89 | dict_grad[name][batch_idx] = val 90 | # ------------ 91 | 92 | self.optimizer.step() 93 | 94 | for l_i, l_i_val in zip(self.log_loss, [loss, nll_m, kl_m]): 95 | self.train_metrics.update(l_i, l_i_val.item()) 96 | if self.metric_ftns is not None: 97 | for met in self.metric_ftns: 98 | if met.__name__ == 'bound_eval': 99 | self.train_metrics.update(met.__name__, 100 | met([x_recon, mu_q_seq, logvar_q_seq], 101 | [x, mu_p_seq, logvar_p_seq], mask=x_mask)) 102 | 103 | if batch_idx % self.log_step == 0: 104 | self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( 105 | epoch, 106 | self._progress(batch_idx), 107 | loss.item())) 108 | 109 | if batch_idx == self.len_epoch or self.overfit_single_batch: 110 | break 111 | 112 | # --------------------------------------------------- 113 | if self.writer is not None: 114 | self.writer.set_step(epoch, 'train') 115 | # log losses 116 | for l_i in self.log_loss: 117 | self.train_metrics.write_to_logger(l_i) 118 | # log metrics 119 | if self.metric_ftns is not None: 120 | if met.__name__ == 'bound_eval': 121 | self.train_metrics.write_to_logger(met.__name__) 122 | # log gradients 123 | for name, p in dict_grad.items(): 124 | self.writer.add_histogram(name + '/grad', p, bins='auto') 125 | # log parameters 126 | for name, p in self.model.named_parameters(): 127 | self.writer.add_histogram(name, p, bins='auto') 128 | # log kl annealing factors 129 | self.writer.add_scalar('anneal_factor', kl_annealing_factor) 130 | # --------------------------------------------------- 131 | 132 | if epoch % 50 == 0: 133 | fig = create_reconstruction_figure(x, torch.sigmoid(x_recon)) 134 | # debug_fig = create_debug_figure(x, x_reversed, x_mask) 135 | # debug_fig_loss = create_debug_loss_figure(kl_raw, nll_raw, kl_fr, nll_fr, kl_m, nll_m, x_mask) 136 | self.writer.set_step(epoch, 'train') 137 | self.writer.add_figure('reconstruction', fig) 138 | # self.writer.add_figure('debug', debug_fig) 139 | # self.writer.add_figure('debug_loss', debug_fig_loss) 140 | 141 | log = self.train_metrics.result() 142 | 143 | if self.do_validation: 144 | val_log = self._valid_epoch(epoch) 145 | log.update(**{'val_' + k: v for k, v in val_log.items()}) 146 | 147 | if self.do_test and epoch % 50 == 0: 148 | test_log = self._test_epoch(epoch) 149 | log.update(**{'test_' + k: v for k, v in test_log.items()}) 150 | 151 | if self.lr_scheduler is not None: 152 | self.lr_scheduler.step() 153 | return log 154 | 155 | def _valid_epoch(self, epoch): 156 | """ 157 | Validate after training an epoch 158 | 159 | :param epoch: Integer, current training epoch. 160 | :return: A log that contains information about validation 161 | """ 162 | self.model.eval() 163 | self.valid_metrics.reset() 164 | with torch.no_grad(): 165 | for batch_idx, batch in enumerate(self.valid_data_loader): 166 | x, x_reversed, x_mask, x_seq_lengths = batch 167 | 168 | x = x.to(self.device) 169 | x_reversed = x_reversed.to(self.device) 170 | x_mask = x_mask.to(self.device) 171 | x_seq_lengths = x_seq_lengths.to(self.device) 172 | 173 | x_recon, z_q_seq, z_p_seq, mu_q_seq, logvar_q_seq, mu_p_seq, logvar_p_seq = \ 174 | self.model(x, x_reversed, x_seq_lengths) 175 | kl_raw, nll_raw, kl_fr, nll_fr, kl_m, nll_m, loss = \ 176 | self.criterion(x, x_recon, mu_q_seq, logvar_q_seq, mu_p_seq, logvar_p_seq, 1, x_mask) 177 | 178 | for l_i, l_i_val in zip(self.log_loss, [loss, nll_m, kl_m]): 179 | self.valid_metrics.update(l_i, l_i_val.item()) 180 | if self.metric_ftns is not None: 181 | for met in self.metric_ftns: 182 | if met.__name__ == 'bound_eval': 183 | self.valid_metrics.update(met.__name__, 184 | met([x_recon, mu_q_seq, logvar_q_seq], 185 | [x, mu_p_seq, logvar_p_seq], mask=x_mask)) 186 | 187 | # --------------------------------------------------- 188 | if self.writer is not None: 189 | self.writer.set_step(epoch, 'valid') 190 | for l_i in self.log_loss: 191 | self.valid_metrics.write_to_logger(l_i) 192 | if self.metric_ftns is not None: 193 | for met in self.metric_ftns: 194 | if met.__name__ == 'bound_eval': 195 | self.valid_metrics.write_to_logger(met.__name__) 196 | # --------------------------------------------------- 197 | 198 | if epoch % 10 == 0: 199 | x_recon = torch.nn.functional.sigmoid(x_recon.view(x.size(0), x.size(1), -1)) 200 | fig = create_reconstruction_figure(x, x_recon) 201 | # debug_fig = create_debug_figure(x, x_reversed_unpack, x_mask) 202 | # debug_fig_loss = create_debug_loss_figure(kl_raw, nll_raw, kl_fr, nll_fr, kl_m, nll_m, x_mask) 203 | self.writer.set_step(epoch, 'valid') 204 | self.writer.add_figure('reconstruction', fig) 205 | # self.writer.add_figure('debug', debug_fig) 206 | # self.writer.add_figure('debug_loss', debug_fig_loss) 207 | 208 | return self.valid_metrics.result() 209 | 210 | def _test_epoch(self, epoch): 211 | self.model.eval() 212 | self.test_metrics.reset() 213 | with torch.no_grad(): 214 | for batch_idx, batch in enumerate(self.test_data_loader): 215 | x, x_reversed, x_mask, x_seq_lengths = batch 216 | 217 | x = x.to(self.device) 218 | x_reversed = x_reversed.to(self.device) 219 | x_mask = x_mask.to(self.device) 220 | x_seq_lengths = x_seq_lengths.to(self.device) 221 | 222 | x_recon, z_q_seq, z_p_seq, mu_q_seq, logvar_q_seq, mu_p_seq, logvar_p_seq = \ 223 | self.model(x, x_reversed, x_seq_lengths) 224 | 225 | if self.metric_ftns is not None: 226 | for met in self.metric_ftns: 227 | if met.__name__ == 'bound_eval': 228 | self.test_metrics.update(met.__name__, 229 | met([x_recon, mu_q_seq, logvar_q_seq], 230 | [x, mu_p_seq, logvar_p_seq], mask=x_mask)) 231 | if met.__name__ == 'importance_sample': 232 | self.test_metrics.update(met.__name__, 233 | met(batch_idx, self.model, x, x_reversed, x_seq_lengths, x_mask, n_sample=500)) 234 | # --------------------------------------------------- 235 | if self.writer is not None: 236 | self.writer.set_step(epoch, 'test') 237 | if self.metric_ftns is not None: 238 | for met in self.metric_ftns: 239 | self.test_metrics.write_to_logger(met.__name__) 240 | 241 | n_sample = 3 242 | output_seq, z_p_seq, mu_p_seq, logvar_p_seq = self.model.generate(n_sample, 100) 243 | output_seq = torch.sigmoid(output_seq) 244 | plt.close() 245 | fig, ax = plt.subplots(n_sample, 1, figsize=(10, n_sample * 10)) 246 | for i in range(n_sample): 247 | ax[i].imshow(output_seq[i].T.cpu().detach().numpy(), origin='lower') 248 | self.writer.add_figure('generation', fig) 249 | # --------------------------------------------------- 250 | return self.test_metrics.result() 251 | 252 | def _progress(self, batch_idx): 253 | base = '[{}/{} ({:.0f}%)]' 254 | if hasattr(self.data_loader, 'n_samples'): 255 | current = batch_idx * self.data_loader.batch_size 256 | total = self.data_loader.n_samples 257 | else: 258 | current = batch_idx 259 | total = self.len_epoch 260 | return base.format(current, total, 100.0 * current / total) 261 | 262 | 263 | def determine_annealing_factor(min_anneal_factor, 264 | anneal_update, 265 | epoch, n_batch, batch_idx): 266 | n_updates = epoch * n_batch + batch_idx 267 | 268 | if anneal_update > 0 and n_updates < anneal_update: 269 | anneal_factor = min_anneal_factor + \ 270 | (1.0 - min_anneal_factor) * ( 271 | (n_updates / anneal_update) 272 | ) 273 | else: 274 | anneal_factor = 1.0 275 | return anneal_factor 276 | 277 | 278 | def create_reconstruction_figure(x, x_recon, sample=True): 279 | plt.close() 280 | if sample: 281 | idx = np.random.choice(x.shape[0], 1)[0] 282 | else: 283 | idx = 0 284 | x = x[idx].cpu().detach().numpy() 285 | x_recon = x_recon[idx].cpu().detach().numpy() 286 | fig, ax = plt.subplots(2, 1, sharex=True, figsize=(10, 20)) 287 | ax[0].imshow(x.T, origin='lower') 288 | ax[1].imshow(x_recon.T, origin='lower') 289 | return fig 290 | 291 | 292 | def create_debug_figure(x, x_reversed_unpack, x_mask, sample=True): 293 | plt.close() 294 | if sample: 295 | idx = np.random.choice(x.shape[0], 1)[0] 296 | else: 297 | idx = 0 298 | x = x[idx].cpu().detach().numpy() 299 | x_reversed_unpack = x_reversed_unpack[idx].cpu().detach().numpy() 300 | x_mask = x_mask[idx].cpu().detach().numpy() 301 | fig, ax = plt.subplots(3, 1, sharex=True, figsize=(10, 30)) 302 | ax[0].imshow(x.T, origin='lower') 303 | ax[1].imshow(x_reversed_unpack.T, origin='lower') 304 | ax[2].imshow(np.tile(x_mask, (x.shape[0], 1)), origin='lower') 305 | return fig 306 | 307 | 308 | def create_debug_loss_figure(kl_raw, nll_raw, 309 | kl_fr, nll_fr, 310 | kl_m, nll_m, 311 | mask, sample=True): 312 | plt.close() 313 | if sample: 314 | idx = np.random.choice(kl_raw.shape[0], 1)[0] 315 | else: 316 | idx = 0 317 | mask = tensor2np(mask[idx]) 318 | kl_raw, nll_raw = tensor2np(kl_raw[idx]), tensor2np(nll_raw[idx]) # (t, f) 319 | kl_fr, nll_fr = tensor2np(kl_fr[idx]), tensor2np(nll_fr[idx]) # (t, ) 320 | kl_m, nll_m = tensor2np(kl_m[idx]), tensor2np(nll_m[idx]) # (t, ) 321 | # kl_aggr, nll_aggr = tensor2np(kl_aggr[idx]), tensor2np(nll_aggr[idx]) # () 322 | fig, ax = plt.subplots(4, 2, sharex=True, figsize=(20, 40)) 323 | ax[0][0].imshow(kl_raw.T, origin='lower') 324 | ax[1][0].plot(kl_fr) 325 | ax[2][0].plot(kl_m) 326 | ax[3][0].plot(mask) 327 | ax[0][1].imshow(nll_raw.T, origin='lower') 328 | ax[1][1].plot(nll_fr) 329 | ax[2][1].plot(nll_m) 330 | ax[3][1].plot(mask) 331 | return fig 332 | 333 | 334 | def tensor2np(t): 335 | return t.cpu().detach().numpy() 336 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | from pathlib import Path 4 | from itertools import repeat 5 | from collections import OrderedDict 6 | 7 | 8 | def ensure_dir(dirname): 9 | dirname = Path(dirname) 10 | if not dirname.is_dir(): 11 | dirname.mkdir(parents=True, exist_ok=False) 12 | 13 | def read_json(fname): 14 | fname = Path(fname) 15 | with fname.open('rt') as handle: 16 | return json.load(handle, object_hook=OrderedDict) 17 | 18 | def write_json(content, fname): 19 | fname = Path(fname) 20 | with fname.open('wt') as handle: 21 | json.dump(content, handle, indent=4, sort_keys=False) 22 | 23 | def inf_loop(data_loader): 24 | ''' wrapper function for endless data loader. ''' 25 | for loader in repeat(data_loader): 26 | yield from loader 27 | 28 | 29 | class MetricTracker: 30 | def __init__(self, *keys, writer=None): 31 | self.writer = writer 32 | self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average']) 33 | self.reset() 34 | 35 | def reset(self): 36 | for col in self._data.columns: 37 | self._data[col].values[:] = 0 38 | 39 | def update(self, key, value, n=1): 40 | # if self.writer is not None: 41 | # self.writer.add_scalar(key, value) 42 | self._data.total[key] += value * n 43 | self._data.counts[key] += n 44 | self._data.average[key] = self._data.total[key] / self._data.counts[key] 45 | 46 | def avg(self, key): 47 | return self._data.average[key] 48 | 49 | def result(self): 50 | self._data = self._data[self._data.counts != 0] 51 | return dict(self._data.average) 52 | 53 | def write_to_logger(self, key, value=None): 54 | assert self.writer is not None 55 | if value is None: 56 | self.writer.add_scalar(key, self._data.average[key]) 57 | else: 58 | self.writer.add_scalar(key, value) 59 | --------------------------------------------------------------------------------