├── .gitignore ├── LICENSE ├── README.md ├── setup.py └── tensor_rnn ├── example ├── decomposition │ └── decomposition_cptensor.py └── polymusic │ ├── data │ ├── README.md │ └── download_data.sh │ ├── extract_info.py │ ├── loader.py │ ├── poly_allrnn.py │ ├── run_cpgru.py │ ├── run_gru.py │ ├── run_ttgru.py │ └── run_tuckergru.py ├── modules ├── __init__.py ├── candecomp.py ├── composite │ ├── __init__.py │ ├── cprnn.py │ ├── ttrnn.py │ └── tuckerrnn.py ├── loss.py ├── rnn.py ├── tensor_train.py └── tucker.py └── utils ├── __init__.py ├── data_util.py └── helper.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Jupyter Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | .venv/ 83 | venv/ 84 | ENV/ 85 | 86 | # Spyder project settings 87 | .spyderproject 88 | 89 | # Rope project settings 90 | .ropeproject 91 | 92 | # CUSTOM 93 | *.pickle 94 | *.log 95 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Andros Tjandra 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensor RNN 2 | An implementation of various tensor-based decomposition for NN & RNN parameters 3 | 4 | ## Quick Start 5 | 1. Install `python >= 3.0` 6 | 2. Install `pytorch >= 3.0` 7 | 3. `pip install -e .` or `python setup.py install` 8 | 9 | ### Run example scripts 10 | 1. Go to example folder 11 | 12 | `cd example/polymusic` 13 | 14 | 2. Go to data folder, download the pickled dataset and return. 15 | 16 | `cd data && ./download_data.sh && cd ..` 17 | 3. Run any example script 18 | 19 | `python run_ttgru.py` 20 | 21 | For the usage, see the code inside `example/polymusic/poly_allrnn.py` 22 | ## Modules 23 | ### Linear Layer 24 | * `TuckerLinear ` 25 | * `CPLinear` 26 | * `TTLinear` 27 | 28 | ### Bilinear Layer 29 | * `CPBilinear` 30 | * `TuckerBilinear` (TODO) 31 | 32 | ### RNN Layer 33 | * `StatefulCPLSTMCell` 34 | * `StatefulCPGRUCell` 35 | * `StatefulTuckerLSTMCell` 36 | * `StatefulTuckerGRUCell` 37 | * `StatefulTTLSTMCell` 38 | * `StatefulTTGRUCell` 39 | 40 | ## Reference 41 | If you find this package is useful, please kindly cite: 42 | ``` 43 | @article{tjandra2018tensor, 44 | title={Tensor Decomposition for Compressing Recurrent Neural Network}, 45 | author={Tjandra, Andros and Sakti, Sakriani and Nakamura, Satoshi}, 46 | journal={arXiv preprint arXiv:1802.10410}, 47 | year={2018} 48 | } 49 | 50 | @inproceedings{tjandra2017compressing, 51 | title={Compressing recurrent neural network with tensor train}, 52 | author={Tjandra, Andros and Sakti, Sakriani and Nakamura, Satoshi}, 53 | booktitle={Neural Networks (IJCNN), 2017 International Joint Conference on}, 54 | pages={4451--4458}, 55 | year={2017}, 56 | organization={IEEE} 57 | } 58 | ``` 59 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup(name="tensor_rnn", 4 | version="0.1", 5 | description="a library for Tensor-based Decomposition for RNN weight parameters", 6 | author="Andros Tjandra", 7 | author_email='andros.tjandra@gmail.com', 8 | platforms=["any"], # or more specific, e.g. "win32", "cygwin", "osx" 9 | license="MIT", 10 | url="", 11 | python_requires='>=3', 12 | packages=find_packages(), 13 | install_requires=['numpy', 'scipy', 'torch', 'scikit-learn', 'requests']); 14 | -------------------------------------------------------------------------------- /tensor_rnn/example/decomposition/decomposition_cptensor.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | from torch import tensor 6 | from tensor_rnn.modules import candecomp 7 | 8 | if __name__ == '__main__': 9 | MAIN_DEVICE = torch.device('cuda') 10 | M = [100, 100, 100] 11 | mu = [0.6, -0.3, 2,0] 12 | std = [1, 1, 1] 13 | EPOCHS = 10000 14 | INTERVAL = 10 15 | ORDER = 10 16 | # real data # 17 | true_factors = candecomp._create_candecomp_cores_unconstrained(M, order=ORDER) 18 | for ii, factor in enumerate(true_factors) : 19 | factor.data.normal_(mu[ii], std[ii]) 20 | true_tensors = candecomp._cpcores_to_tensor(true_factors) 21 | true_tensors = true_tensors.detach().to(MAIN_DEVICE) 22 | 23 | # pred data # 24 | pred_factors = candecomp._create_candecomp_cores_unconstrained(M, order=ORDER) 25 | for ii, factor in enumerate(pred_factors) : 26 | # Eq. refer here -> Tensor Decomposition for Compressing Recurrent Neural Network 27 | # initialization variance is 0.5 here 28 | factor.data.normal_(0.0, (0.5/(ORDER**0.5))**(1.0/len(M))) 29 | pred_factors.to(MAIN_DEVICE) 30 | opt = torch.optim.Adam(pred_factors, lr=5e-3, amsgrad=True) 31 | for ee in range(1, EPOCHS+1) : 32 | pred_tensors = candecomp._cpcores_to_tensor(pred_factors) 33 | loss = (true_tensors - pred_tensors).pow(2).mean() 34 | 35 | opt.zero_grad() 36 | loss.backward() 37 | opt.step() 38 | if ee % INTERVAL == 0 : 39 | print('Epoch {}: MSE {:g}'.format(ee, loss.item())) 40 | pass 41 | -------------------------------------------------------------------------------- /tensor_rnn/example/polymusic/data/README.md: -------------------------------------------------------------------------------- 1 | Dataset are downloaded from here : 2 | http://www-etud.iro.umontreal.ca/~boulanni/icml2012 3 | -------------------------------------------------------------------------------- /tensor_rnn/example/polymusic/data/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://www-etud.iro.umontreal.ca/~boulanni/Piano-midi.de.pickle 4 | wget http://www-etud.iro.umontreal.ca/~boulanni/Nottingham.pickle 5 | wget http://www-etud.iro.umontreal.ca/~boulanni/MuseData.pickle 6 | wget http://www-etud.iro.umontreal.ca/~boulanni/JSB%20Chorales.pickle 7 | -------------------------------------------------------------------------------- /tensor_rnn/example/polymusic/extract_info.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import re 3 | INF = 2**64 4 | 5 | files = [line.strip() for line in sys.stdin] 6 | 7 | param = None 8 | best_nll = INF 9 | best_acc = -INF 10 | for ff in files : 11 | txt = open(ff).read() 12 | _param = re.findall('RNN parameters: ([0-9]+)', txt) 13 | if _param == [] : 14 | print('[WARN] {} crashed!'.format(ff)) 15 | continue 16 | _param = int(_param[0]) 17 | if param is not None : 18 | assert _param == param 19 | else : 20 | param = _param 21 | _nll_acc = re.findall('Best test loss: ([0-9\.]+), acc: ([0-9\.]+)', txt) 22 | if _nll_acc != [] : 23 | _nll, _acc = _nll_acc[0] 24 | _nll, _acc = float(_nll), float(_acc) 25 | else : 26 | print('[WARN] {} crashed!'.format(ff)) 27 | continue 28 | best_nll = min(best_nll, _nll) 29 | best_acc = max(best_acc, _acc) 30 | 31 | print('PARAMS\tNLL\tACC') 32 | print('{}\t{}\t{}'.format(param, best_nll, best_acc)) 33 | -------------------------------------------------------------------------------- /tensor_rnn/example/polymusic/loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | from scipy.io import loadmat 4 | 5 | LIST_DATASET = ['jsb', 'musedata', 'nottingham', 'pianomidi'] 6 | 7 | def parse_midi_pkl(path) : 8 | OFFSET = 21 9 | MAXDIM = 88 10 | data = pickle.load(open(path, 'rb')) 11 | def convert_int_to_onehot(items) : 12 | mat_onehot = np.zeros((len(items), MAXDIM), dtype='float32') 13 | for ii, item in enumerate(items) : 14 | item = [jj-OFFSET for jj in item] 15 | mat_onehot[ii, item] = 1.0 16 | return mat_onehot 17 | mat = {'train':[convert_int_to_onehot(item) for item in data['train']], 18 | 'valid':[convert_int_to_onehot(item) for item in data['valid']], 19 | 'test':[convert_int_to_onehot(item) for item in data['test']]} 20 | return mat 21 | 22 | def load_dataset_pickle(dataset) : 23 | assert dataset in LIST_DATASET 24 | if dataset == 'jsb' : 25 | mat = parse_midi_pkl('data/JSB Chorales.pickle') 26 | elif dataset == 'musedata' : 27 | mat = parse_midi_pkl('data/MuseData.pickle') 28 | elif dataset == 'nottingham' : 29 | mat = parse_midi_pkl('data/Nottingham.pickle') 30 | elif dataset == 'pianomidi' : 31 | mat = parse_midi_pkl('data/Piano-midi.de.pickle') 32 | train_all = mat['train'] 33 | val_all = mat['valid'] 34 | test_all = mat['test'] 35 | return train_all, val_all, test_all 36 | 37 | def load_dataset(dataset) : 38 | assert dataset in LIST_DATASET 39 | if dataset == 'jsb' : 40 | mat = loadmat('data/JSB_Chorales.mat') 41 | elif dataset == 'musedata' : 42 | mat = loadmat('data/MuseData.mat') 43 | elif dataset == 'nottingham' : 44 | mat = loadmat('data/Nottingham.mat') 45 | elif dataset == 'pianomidi' : 46 | mat = loadmat('data/Piano_midi.mat') 47 | else : 48 | raise ValueError('dataset not available') 49 | train_all = mat['traindata'][0] 50 | val_all = mat['validdata'][0] 51 | test_all = mat['testdata'][0] 52 | return train_all, val_all, test_all 53 | 54 | def batch_data(list_data) : 55 | batch = len(list_data) 56 | seq_len = [len(x)-1 for x in list_data] 57 | ndim = list_data[0].shape[1] 58 | input = np.zeros((batch, max(seq_len), ndim), dtype='float32') 59 | target = np.zeros_like(input) 60 | mask = np.zeros((batch, max(seq_len)), dtype='float32') 61 | for ii in range(batch) : 62 | input[ii, 0:seq_len[ii]] = list_data[ii][0:-1] 63 | target[ii, 0:seq_len[ii]] = list_data[ii][1:] 64 | mask[ii, 0:seq_len[ii]] = 1 65 | return input, target, mask 66 | 67 | 68 | # TODO # 69 | def acc_polymusic(pred, label, mask) : 70 | """ Ref : http://web.eecs.umich.edu/~honglak/ismir2011-PolyphonicTranscription.pdf 71 | ACC = TP/(FP+FN+TP) 72 | TP = number of note correctly predicted 73 | FP = number of note-off predicted as note-on 74 | FN = number of note-on predicted as note-off 75 | """ 76 | pred = np.round(pred) #sigmoid threshold 0.5 77 | 78 | TP = np.float((np.logical_and(pred==1, label==1) * mask[:, :, np.newaxis]).sum()) 79 | FP = np.float((np.logical_and(pred==1, label==0) * mask[:, :, np.newaxis]).sum()) 80 | FN = np.float((np.logical_and(pred==0, label==1) * mask[:, :, np.newaxis]).sum()) 81 | 82 | TN = np.float((np.logical_and(pred==0, label==0) * mask[:, :, np.newaxis]).sum()) 83 | denom = TP+FP+FN 84 | nom = TP 85 | return nom, denom # nominator, denominator 86 | -------------------------------------------------------------------------------- /tensor_rnn/example/polymusic/poly_allrnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import time 3 | import numpy as np 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torchvision import datasets, transforms 10 | from torch.autograd import Variable 11 | 12 | from tensor_rnn.modules import StatefulLSTMCell, StatefulGRUCell 13 | from tensor_rnn.modules.composite import StatefulCPLSTMCell, StatefulCPGRUCell,\ 14 | StatefulTuckerLSTMCell, StatefulTuckerGRUCell, \ 15 | StatefulTTLSTMCell, StatefulTTGRUCell 16 | from tensor_rnn.modules import elementwise_bce, elementwise_bce_with_logits 17 | 18 | from tensor_rnn.utils import tensorauto, torchauto 19 | from tensor_rnn.utils.data_util import iter_minibatches 20 | 21 | from loader import load_dataset_pickle, batch_data, LIST_DATASET, acc_polymusic 22 | 23 | # rename 24 | load_dataset = load_dataset_pickle 25 | 26 | # Training settings 27 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 28 | parser.add_argument('--data', type=str, choices=LIST_DATASET, help='dataset {}'.format(str(LIST_DATASET))) 29 | parser.add_argument('--batch-size', type=int, default=16, metavar='N', 30 | help='input batch size for training (default: 64)') 31 | parser.add_argument('--epochs', type=int, default=50, metavar='N', 32 | help='number of epochs to train (default: 10)') 33 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', 34 | help='learning rate (default: 1e-3)') 35 | parser.add_argument('--no-cuda', action='store_true', default=False, 36 | help='disables CUDA training') 37 | parser.add_argument('--seed', type=int, default=1, metavar='S', 38 | help='random seed (default: 1)') 39 | parser.add_argument('--inmodes', type=int, nargs='+', default=None) 40 | parser.add_argument('--outmodes', type=int, nargs='+', default=None) 41 | parser.add_argument('--ranks', type=int, nargs='+', default=None) 42 | parser.add_argument('--order', type=int, default=5, help='order for CP decomposition') 43 | parser.add_argument('--clip', type=float, default=5, help='clip grad norm') 44 | parser.add_argument('--nlayers', type=int, default=1, help='number of layers') 45 | parser.add_argument('--do', type=float, default=0.2, help='dropout rnn layer') 46 | parser.add_argument('--rnntype', type=str, choices=['gru', 'lstm', 47 | 'ttgru', 'ttlstm', 48 | 'tuckergru', 'tuckerlstm', 49 | 'cpgru', 'cplstm']) 50 | parser.add_argument('--opt', type=str, default='Adam') 51 | 52 | args = parser.parse_args() 53 | args.cuda = not args.no_cuda and torch.cuda.is_available() 54 | 55 | np.random.seed(args.seed) 56 | torch.manual_seed(args.seed) 57 | if args.cuda: 58 | torch.cuda.manual_seed(args.seed) 59 | 60 | train_data, val_data, test_data = load_dataset(args.data) 61 | 62 | in_modes = args.inmodes 63 | out_modes = args.outmodes 64 | in_sizes = int(np.prod(in_modes)) 65 | out_sizes = int(np.prod(out_modes)) 66 | order = args.order 67 | ranks = args.ranks 68 | rnn_type = args.rnntype 69 | nlayers = args.nlayers 70 | dropout = args.do 71 | 72 | class Net(nn.Module): 73 | def __init__(self): 74 | super(Net, self).__init__() 75 | self.prenet = nn.Linear(88, in_sizes) 76 | self.nlayers = nlayers 77 | self.rnn = nn.ModuleList() 78 | for ii in range(nlayers) : 79 | if rnn_type == 'gru' : 80 | self.rnn.append(StatefulGRUCell(in_sizes if ii == 0 else out_sizes, out_sizes)) 81 | elif rnn_type == 'lstm' : 82 | self.rnn.append(StatefulLSTMCell(in_sizes if ii == 0 else out_sizes, out_sizes)) 83 | elif rnn_type == 'ttlstm' : 84 | self.rnn.append(StatefulTTLSTMCell(in_modes if ii == 0 else out_modes, out_modes, ranks)) 85 | elif rnn_type == 'ttgru' : 86 | self.rnn.append(StatefulTTGRUCell(in_modes if ii == 0 else out_modes, out_modes, ranks)) 87 | elif rnn_type == 'cplstm' : 88 | self.rnn.append(StatefulCPLSTMCell(in_modes if ii == 0 else out_modes, out_modes, order)) 89 | elif rnn_type == 'cpgru' : 90 | self.rnn.append(StatefulCPGRUCell(in_modes if ii == 0 else out_modes, out_modes, order)) 91 | elif rnn_type == 'tuckerlstm' : 92 | self.rnn.append(StatefulTuckerLSTMCell(in_modes if ii == 0 else out_modes, out_modes, ranks)) 93 | elif rnn_type == 'tuckergru' : 94 | self.rnn.append(StatefulTuckerGRUCell(in_modes if ii == 0 else out_modes, out_modes, ranks)) 95 | else : 96 | raise ValueError() 97 | self.postnet = nn.Linear(out_sizes, 88) 98 | 99 | def reset(self) : 100 | for rnn in self.rnn : 101 | rnn.reset() 102 | 103 | def forward(self, x): 104 | # x = [batch, max_seq_len, 88] # 105 | batch, max_seq_len, _ = x.shape 106 | res = F.leaky_relu(self.prenet(x.view(-1, 88)).view(batch, max_seq_len, -1), 0.1) 107 | list_res = [] 108 | for ii in range(max_seq_len) : # seq_len # 109 | hidden = res[:, ii].contiguous() 110 | for jj in range(len(self.rnn)) : 111 | hidden = self.rnn[jj](hidden) 112 | if isinstance(hidden, (list, tuple)) : 113 | hidden = hidden[0] 114 | if dropout > 0 : 115 | hidden = F.dropout(hidden, p=dropout, training=self.training) 116 | list_res.append(hidden) 117 | res = torch.stack(list_res, dim=1) 118 | res = self.postnet(res.view(batch*max_seq_len, -1)).view(batch, max_seq_len, -1) # use last h_t # 119 | # res = F.sigmoid(res) 120 | return res 121 | 122 | model = Net() 123 | if args.cuda: 124 | model.cuda() 125 | # optimizer = optim.Adam(model.parameters(), lr=args.lr) 126 | optimizer = getattr(torch.optim, args.opt)(model.parameters(), lr=args.lr) 127 | 128 | total_params = sum([np.prod(x.size()) for x in model.rnn.parameters()]) 129 | for rnn in model.rnn : 130 | # minus 1 bias # 131 | if isinstance(rnn, (StatefulGRUCell, StatefulLSTMCell)) : 132 | total_params -= np.prod(rnn.bias_hh.size()) 133 | else : 134 | total_params -= np.prod(rnn.weight_hh.bias.size()) 135 | print(vars(args)) 136 | print('RNN parameters: {}'.format(total_params)) 137 | 138 | def train(epoch, data): 139 | model.train() 140 | data_size = len(data) 141 | train_loss = 0 142 | acc_nom = 0 143 | acc_denom = 0 144 | count = 0 145 | for rr in iter_minibatches(data_size, args.batch_size, shuffle=True, pad=False) : 146 | curr_input, curr_target, curr_mask = batch_data([data[rrii] for rrii in rr]) 147 | curr_input = Variable(tensorauto(model, torch.from_numpy(curr_input))) 148 | curr_target = Variable(tensorauto(model, torch.from_numpy(curr_target))) 149 | curr_mask = Variable(tensorauto(model, torch.from_numpy(curr_mask))) 150 | curr_count = curr_mask.data.sum() 151 | model.reset() 152 | optimizer.zero_grad() 153 | output = model(curr_input) 154 | loss = elementwise_bce_with_logits(output, curr_target) * curr_mask.unsqueeze(-1) 155 | loss = loss.sum() / curr_count 156 | loss.backward() 157 | torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) 158 | optimizer.step() 159 | train_loss += loss.data.sum() * curr_count 160 | curr_acc_nom, curr_acc_denom = acc_polymusic(F.sigmoid(output).data.cpu().numpy(), curr_target.data.cpu().numpy(), curr_mask.data.cpu().numpy()) 161 | acc_nom += curr_acc_nom 162 | acc_denom += curr_acc_denom 163 | count += curr_count 164 | pass 165 | 166 | train_loss /= count 167 | acc = acc_nom / acc_denom 168 | return train_loss, acc * 100 169 | 170 | def test(data): 171 | model.eval() 172 | model.reset() 173 | data_size = len(data) 174 | test_loss = 0 175 | acc_nom = 0 176 | acc_denom = 0 177 | count = 0 178 | 179 | for rr in iter_minibatches(data_size, args.batch_size, shuffle=False, pad=False) : 180 | curr_input, curr_target, curr_mask = batch_data([data[rrii] for rrii in rr]) 181 | curr_input = Variable(tensorauto(model, torch.from_numpy(curr_input))) 182 | curr_target = Variable(tensorauto(model, torch.from_numpy(curr_target))) 183 | curr_mask = Variable(tensorauto(model, torch.from_numpy(curr_mask))) 184 | curr_count = curr_mask.data.sum() 185 | model.reset() 186 | output = model(curr_input) 187 | loss = elementwise_bce_with_logits(output, curr_target) * curr_mask.unsqueeze(-1) 188 | loss = loss.sum() / curr_count 189 | test_loss += loss.data.sum() * curr_count 190 | curr_acc_nom, curr_acc_denom = acc_polymusic(F.sigmoid(output).data.cpu().numpy(), curr_target.data.cpu().numpy(), curr_mask.data.cpu().numpy()) 191 | acc_nom += curr_acc_nom 192 | acc_denom += curr_acc_denom 193 | count += curr_count 194 | 195 | test_loss /= count 196 | acc = acc_nom / acc_denom 197 | return test_loss, acc * 100 198 | 199 | INF = 2**32 200 | best_val_loss, best_val_loss_idx = INF, 0 201 | best_val_acc, best_val_acc_idx = -INF, 0 202 | 203 | hist_loss = {'train':[], 'val':[], 'test':[]} 204 | hist_acc = {'train':[], 'val':[], 'test':[]} 205 | 206 | for epoch in range(1, args.epochs + 1): 207 | start = time.time() 208 | train_loss, train_acc = train(epoch, train_data) 209 | end = time.time() - start 210 | print('Epoch {} -- time {:.1f} s'.format(epoch, end)) 211 | print('\tTrain set: loss: {:.4f}, acc: {:.2f}'.format(train_loss, train_acc)) 212 | val_loss, val_acc = test(val_data) 213 | print('\tVal set: loss: {:.4f}, acc: {:.2f}'.format(val_loss, val_acc)) 214 | test_loss, test_acc = test(test_data) 215 | print('\tTest set: loss: {:.4f}, acc: {:.2f}'.format(test_loss, test_acc)) 216 | 217 | hist_loss['train'].append(train_loss) 218 | hist_loss['val'].append(val_loss) 219 | hist_loss['test'].append(test_loss) 220 | hist_acc['train'].append(train_acc) 221 | hist_acc['val'].append(val_acc) 222 | hist_acc['test'].append(test_acc) 223 | 224 | if best_val_loss > val_loss : 225 | best_val_loss = val_loss 226 | best_val_loss_idx = epoch-1 227 | if best_val_acc < val_acc : 228 | best_val_acc = val_acc 229 | best_val_acc_idx = epoch-1 230 | 231 | print('Best val loss: {:.4f}, acc: {:.2f}'.format(hist_loss['val'][best_val_loss_idx], hist_acc['val'][best_val_acc_idx])) 232 | print('Best test loss: {:.4f}, acc: {:.2f}'.format(hist_loss['test'][best_val_loss_idx], hist_acc['test'][best_val_acc_idx])) 233 | -------------------------------------------------------------------------------- /tensor_rnn/example/polymusic/run_cpgru.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from sklearn.model_selection import ParameterGrid 3 | 4 | _cmd = 'CUDA_VISIBLE_DEVICES=0 python -u poly_allrnn.py --data {data} --inmodes {inmodes} --outmodes {outmodes} --order {order} --rnntype {rnntype} --lr {lr} --do {do} | tee {log}' 5 | 6 | hparams = [{ 7 | 'data': ['jsb', 'nottingham', 'pianomidi', 'musedata'], 8 | 'inmodes': ["4 4 4 4"], 9 | 'outmodes': ["8 4 4 4"], 10 | 'rnntype': ['cpgru'], 11 | 'order':[50, 80, 110], 12 | 'lr': [5e-3, 1e-2], 13 | 'do':[0.2, 0.5] 14 | }] 15 | 16 | if __name__ == '__main__' : 17 | list_param = list(ParameterGrid(hparams)) 18 | for item in list_param : 19 | _log = 'log/{data}-{rnntype}-inmodes_{inmodes}-outmodes_{outmodes}-order_{order}-lr_{lr}-do_{do}.log'.format(data=item['data'], rnntype=item['rnntype'], inmodes=item['inmodes'].replace(' ', '_'), outmodes=item['outmodes'].replace(' ', '_'), lr=item['lr'], do=item['do'], order=item['order']) 20 | subprocess.run(_cmd.format(log=_log, **item), shell=True) 21 | -------------------------------------------------------------------------------- /tensor_rnn/example/polymusic/run_gru.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from sklearn.model_selection import ParameterGrid 3 | 4 | _cmd = 'CUDA_VISIBLE_DEVICES=0 python -u poly_allrnn.py --data {data} --inmodes {inmodes} --outmodes {outmodes} --rnntype {rnntype} --lr {lr} --do {do} --epochs {epochs} | tee {log}' 5 | 6 | hparams = [{ 7 | 'data': ['jsb', 'nottingham'], 8 | 'inmodes': [256], 9 | 'outmodes': [512], 10 | 'rnntype': ['gru'], 11 | 'lr': [2.5e-3], 12 | 'do':[0.3], 13 | 'epochs':[100] 14 | }] 15 | 16 | if __name__ == '__main__' : 17 | list_param = list(ParameterGrid(hparams)) 18 | for item in list_param : 19 | _log = 'log/{data}-{rnntype}-inmodes_{inmodes}-outmodes_{outmodes}-lr_{lr}-do_{do}-ep_{epochs}.log'.format(data=item['data'], rnntype=item['rnntype'], inmodes=item['inmodes'], outmodes=item['outmodes'], lr=item['lr'], do=item['do'], epochs=item['epochs']) 20 | subprocess.run(_cmd.format(log=_log, **item), shell=True) 21 | pass 22 | pass 23 | -------------------------------------------------------------------------------- /tensor_rnn/example/polymusic/run_ttgru.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from sklearn.model_selection import ParameterGrid 3 | 4 | _cmd = 'CUDA_VISIBLE_DEVICES=0 python -u poly_allrnn.py --data {data} --inmodes {inmodes} --outmodes {outmodes} --ranks {ranks} --rnntype {rnntype} --lr {lr} --do {do} --batch-size {batchsize} | tee {log}' 5 | 6 | hparams = [ 7 | { 8 | 'data': ['pianomidi', 'musedata'], 9 | 'inmodes': ["4 4 4 4"], 10 | 'outmodes': ["8 4 4 4"], 11 | 'rnntype': ['ttgru'], 12 | 'ranks':["1 11 11 11 1"], 13 | 'lr': [5e-3, 1e-2], 14 | 'do':[0.2, 0.5], 15 | 'batchsize':[8]} 16 | ] 17 | 18 | if __name__ == '__main__' : 19 | list_param = list(ParameterGrid(hparams)) 20 | for item in list_param : 21 | _log = 'log/{data}-{rnntype}-inmodes_{inmodes}-outmodes_{outmodes}-ranks_{ranks}-lr_{lr}-do_{do}-bsize_{batchsize}.log'.format(data=item['data'], rnntype=item['rnntype'], inmodes=item['inmodes'].replace(' ', '_'), outmodes=item['outmodes'].replace(' ', '_'), lr=item['lr'], do=item['do'], ranks=item['ranks'].replace(' ', '_'), batchsize=item['batchsize']) 22 | print('CMD : {}'.format(_cmd.format(log=_log, **item))) 23 | subprocess.run(_cmd.format(log=_log, **item), shell=True) 24 | -------------------------------------------------------------------------------- /tensor_rnn/example/polymusic/run_tuckergru.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from sklearn.model_selection import ParameterGrid 3 | 4 | _cmd = 'CUDA_VISIBLE_DEVICES=1 python -u poly_allrnn.py --data {data} --inmodes {inmodes} --outmodes {outmodes} --ranks {ranks} --rnntype {rnntype} --lr {lr} --do {do} | tee {log}' 5 | 6 | hparams = [{ 7 | 'data': ['jsb', 'nottingham', 'pianomidi', 'musedata'], 8 | 'inmodes': ["4 4 4 4"], 9 | 'outmodes': ["8 4 4 4", "8 4 8 4"], 10 | 'rnntype': ['tuckergru'], 11 | 'ranks':["2 2 2 2", "2 4 2 4"], 12 | 'lr': [5e-3, 1e-2], 13 | 'do':[0.2, 0.5] 14 | }] 15 | 16 | if __name__ == '__main__' : 17 | list_param = list(ParameterGrid(hparams)) 18 | for item in list_param : 19 | _log = 'log/{data}-{rnntype}-inmodes_{inmodes}-outmodes_{outmodes}-ranks_{ranks}-lr_{lr}-do_{do}.log'.format(data=item['data'], rnntype=item['rnntype'], inmodes=item['inmodes'].replace(' ', '_'), outmodes=item['outmodes'].replace(' ', '_'), lr=item['lr'], do=item['do'], ranks=item['ranks'].replace(' ', '_')) 20 | subprocess.run(_cmd.format(log=_log, **item), shell=True) 21 | -------------------------------------------------------------------------------- /tensor_rnn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .rnn import * 2 | from .loss import * 3 | -------------------------------------------------------------------------------- /tensor_rnn/modules/candecomp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import Module, Parameter, ParameterList 7 | from torch.nn import functional as F 8 | from torch.nn import init 9 | 10 | def _create_candecomp_cores(in_modes, out_modes, order) : 11 | assert len(in_modes) == len(out_modes) 12 | assert order > 0 13 | list_cores = [] 14 | modes = in_modes + out_modes # extend list 15 | for mm in modes : 16 | list_cores.append(Parameter(torch.Tensor(mm, order).zero_())) 17 | list_cores = ParameterList(list_cores) 18 | return list_cores 19 | 20 | def _create_candecomp_cores_unconstrained(tensor_modes, order) : 21 | list_cores = [] 22 | modes = tensor_modes 23 | for mm in modes : 24 | list_cores.append(Parameter(torch.Tensor(mm, order).zero_())) 25 | list_cores = ParameterList(list_cores) 26 | return list_cores 27 | 28 | def _tensor_to_matrix(in_modes, out_modes, tensor) : 29 | return tensor.view(int(np.prod(in_modes)), int(np.prod(out_modes))) 30 | 31 | def _cpcores_to_tensor(list_factors) : 32 | assert len(list_factors) > 2 33 | tensor_out = None 34 | list_tensor_shape = [list_factors[ii].shape[0] for ii in range(len(list_factors))] 35 | for ii in range(len(list_factors)) : 36 | if ii == 0 : 37 | tensor_out = list_factors[ii] 38 | else : 39 | t_r, t_c = tensor_out.shape 40 | f_r, f_c = list_factors[ii].shape 41 | assert t_c == f_c, "tensor core order should be same" 42 | tensor_out = tensor_out.view(t_r, 1, t_c) * list_factors[ii].view(1, f_r, f_c) 43 | tensor_out = tensor_out.view(t_r * f_r, f_c) 44 | # sum across all order # 45 | tensor_out = tensor_out.sum(-1) 46 | tensor_out = tensor_out.view(*list_tensor_shape) 47 | return tensor_out 48 | 49 | class CPLinear(Module) : 50 | def __init__(self, in_modes, out_modes, order, bias=True, cache=True) : 51 | """ 52 | cache: if cache is True, pre calculated W_tsr until user reset the variable 53 | """ 54 | super().__init__() 55 | self.in_modes = in_modes 56 | self.out_modes = out_modes 57 | self.order = order 58 | self.cache = cache 59 | self._W_linear = None 60 | 61 | self.factors = _create_candecomp_cores(in_modes, out_modes, order) 62 | 63 | if bias : 64 | self.bias = Parameter(torch.Tensor(int(np.prod(out_modes)))) 65 | else : 66 | self.register_parameter('bias', None) 67 | self.reset_parameters() 68 | 69 | def reset_parameters(self) : 70 | CONST = (0.05 / (self.order**0.5)) ** (1.0/(len(self.in_modes)+len(self.out_modes))) 71 | for ii in range(len(self.factors)) : 72 | init.normal(self.factors[ii], 0, CONST) 73 | pass 74 | if self.bias is not None : 75 | self.bias.data.zero_() 76 | 77 | def reset(self) : 78 | self._W_linear = None 79 | 80 | @property 81 | def W_linear(self) : 82 | if not self.cache : 83 | return _tensor_to_matrix(self.in_modes, self.out_modes, _cpcores_to_tensor(list(self.factors))) 84 | if self._W_linear is None : 85 | self._W_linear = _tensor_to_matrix(self.in_modes, self.out_modes, _cpcores_to_tensor(list(self.factors))) 86 | else : 87 | pass 88 | return self._W_linear 89 | 90 | def forward(self, input) : 91 | return F.linear(input, self.W_linear.t(), self.bias) 92 | 93 | class CPBilinear(Module) : 94 | def __init__(self, in1_features, in2_features, out_features, order, bias=True) : 95 | """ 96 | order: rank for each factor matrix (order << out_features to maximize parameter efficiency) 97 | """ 98 | super().__init__() 99 | self.in1_features = in1_features 100 | self.in2_features = in2_features 101 | self.out_features = out_features 102 | self.order = order 103 | self.factors = _create_candecomp_cores_unconstrained([out_features, in1_features, in2_features], order) 104 | 105 | if bias : 106 | self.bias = Parameter(torch.Tensor(out_features)) 107 | else : 108 | self.register_parameter('bias', None) 109 | self.reset_parameters() 110 | pass 111 | 112 | def reset_parameters(self) : 113 | CONST = (0.05 / (self.order**0.5)) ** (1.0/3.0) # based on the theorem derived in latest paper 114 | for ii in range(len(self.factors)) : 115 | init.normal(self.factors[ii], 0, CONST) 116 | pass 117 | if self.bias is not None : 118 | self.bias.data.zero_() 119 | 120 | @property 121 | def weight(self) : 122 | return _cpcores_to_tensor(list(self.factors)) 123 | 124 | def forward(self, input1, input2) : 125 | return F.bilinear(input1, input2, self.weight, self.bias) 126 | 127 | pass 128 | -------------------------------------------------------------------------------- /tensor_rnn/modules/composite/__init__.py: -------------------------------------------------------------------------------- 1 | from .cprnn import * 2 | from .tuckerrnn import * 3 | from .ttrnn import * 4 | -------------------------------------------------------------------------------- /tensor_rnn/modules/composite/cprnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import Module, Parameter, ParameterList 6 | from torch.nn import functional as F 7 | from torch.nn import init 8 | from torch.autograd import Variable 9 | 10 | from ..rnn import StatefulBaseCell 11 | from ..candecomp import CPLinear 12 | from ...utils.helper import torchauto, tensorauto 13 | 14 | class StatefulCPGRUCell(StatefulBaseCell) : 15 | def __init__(self, in_modes, out_modes, order, bias=True, cache=True, 16 | compress_in=True, compress_out=True) : 17 | super().__init__() 18 | self.in_modes = in_modes 19 | self.out_modes = out_modes 20 | 21 | self.input_size = int(np.prod(in_modes)) 22 | self.hidden_size = int(np.prod(out_modes)) 23 | 24 | self.compress_in = compress_in 25 | self.compress_out = compress_out 26 | 27 | self.bias = bias 28 | self.order = order 29 | self.out_modes_Mx = list(out_modes) 30 | self.out_modes_Mx[-1] *= 3 31 | if compress_in : 32 | self.weight_ih = CPLinear(in_modes, self.out_modes_Mx, order, bias=self.bias, cache=cache) 33 | else : 34 | self.weight_ih = nn.Linear(self.input_size, self.hidden_size*3, bias=self.bias) 35 | 36 | if compress_out : 37 | self.weight_hh = CPLinear(out_modes, self.out_modes_Mx, order, bias=self.bias, cache=cache) 38 | else : 39 | self.weight_hh = nn.Linear(self.hidden_size, self.hidden_size*3, bias=self.bias) 40 | 41 | self.reset_parameters() 42 | pass 43 | 44 | def reset_parameters(self) : 45 | self.weight_hh.reset_parameters() 46 | self.weight_ih.reset_parameters() 47 | 48 | def reset(self) : 49 | super().reset() 50 | if self.compress_out : 51 | self.weight_hh.reset() 52 | if self.compress_in : 53 | self.weight_ih.reset() 54 | 55 | def forward(self, input) : 56 | batch = input.size(0) 57 | if self.state is None : 58 | h0 = Variable(torchauto(self).FloatTensor(batch, self.hidden_size).zero_()) 59 | else : 60 | h0 = self.state 61 | pre_rih, pre_zih, pre_nih = torch.split(self.weight_ih(input), self.hidden_size, dim=1) 62 | pre_rhh, pre_zhh, pre_nhh = torch.split(self.weight_hh(h0), self.hidden_size, dim=1) 63 | r_t = F.sigmoid(pre_rih + pre_rhh) 64 | z_t = F.sigmoid(pre_zih + pre_zhh) 65 | c_t = F.tanh(pre_nih + r_t * (pre_nhh)) 66 | h_t = (1-z_t) * c_t + (z_t * h0) 67 | self.state = h_t 68 | return h_t 69 | pass 70 | 71 | class StatefulCPLSTMCell(StatefulBaseCell) : 72 | def __init__(self, in_modes, out_modes, order, bias=True, cache=True, 73 | compress_in=True, compress_out=True) : 74 | super().__init__() 75 | self.in_modes = in_modes 76 | self.out_modes = out_modes 77 | 78 | self.input_size = int(np.prod(in_modes)) 79 | self.hidden_size = int(np.prod(out_modes)) 80 | 81 | self.compress_in = compress_in 82 | self.compress_out = compress_out 83 | 84 | self.bias = bias 85 | self.order = order 86 | self.out_modes_Mx = list(out_modes) 87 | self.out_modes_Mx[-1] *= 4 88 | if compress_in : 89 | self.weight_ih = CPLinear(in_modes, self.out_modes_Mx, order, bias=self.bias, cache=cache) 90 | else : 91 | self.weight_ih = nn.Linear(self.hidden_size, self.hidden_size*4, bias=self.bias) 92 | 93 | if compress_out : 94 | self.weight_hh = CPLinear(out_modes, self.out_modes_Mx, order, bias=self.bias, cache=cache) 95 | else : 96 | self.weight_hh = nn.Linear(self.hidden_size, self.hidden_size*4, bias=self.bias) 97 | 98 | self.reset_parameters() 99 | pass 100 | 101 | def reset_parameters(self) : 102 | self.weight_hh.reset_parameters() 103 | self.weight_ih.reset_parameters() 104 | 105 | def reset(self) : 106 | super().reset() 107 | if self.compress_out : 108 | self.weight_hh.reset() 109 | if self.compress_in : 110 | self.weight_ih.reset() 111 | 112 | def forward(self, input) : 113 | batch = input.size(0) 114 | if self.state is None : 115 | h0 = Variable(torchauto(self).FloatTensor(batch, self.hidden_size).zero_()) 116 | c0 = Variable(torchauto(self).FloatTensor(batch, self.hidden_size).zero_()) 117 | else : 118 | h0, c0 = self.state 119 | pre_iih, pre_fih, pre_gih, pre_oih = torch.split(self.weight_ih(input), self.hidden_size, dim=1) 120 | pre_ihh, pre_fhh, pre_ghh, pre_ohh = torch.split(self.weight_hh(h0), self.hidden_size, dim=1) 121 | i_t = F.sigmoid(pre_iih + pre_ihh) 122 | f_t = F.sigmoid(pre_fih + pre_fhh) 123 | o_t = F.sigmoid(pre_oih + pre_ohh) 124 | g_t = F.tanh(pre_gih + pre_ghh) 125 | c_t = f_t * c0 + i_t * g_t 126 | h_t = o_t * F.tanh(c_t) 127 | self.state = (h_t, c_t) 128 | return (h_t, c_t) 129 | -------------------------------------------------------------------------------- /tensor_rnn/modules/composite/ttrnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import Module, Parameter, ParameterList 6 | from torch.nn import functional as F 7 | from torch.nn import init 8 | from torch.autograd import Variable 9 | 10 | from ...utils.helper import torchauto, tensorauto 11 | from ...modules.rnn import StatefulBaseCell 12 | from ...modules.tensor_train import TTLinear 13 | 14 | class StatefulTTGRUCell(StatefulBaseCell) : 15 | def __init__(self, in_modes, out_modes, ranks, bias=True, 16 | compress_in=True, compress_out=True) : 17 | super().__init__() 18 | self.in_modes = in_modes 19 | self.out_modes = out_modes 20 | 21 | self.input_size = int(np.prod(in_modes)) 22 | self.hidden_size = int(np.prod(out_modes)) 23 | 24 | self.compress_in = compress_in 25 | self.compress_out = compress_out 26 | 27 | self.bias = bias 28 | self.ranks = ranks 29 | out_modes_3x = list(out_modes) 30 | out_modes_3x[-1] *= 3 31 | if compress_in : 32 | self.weight_ih = TTLinear(in_modes, out_modes_3x, ranks, bias=self.bias) 33 | else : 34 | self.weight_ih = nn.Linear(self.input_size, self.hidden_size*3, bias=self.bias) 35 | if compress_out : 36 | self.weight_hh = TTLinear(out_modes, out_modes_3x, ranks, bias=self.bias) 37 | else : 38 | self.weight_hh = nn.Linear(self.hidden_size, self.hidden_size*3, bias=self.bias) 39 | 40 | self.reset_parameters() 41 | pass 42 | 43 | def reset_parameters(self) : 44 | self.weight_hh.reset_parameters() 45 | self.weight_ih.reset_parameters() 46 | 47 | def forward(self, input) : 48 | batch = input.size(0) 49 | if self.state is None : 50 | h0 = Variable(torchauto(self).FloatTensor(batch, self.hidden_size).zero_()) 51 | else : 52 | h0 = self.state 53 | pre_rih, pre_zih, pre_nih = torch.split(self.weight_ih(input), self.hidden_size, dim=1) 54 | pre_rhh, pre_zhh, pre_nhh = torch.split(self.weight_hh(h0), self.hidden_size, dim=1) 55 | r_t = F.sigmoid(pre_rih + pre_rhh) 56 | z_t = F.sigmoid(pre_zih + pre_zhh) 57 | c_t = F.tanh(pre_nih + r_t * (pre_nhh)) 58 | h_t = (1-z_t) * c_t + (z_t * h0) 59 | self.state = h_t 60 | return h_t 61 | 62 | class StatefulTTLSTMCell(StatefulBaseCell) : 63 | def __init__(self, in_modes, out_modes, ranks, bias=True, 64 | compress_in=True, compress_out=True) : 65 | super().__init__() 66 | self.in_modes = in_modes 67 | self.out_modes = out_modes 68 | 69 | self.input_size = int(np.prod(in_modes)) 70 | self.hidden_size = int(np.prod(out_modes)) 71 | 72 | self.compress_in = compress_in 73 | self.compress_out = compress_out 74 | 75 | self.bias = bias 76 | self.ranks = ranks 77 | out_modes_4x = list(out_modes) 78 | out_modes_4x[-1] *= 4 79 | if compress_in : 80 | self.weight_ih = TTLinear(in_modes, out_modes_4x, ranks, bias=self.bias) 81 | else : 82 | self.weight_ih = nn.Linear(self.input_size, self.hidden_size*4, bias=self.bias) 83 | 84 | if compress_out : 85 | self.weight_hh = TTLinear(out_modes, out_modes_4x, ranks, bias=self.bias) 86 | else : 87 | self.weight_hh = nn.Linear(self.hidden_size, self.hidden_size*4, bias=self.bias) 88 | 89 | self.reset_parameters() 90 | 91 | def reset_parameters(self) : 92 | self.weight_hh.reset_parameters() 93 | self.weight_ih.reset_parameters() 94 | 95 | def forward(self, input) : 96 | batch = input.size(0) 97 | if self.state is None : 98 | h0 = Variable(torchauto(self).FloatTensor(batch, self.hidden_size).zero_()) 99 | c0 = Variable(torchauto(self).FloatTensor(batch, self.hidden_size).zero_()) 100 | else : 101 | h0, c0 = self.state 102 | pre_iih, pre_fih, pre_gih, pre_oih = torch.split(self.weight_ih(input), self.hidden_size, dim=1) 103 | pre_ihh, pre_fhh, pre_ghh, pre_ohh = torch.split(self.weight_hh(h0), self.hidden_size, dim=1) 104 | i_t = F.sigmoid(pre_iih + pre_ihh) 105 | f_t = F.sigmoid(pre_fih + pre_fhh) 106 | o_t = F.sigmoid(pre_oih + pre_ohh) 107 | g_t = F.tanh(pre_gih + pre_ghh) 108 | c_t = f_t * c0 + i_t * g_t 109 | h_t = o_t * F.tanh(c_t) 110 | self.state = (h_t, c_t) 111 | return (h_t, c_t) 112 | -------------------------------------------------------------------------------- /tensor_rnn/modules/composite/tuckerrnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import Module, Parameter, ParameterList 6 | from torch.nn import functional as F 7 | from torch.nn import init 8 | from torch.autograd import Variable 9 | 10 | from ...modules.rnn import StatefulBaseCell 11 | from ...modules.tucker import TuckerLinear 12 | from ...utils.helper import torchauto, tensorauto 13 | 14 | class StatefulTuckerGRUCell(StatefulBaseCell) : 15 | def __init__(self, in_modes, out_modes, ranks, bias=True, cache=True, 16 | compress_in=True, compress_out=True) : 17 | super().__init__() 18 | self.in_modes = in_modes 19 | self.out_modes = out_modes 20 | 21 | self.input_size = int(np.prod(in_modes)) 22 | self.hidden_size = int(np.prod(out_modes)) 23 | 24 | self.compress_in = compress_in 25 | self.compress_out = compress_out 26 | 27 | self.bias = bias 28 | self.ranks = ranks 29 | self.out_modes_Mx = list(out_modes) 30 | self.out_modes_Mx[-1] *= 3 31 | if compress_in : 32 | self.weight_ih = TuckerLinear(in_modes, self.out_modes_Mx, ranks, bias=self.bias, cache=cache) 33 | else : 34 | self.weight_ih = nn.Linear(self.input_size, self.hidden_size*3, bias=self.bias) 35 | if compress_in : 36 | self.weight_hh = TuckerLinear(out_modes, self.out_modes_Mx, ranks, bias=self.bias, cache=cache) 37 | else : 38 | self.weight_hh = nn.Linear(self.hidden_size, self.hidden_size*3, bias=self.bias) 39 | 40 | self.reset_parameters() 41 | pass 42 | 43 | def reset_parameters(self) : 44 | self.weight_hh.reset_parameters() 45 | self.weight_ih.reset_parameters() 46 | 47 | def reset(self) : 48 | super().reset() 49 | if self.compress_out : 50 | self.weight_hh.reset() 51 | if self.compress_in : 52 | self.weight_ih.reset() 53 | 54 | def forward(self, input) : 55 | batch = input.size(0) 56 | if self.state is None : 57 | h0 = Variable(torchauto(self).FloatTensor(batch, self.hidden_size).zero_()) 58 | else : 59 | h0 = self.state 60 | pre_rih, pre_zih, pre_nih = torch.split(self.weight_ih(input), self.hidden_size, dim=1) 61 | pre_rhh, pre_zhh, pre_nhh = torch.split(self.weight_hh(h0), self.hidden_size, dim=1) 62 | r_t = F.sigmoid(pre_rih + pre_rhh) 63 | z_t = F.sigmoid(pre_zih + pre_zhh) 64 | c_t = F.tanh(pre_nih + r_t * (pre_nhh)) 65 | h_t = (1-z_t) * c_t + (z_t * h0) 66 | self.state = h_t 67 | return h_t 68 | pass 69 | 70 | class StatefulTuckerLSTMCell(StatefulBaseCell) : 71 | def __init__(self, in_modes, out_modes, ranks, bias=True, cache=True, 72 | compress_in=True, compress_out=True) : 73 | super().__init__() 74 | self.in_modes = in_modes 75 | self.out_modes = out_modes 76 | 77 | self.input_size = int(np.prod(in_modes)) 78 | self.hidden_size = int(np.prod(out_modes)) 79 | 80 | self.compress_in = compress_in 81 | self.compress_out = compress_out 82 | 83 | self.bias = bias 84 | self.ranks = ranks 85 | self.out_modes_Mx = list(out_modes) 86 | self.out_modes_Mx[-1] *= 4 87 | if compress_in : 88 | self.weight_ih = TuckerLinear(in_modes, self.out_modes_Mx, ranks, bias=self.bias, cache=cache) 89 | else : 90 | self.weight_hh = nn.Linear(self.input_size, self.hidden_size*4, bias=self.bias) 91 | if compress_out : 92 | self.weight_hh = TuckerLinear(out_modes, self.out_modes_Mx, ranks, bias=self.bias, cache=cache) 93 | else : 94 | self.weight_hh = nn.Linear(self.hidden_size, self.hidden_size*4, bias=self.bias) 95 | 96 | self.reset_parameters() 97 | pass 98 | 99 | def reset_parameters(self) : 100 | self.weight_hh.reset_parameters() 101 | self.weight_ih.reset_parameters() 102 | 103 | def reset(self) : 104 | super().reset() 105 | if self.compress_out : 106 | self.weight_hh.reset() 107 | if self.compress_in : 108 | self.weight_ih.reset() 109 | 110 | def forward(self, input) : 111 | batch = input.size(0) 112 | if self.state is None : 113 | h0 = Variable(torchauto(self).FloatTensor(batch, self.hidden_size).zero_()) 114 | c0 = Variable(torchauto(self).FloatTensor(batch, self.hidden_size).zero_()) 115 | else : 116 | h0, c0 = self.state 117 | pre_iih, pre_fih, pre_gih, pre_oih = torch.split(self.weight_ih(input), self.hidden_size, dim=1) 118 | pre_ihh, pre_fhh, pre_ghh, pre_ohh = torch.split(self.weight_hh(h0), self.hidden_size, dim=1) 119 | i_t = F.sigmoid(pre_iih + pre_ihh) 120 | f_t = F.sigmoid(pre_fih + pre_fhh) 121 | o_t = F.sigmoid(pre_oih + pre_ohh) 122 | g_t = F.tanh(pre_gih + pre_ghh) 123 | c_t = f_t * c0 + i_t * g_t 124 | h_t = o_t * F.tanh(c_t) 125 | self.state = (h_t, c_t) 126 | return (h_t, c_t) 127 | -------------------------------------------------------------------------------- /tensor_rnn/modules/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module 3 | 4 | def elementwise_bce(input, target) : 5 | return ElementwiseBCE()(input, target) 6 | 7 | class ElementwiseBCEWithLogits(Module) : 8 | def __init__(self) : 9 | super().__init__() 10 | pass 11 | 12 | def forward(self, input, target) : 13 | if not (target.size() == input.size()): 14 | raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size())) 15 | 16 | max_val = (-input).clamp(min=0) 17 | loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log() 18 | return loss 19 | 20 | def elementwise_bce_with_logits(input, target) : 21 | return ElementwiseBCEWithLogits()(input, target) 22 | -------------------------------------------------------------------------------- /tensor_rnn/modules/rnn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from torch.nn import Module 5 | from torch.nn import Parameter 6 | from torch.autograd import Variable 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | from ..utils.helper import torchauto 11 | 12 | ##### WRAPPER ##### 13 | class StatefulBaseCell(Module) : 14 | def __init__(self) : 15 | super(StatefulBaseCell, self).__init__() 16 | self._state = None 17 | pass 18 | 19 | def reset(self) : 20 | self._state = None 21 | 22 | @property 23 | def state(self) : 24 | return self._state 25 | 26 | @state.setter 27 | def state(self, value) : 28 | self._state = value 29 | 30 | class StatefulLSTMCell(StatefulBaseCell) : 31 | def __init__(self, input_size, hidden_size, bias=True) : 32 | super(StatefulLSTMCell, self).__init__() 33 | self.rnn_cell = nn.LSTMCell(input_size, hidden_size, bias) 34 | pass 35 | 36 | @property 37 | def weight_hh(self) : 38 | return self.rnn_cell.weight_hh.t() 39 | 40 | @property 41 | def weight_ih(self) : 42 | return self.rnn_cell.weight_ih.t() 43 | 44 | @property 45 | def bias_hh(self) : 46 | return self.rnn_cell.bias_hh 47 | 48 | @property 49 | def bias_ih(self) : 50 | return self.rnn_cell.bias_ih 51 | 52 | def forward(self, input) : 53 | batch = input.size(0) 54 | if self.state is None : 55 | h0 = Variable(torchauto(self).FloatTensor(batch, self.rnn_cell.hidden_size).zero_()) 56 | c0 = Variable(torchauto(self).FloatTensor(batch, self.rnn_cell.hidden_size).zero_()) 57 | # h0, c0 # 58 | self.state = (h0, c0) 59 | 60 | self.state = self.rnn_cell(input, self.state) 61 | return self.state 62 | 63 | class StatefulGRUCell(StatefulBaseCell) : 64 | def __init__(self, input_size, hidden_size, bias=True) : 65 | super(StatefulGRUCell, self).__init__() 66 | self.rnn_cell = nn.GRUCell(input_size, hidden_size, bias) 67 | pass 68 | 69 | @property 70 | def weight_hh(self) : 71 | return self.rnn_cell.weight_hh.t() 72 | 73 | @property 74 | def weight_ih(self) : 75 | return self.rnn_cell.weight_ih.t() 76 | 77 | @property 78 | def bias_hh(self) : 79 | return self.rnn_cell.bias_hh 80 | 81 | @property 82 | def bias_ih(self) : 83 | return self.rnn_cell.bias_ih 84 | 85 | def forward(self, input) : 86 | batch = input.size(0) 87 | if self.state is None : 88 | h0 = Variable(torchauto(self).FloatTensor(batch, self.rnn_cell.hidden_size).zero_()) 89 | # h0, c0 # 90 | self.state = h0 91 | 92 | self.state = self.rnn_cell(input, self.state) 93 | return self.state 94 | ################### 95 | -------------------------------------------------------------------------------- /tensor_rnn/modules/tensor_train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | from torch.nn import Module, Parameter, ParameterList 6 | from torch.nn import functional as F 7 | from torch.nn import init 8 | 9 | def _create_tt_cores(in_modes, out_modes, ranks) : 10 | assert len(in_modes) == len(out_modes) == len(ranks)-1 11 | dim = len(in_modes) 12 | list_tt_cores = [] 13 | for ii in range(dim) : 14 | list_tt_cores.append(Parameter(torch.Tensor(out_modes[ii] * ranks[ii+1], in_modes[ii] * ranks[ii]))) 15 | weight = ParameterList(list_tt_cores) 16 | return weight 17 | 18 | def tt_dot(in_modes, out_modes, ranks, input, weight, bias=None) : 19 | assert len(in_modes) == len(out_modes) == len(ranks)-1 20 | assert input.shape[1] == np.prod(in_modes) 21 | res = input 22 | res = res.view(-1, int(np.prod(in_modes))) 23 | res = res.transpose(1, 0) 24 | res = res.contiguous() 25 | dim = len(in_modes) 26 | for ii in range(dim) : 27 | res = res.view(ranks[ii] * in_modes[ii], -1) 28 | res = torch.matmul(weight[ii], res) 29 | res = res.view(out_modes[ii], -1) 30 | res = res.transpose(1, 0) 31 | res = res.contiguous() 32 | res = res.view(-1, int(np.prod(out_modes))) 33 | 34 | if bias is not None : 35 | res += bias 36 | return res 37 | 38 | class TTLinear(Module): 39 | 40 | def __init__(self, in_modes, out_modes, ranks, bias=True): 41 | super().__init__() 42 | self.in_modes = in_modes 43 | self.out_modes = out_modes 44 | self.ranks = ranks 45 | dim = len(self.in_modes) 46 | 47 | assert len(self.in_modes) == len(self.out_modes) == len(self.ranks)-1 48 | 49 | self.weight = _create_tt_cores(self.in_modes, self.out_modes, self.ranks) 50 | 51 | if bias: 52 | self.bias = Parameter(torch.Tensor(int(np.prod(out_modes)))) 53 | else: 54 | self.register_parameter('bias', None) 55 | self.reset_parameters() 56 | 57 | def reset_xavier(self) : 58 | for ii in range(len(self.weight)) : 59 | init.xavier_normal(self.weight[ii]) 60 | 61 | def reset_normal(self) : 62 | CONST = ((((0.05**2)/np.prod(self.ranks)))**(1/(len(self.ranks)-1))) ** 0.5 63 | for ii in range(len(self.weight)) : 64 | init.normal(self.weight[ii], 0, CONST) 65 | 66 | def reset_parameters(self) : 67 | self.reset_normal() 68 | if self.bias is not None: 69 | self.bias.data.zero_() 70 | 71 | def forward(self, input): 72 | return tt_dot(self.in_modes, self.out_modes, self.ranks, input, self.weight, self.bias) 73 | 74 | def __repr__(self): 75 | return self.__class__.__name__ + 'in: ' \ 76 | + str(self.in_modes) + ' -> out:' \ 77 | + str(self.out_modes) + ' | ' \ 78 | + 'rank: {}'.format(str(self.ranks)) 79 | 80 | -------------------------------------------------------------------------------- /tensor_rnn/modules/tucker.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | from torch.nn import Module, Parameter, ParameterList 6 | from torch.nn import functional as F 7 | from torch.nn import init 8 | 9 | def _create_tucker_params(in_modes, out_modes, ranks) : 10 | assert len(in_modes) == len(out_modes) == len(ranks) 11 | modes = in_modes + out_modes # extend list 12 | core = Parameter(torch.Tensor(*list(ranks+ranks)).normal_()) 13 | factors = [] 14 | for mm, rr in zip(modes, ranks+ranks) : 15 | factors.append(Parameter(torch.Tensor(mm, rr).normal_())) 16 | factors = ParameterList(factors) 17 | return core, factors 18 | 19 | def _tensor_to_matrix(in_modes, out_modes, tensor) : 20 | return tensor.contiguous().view(int(np.prod(in_modes)), int(np.prod(out_modes))) 21 | 22 | def _n_mode_product(core, factor, mode) : 23 | assert factor.dim() == 2 24 | # core = [i_1,..,i_j,..,i_D] 25 | core_shape = list(core.shape) # j = mode 26 | core_tmp = core.transpose(mode, -1) # [i_1,..,i_D,i_j] 27 | new_core_shape = list(core_tmp.shape) 28 | core_tmp = core_tmp.contiguous().view(-1, core_shape[mode]) # [prod([i_1,..,i_D]), i_j ] 29 | core_tmp = core_tmp.mm(factor.t()) # [prod([i_1,..,i_D]), m_j] 30 | core_tmp = core_tmp.view(*new_core_shape[0:-1], factor.shape[0]) # [i_1,..,i_D,m_j] 31 | core_tmp = core_tmp.transpose(mode, -1) # [i_1,..,m_j,..,i_D] 32 | return core_tmp 33 | 34 | def _tuckercores_to_tensor(core, list_factors) : 35 | n_dim = len(list_factors) 36 | assert n_dim == core.dim() 37 | tensor_out = core.contiguous() 38 | for ii in range(n_dim) : 39 | tensor_out = _n_mode_product(tensor_out, list_factors[ii], ii) 40 | return tensor_out 41 | 42 | class TuckerLinear(Module) : 43 | def __init__(self, in_modes, out_modes, ranks, bias=True, cache=True) : 44 | """ 45 | cache: if cache is True, pre calculated W_tsr until user reset the variable 46 | """ 47 | super().__init__() 48 | assert len(in_modes) == len(out_modes) == len(ranks) 49 | self.in_modes = in_modes 50 | self.out_modes = out_modes 51 | self.ranks = ranks 52 | self.cache = cache 53 | self._W_linear = None 54 | 55 | self.core, self.factors = _create_tucker_params(in_modes, out_modes, ranks) 56 | 57 | if bias : 58 | self.bias = Parameter(torch.Tensor(int(np.prod(out_modes)))) 59 | else : 60 | self.register_parameter('bias', None) 61 | self.reset_parameters() 62 | 63 | def reset_parameters(self) : 64 | CONST = (0.05 / np.prod(self.ranks+self.ranks)**0.5) ** (1.0/(len(self.in_modes)+len(self.out_modes)+1)) 65 | init.normal(self.core, 0, CONST) 66 | for ii in range(len(self.factors)) : 67 | init.normal(self.factors[ii], 0, CONST) 68 | pass 69 | if self.bias is not None : 70 | self.bias.data.zero_() 71 | 72 | def reset(self) : 73 | self._W_linear = None 74 | 75 | @property 76 | def W_linear(self) : 77 | if not self.cache : 78 | return _tensor_to_matrix(self.in_modes, self.out_modes, _tuckercores_to_tensor(self.core, list(self.factors))) 79 | if self._W_linear is None : 80 | self._W_linear = _tensor_to_matrix(self.in_modes, self.out_modes, _tuckercores_to_tensor(self.core, list(self.factors))) 81 | else : 82 | pass 83 | return self._W_linear 84 | 85 | def forward(self, input) : 86 | return F.linear(input, self.W_linear.t(), self.bias) 87 | -------------------------------------------------------------------------------- /tensor_rnn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .helper import * 2 | -------------------------------------------------------------------------------- /tensor_rnn/utils/data_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def iter_minibatches(indices, batchsize, shuffle=True, pad=False, excludes=None): 4 | """ 5 | Args: 6 | datasize : total number of data or list of indices 7 | batchsize : mini-batchsize 8 | shuffle : 9 | use_padding : pad the dataset if dataset can't divided by batchsize equally 10 | 11 | Return : 12 | list of index for current epoch (randomized or not depends on shuffle) 13 | """ 14 | if isinstance(indices, list) : 15 | indices = indices 16 | elif isinstance(indices, int) : 17 | indices = list(range(indices)) 18 | if excludes is not None : 19 | indices = [x for x in indices if x not in excludes] 20 | if shuffle: 21 | np.random.shuffle(indices) 22 | 23 | if pad : 24 | indices = pad_idx(indices, batchsize) 25 | 26 | for ii in range(0, len(indices), batchsize): 27 | yield indices[ii:ii + batchsize] 28 | pass 29 | -------------------------------------------------------------------------------- /tensor_rnn/utils/helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | def is_cuda_module(module) : 5 | return next(module.parameters()).is_cuda 6 | 7 | def _auto_detect_cuda(module) : 8 | if isinstance(module, torch.nn.Module) : 9 | return is_cuda_module(module) 10 | if isinstance(module, bool) : 11 | return module 12 | if isinstance(module, int) : 13 | return module >= 0 14 | if isinstance(module, torch.autograd.Variable) : 15 | return module.data.is_cuda 16 | if isinstance(module, torch.tensor._TensorBase) : 17 | return module.is_cuda 18 | raise NotImplementedError() 19 | 20 | def torchauto(module) : 21 | return torch.cuda if _auto_detect_cuda(module) else torch 22 | 23 | def tensorauto(module, tensor) : 24 | return tensor.cuda() if _auto_detect_cuda(module) else tensor.cpu() 25 | --------------------------------------------------------------------------------