├── paper ├── .gitignore ├── figures │ ├── gammas.pdf │ ├── ptb_valid.pdf │ ├── tanh_grad.pdf │ ├── tanh_grad.png │ ├── attr_valid.pdf │ ├── attr_valid2.pdf │ ├── ptb_lengths.pdf │ ├── permuted_valid.pdf │ ├── rnn_grad_prop.pdf │ ├── attr_full_valid.pdf │ ├── unpermuted_valid.pdf │ ├── popstat_stationarity.pdf │ ├── popstat_stationarity.png │ ├── plot_popstat_stationarity.py │ └── plot_tanh_grad.py ├── iclr2017_conference.bst ├── Makefile ├── minipage.sty ├── nips15submit_e.sty ├── iclr2017_conference.sty ├── nips_2016.sty ├── index.bib └── fancyhdr.sty ├── experiments ├── sequentialmnist_lstm.sh ├── sequentialmnist_bnlstm.sh ├── sequentialpmnist_bnlstm.sh ├── sequentialpmnist_lstm.sh ├── text8_bn-lstm.sh ├── text8_lstm.sh ├── penntreebank_bn-lstm.sh └── penntreebank_lstm.sh ├── README.md ├── .gitignore ├── plot_attr.py ├── plotgradients.py ├── util.py ├── results └── iclr │ ├── attr │ ├── plot_full.py │ └── plot.py │ ├── rnn_hidden_norms │ └── plot_rnn_gradnorm.py │ └── plot.py ├── penntreebank_inspect.py ├── penntreebank_evaluate.py ├── sequential_mnist_evaluate.py ├── extensions.py ├── sequential_mnist.py ├── penntreebank.py ├── text8.py └── memory.py /paper/.gitignore: -------------------------------------------------------------------------------- 1 | *.aux 2 | *.out 3 | *.log 4 | *.bbl 5 | *.blg 6 | -------------------------------------------------------------------------------- /paper/figures/gammas.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cooijmanstim/recurrent-batch-normalization/HEAD/paper/figures/gammas.pdf -------------------------------------------------------------------------------- /paper/figures/ptb_valid.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cooijmanstim/recurrent-batch-normalization/HEAD/paper/figures/ptb_valid.pdf -------------------------------------------------------------------------------- /paper/figures/tanh_grad.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cooijmanstim/recurrent-batch-normalization/HEAD/paper/figures/tanh_grad.pdf -------------------------------------------------------------------------------- /paper/figures/tanh_grad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cooijmanstim/recurrent-batch-normalization/HEAD/paper/figures/tanh_grad.png -------------------------------------------------------------------------------- /paper/figures/attr_valid.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cooijmanstim/recurrent-batch-normalization/HEAD/paper/figures/attr_valid.pdf -------------------------------------------------------------------------------- /paper/figures/attr_valid2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cooijmanstim/recurrent-batch-normalization/HEAD/paper/figures/attr_valid2.pdf -------------------------------------------------------------------------------- /paper/figures/ptb_lengths.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cooijmanstim/recurrent-batch-normalization/HEAD/paper/figures/ptb_lengths.pdf -------------------------------------------------------------------------------- /paper/iclr2017_conference.bst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cooijmanstim/recurrent-batch-normalization/HEAD/paper/iclr2017_conference.bst -------------------------------------------------------------------------------- /paper/figures/permuted_valid.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cooijmanstim/recurrent-batch-normalization/HEAD/paper/figures/permuted_valid.pdf -------------------------------------------------------------------------------- /paper/figures/rnn_grad_prop.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cooijmanstim/recurrent-batch-normalization/HEAD/paper/figures/rnn_grad_prop.pdf -------------------------------------------------------------------------------- /paper/figures/attr_full_valid.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cooijmanstim/recurrent-batch-normalization/HEAD/paper/figures/attr_full_valid.pdf -------------------------------------------------------------------------------- /paper/figures/unpermuted_valid.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cooijmanstim/recurrent-batch-normalization/HEAD/paper/figures/unpermuted_valid.pdf -------------------------------------------------------------------------------- /paper/figures/popstat_stationarity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cooijmanstim/recurrent-batch-normalization/HEAD/paper/figures/popstat_stationarity.pdf -------------------------------------------------------------------------------- /paper/figures/popstat_stationarity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cooijmanstim/recurrent-batch-normalization/HEAD/paper/figures/popstat_stationarity.png -------------------------------------------------------------------------------- /paper/Makefile: -------------------------------------------------------------------------------- 1 | all : index.tex 2 | pdflatex $< 3 | bibtex `echo $< | cut -d'.' -f 1` || echo "" 4 | pdflatex $< 5 | pdflatex $< 6 | 7 | clean : 8 | rm -rf index.pdf *.log *.aux *.bbl *.blg *.toc *.tex~ *.out 9 | -------------------------------------------------------------------------------- /experiments/sequentialmnist_lstm.sh: -------------------------------------------------------------------------------- 1 | command="../../sequential_mnist.py --lstm --num-epoch 200 --learning-rate 0.001 --init id --batch-size 100 --baseline" 2 | directory=sequentialmnist-lstm 3 | mkdir -p $directory 4 | cd $directory 5 | python $command 6 | cd .. 7 | -------------------------------------------------------------------------------- /experiments/sequentialmnist_bnlstm.sh: -------------------------------------------------------------------------------- 1 | command="../../sequential_mnist.py --lstm --num-epoch 200 --learning-rate 0.001 --init id --batch-size 100 --noise 0.1" 2 | directory=sequentialmnist-bnlstm 3 | mkdir -p $directory 4 | cd $directory 5 | python $command 6 | cd .. 7 | -------------------------------------------------------------------------------- /experiments/sequentialpmnist_bnlstm.sh: -------------------------------------------------------------------------------- 1 | command="../../sequential_mnist.py --lstm --num-epoch 200 --learning-rate 0.001 --init id --batch-size 100 --permuted" 2 | directory=sequentialpmnist-bnlstm 3 | mkdir -p $directory 4 | cd $directory 5 | python $command 6 | cd .. 7 | -------------------------------------------------------------------------------- /experiments/sequentialpmnist_lstm.sh: -------------------------------------------------------------------------------- 1 | command="../../sequential_mnist.py --lstm --num-epoch 200 --learning-rate 0.001 --init id --batch-size 100 --baseline --permuted" 2 | directory=sequentialpmnist-lstm 3 | mkdir -p $directory 4 | cd $directory 5 | python $command 6 | cd .. 7 | -------------------------------------------------------------------------------- /experiments/text8_bn-lstm.sh: -------------------------------------------------------------------------------- 1 | command="../../text8.py --length 180 --initialization orthogonal --optimizer adam --learning-rate 0.001 --batch-size 128 --num-epochs 50" 2 | nh=2000 3 | directory=text8-bn-lstm-nh$nh 4 | mkdir -p $directory 5 | cd $directory 6 | python $command --num-hidden $nh 7 | cd .. 8 | -------------------------------------------------------------------------------- /experiments/text8_lstm.sh: -------------------------------------------------------------------------------- 1 | command="../../text8.py --length 180 --initialization orthogonal --optimizer adam --learning-rate 0.001 --batch-size 128 --num-epochs 50 --baseline" 2 | nh=2000 3 | directory=text8-lstm-nh$nh 4 | mkdir -p $directory 5 | cd $directory 6 | python $command --num-hidden $nh 7 | cd .. 8 | -------------------------------------------------------------------------------- /experiments/penntreebank_bn-lstm.sh: -------------------------------------------------------------------------------- 1 | command="../penntreebank.py --optimizer adam --initialization orthogonal --num-epochs 50 --length 100 --batch-size 64" 2 | lr=0.002 3 | nh=1000 4 | directory=bn-lstm-nh$nh-lr$lr 5 | mkdir -p $directory 6 | cd $directory 7 | python $command --num-hidden $nh --learning-rate $lr 8 | cd .. 9 | -------------------------------------------------------------------------------- /experiments/penntreebank_lstm.sh: -------------------------------------------------------------------------------- 1 | command="../penntreebank.py --optimizer adam --initialization orthogonal --num-epochs 50 --length 100 --batch-size 64" 2 | lr=0.002 3 | nh=1000 4 | directory=lstm-nh$nh-lr$lr 5 | mkdir -p $directory 6 | cd $directory 7 | python $command --num-hidden $nh --learning-rate $lr --baseline 8 | cd .. 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains the code that was used for the Sequential 2 | MNIST, Penn Treebank and text8 experiments in the paper Recurrent 3 | Batch Normalization (http://arxiv.org/abs/1603.09025). For the 4 | Attentive Reader, see 5 | https://github.com/cooijmanstim/Attentive_reader/tree/bn. 6 | 7 | The `experiments` directory contains shell scripts that demonstrate 8 | how to launch the experiments with the hyperparameters from the paper. 9 | 10 | Depends on [Theano](https://github.com/Theano/Theano), [Blocks](https://github.com/mila-udem/blocks) and [Fuel](https://github.com/mila-udem/fuel). 11 | 12 | Other implementations: 13 | - https://github.com/fchollet/keras/pull/2183 14 | - https://github.com/iassael/torch-bnlstm 15 | - https://github.com/OlavHN/bnlstm 16 | - https://gist.github.com/spitis/27ab7d2a30bbaf5ef431b4a02194ac60 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | 58 | # PyBuilder 59 | target/ 60 | 61 | #Ipython Notebook 62 | .ipynb_checkpoints 63 | -------------------------------------------------------------------------------- /paper/figures/plot_popstat_stationarity.py: -------------------------------------------------------------------------------- 1 | import cPickle, numpy as np 2 | 3 | pkl = cPickle.load(open("/data/lisatmp3/cooijmat/run/batchnorm/ptb/repeatpopstat_independentbatchstat/checkpoint.zip_popstat_results.pkl")) 4 | 5 | new_popstats = dict((k.name, v) for k, v in pkl["new_popstats"].items()) 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | statlabels = dict( 10 | a_mean="mean of recurrent term", 11 | b_mean="mean of input term", 12 | c_mean="mean of cell state", 13 | a_var="variance of recurrent term", 14 | b_var="variance of input term", 15 | c_var="variance of cell state") 16 | 17 | fig, axess = plt.subplots(2, 2, sharex='col') 18 | statss = [["%s_%s" % (key, stat) for key in "ac"] 19 | for stat in "mean var".split()] 20 | for axes, stats in zip(axess, statss): 21 | for axis, stat in zip(axes, stats): 22 | popstat = new_popstats[stat] 23 | 24 | # random subset of popstats 25 | subset = np.random.choice(popstat.shape[1], size=30, replace=False) 26 | 27 | axis.plot(popstat[:, subset]) 28 | axis.set_title(statlabels[stat]) 29 | 30 | # set xlabel only on bottom subplots since x axis is shared 31 | for axis in axess[1]: 32 | axis.set_xlabel("time steps") 33 | 34 | plt.show() 35 | 36 | -------------------------------------------------------------------------------- /paper/figures/plot_tanh_grad.py: -------------------------------------------------------------------------------- 1 | import numpy as np, matplotlib.pyplot as plt 2 | import matplotlib.cm as cm 3 | import matplotlib 4 | matplotlib.rcParams.update({"font.size": 20}) 5 | 6 | def tanh(x): return np.tanh(x) 7 | def dtanh(x): return 1 - tanh(x)**2 8 | def sigmoid(x): return 1. / (1. + np.exp(-x)) 9 | def dsigmoid(x): return sigmoid(x) * (1 - sigmoid(x)) 10 | 11 | sample_size = 1000 12 | sigmas = np.linspace(0, 1, 1000) 13 | x = np.random.randn(sample_size, len(sigmas)) * sigmas 14 | 15 | colors = cm.viridis([0.3]) 16 | 17 | for fn_name, fn, dfn in [("tanh", tanh, dtanh), 18 | ("logistic function", sigmoid, dsigmoid)]: 19 | y = fn(x) 20 | dydx = dfn(x) 21 | 22 | #plt.figure() 23 | #plt.plot(sigmas, sigmas) 24 | #plt.plot(sigmas, y.std(axis=0)) 25 | #plt.gca().set_aspect("equal") 26 | #plt.xlim([0, 1]) 27 | #plt.ylim([0, 1]) 28 | #plt.xlabel("input standard deviation") 29 | #plt.ylabel("output standard deviation") 30 | #plt.title("%s variance propagation" % fn_name) 31 | 32 | plt.figure() 33 | plt.plot(sigmas, dydx.mean(axis=0), linewidth=3, color=colors[0])# color='#CC4F1B') 34 | plt.fill_between(sigmas, 35 | np.percentile(dydx, 25, axis=0), 36 | np.percentile(dydx, 75, axis=0), 37 | facecolor=colors[0], #facecolor='#FF9848', 38 | alpha=0.3, 39 | linewidth=0) 40 | #plt.gca().set_aspect("equal") 41 | plt.xlim([0, 1]) 42 | plt.ylim([0, 1]) 43 | plt.xlabel("input standard deviation") 44 | plt.ylabel("expected derivative (and IQR range)") 45 | plt.title("derivative through %s" % fn_name) 46 | 47 | plt.tight_layout() 48 | fig = plt.gcf() 49 | fig.set_size_inches(800 / fig.dpi, 600 / fig.dpi) 50 | plt.savefig("tanh_grad.pdf", bbox_inches="tight") 51 | 52 | break 53 | 54 | #plt.show() 55 | -------------------------------------------------------------------------------- /plot_attr.py: -------------------------------------------------------------------------------- 1 | import sys, numpy as np 2 | import cPickle as pkl 3 | import matplotlib.pyplot as plt 4 | from matplotlib import cm 5 | 6 | paths = sys.argv[1:] 7 | instances = [] 8 | for i, path in enumerate(paths): 9 | try: 10 | label, path = path.split(":") 11 | except ValueError: 12 | label = i 13 | print(label, path) 14 | instances.append(dict(label=label, 15 | path=path, 16 | data=pkl.load(open(path)))) 17 | 18 | def dump(path): 19 | data = dict((instance["label"], 20 | dict(train=instance["data"]["train_err_ave"], 21 | valid=instance["data"]["valid_errs"])) 22 | for instance in instances) 23 | pkl.dump(data, open(path, "wb")) 24 | 25 | import pdb; pdb.set_trace() 26 | 27 | colors = "r b g purple maroon darkslategray darkolivegreen orangered".split() 28 | colors = cm.rainbow(np.linspace(0, 1, len(instances))) 29 | 30 | channel_labels = dict(train_err_ave="train", 31 | valid_errs="valid") 32 | 33 | plt.figure() 34 | for channel_name, kwargs in [ 35 | ("train_err_ave", dict(linestyle="dotted")), 36 | ("valid_errs", dict(linestyle="solid"))]: 37 | for color, instance in zip(colors, instances): 38 | label = "%s %s" % (instance["label"], channel_labels[channel_name]) 39 | plt.plot(np.asarray(instance["data"][channel_name]), label=label, color=color, linewidth=3, **kwargs) 40 | plt.legend() 41 | plt.xlim((0, 800)) 42 | plt.ylabel("error rate") 43 | plt.xlabel("training steps (thousands)") 44 | 45 | plt.figure() 46 | for channel_name, kwargs in [ 47 | ("train_cost_ave", dict(linestyle="dotted")), 48 | ("valid_costs", dict(linestyle="solid"))]: 49 | for color, instance in zip(colors, instances): 50 | label = "%s %s" % (instance["label"], channel_name) 51 | plt.plot(instance["data"][channel_name], label=label, color=color, linewidth=3, **kwargs) 52 | plt.legend() 53 | plt.xlim((0, 800)) 54 | plt.ylabel("error rate") 55 | plt.xlabel("training steps") 56 | 57 | plt.show() 58 | 59 | -------------------------------------------------------------------------------- /plotgradients.py: -------------------------------------------------------------------------------- 1 | import sys, os, pprint, math 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from mpl_toolkits.axes_grid1 import make_axes_locatable 5 | 6 | from blocks.serialization import load 7 | 8 | statistics = dict(batchmean=lambda x: x.mean(axis=1), 9 | batchvar=lambda x: x.var(axis=1), 10 | #minbatchvar=lambda x: x.var(axis=1).min(axis=1), 11 | #maxbatchvar=lambda x: x.var(axis=1).max(axis=1), 12 | norm=lambda x: np.sqrt((x**2).sum(axis=(1, 2)))) 13 | 14 | paths = dict(enumerate(sys.argv[1:])) 15 | pprint.pprint(paths) 16 | subplotrows = int(math.ceil(math.sqrt(len(paths)))) 17 | 18 | instances = dict((k, load(v)) for k, v in paths.items()) 19 | 20 | keys = list(next(iter(instances.values())).keys()) 21 | print keys 22 | for key in keys: 23 | for statlabel, statistic in statistics.items(): 24 | figure, axes = plt.subplots(subplotrows, subplotrows) 25 | if subplotrows == 1: 26 | # morons 27 | axes = [[axes]] 28 | label = 0 29 | for i in range(subplotrows): 30 | for j in range(subplotrows): 31 | try: 32 | instance = instances[label] 33 | except: 34 | continue 35 | result = statistic(instance[key]) 36 | axis = axes[i][j] 37 | if result.ndim == 1: 38 | # logarithmic line plot 39 | axis.plot(result) 40 | axis.set_yscale("log") 41 | elif result.ndim == 2: 42 | # heatmap 43 | mappable = axis.imshow(result.T, cmap="bone", interpolation="none", aspect="auto") 44 | divider = make_axes_locatable(axis) 45 | figure.colorbar(mappable, cax=divider.append_axes("right", size="5%", pad=0.05)) 46 | axis.set_title("#%i" % label) 47 | title = "%s %s" % (key, statlabel) 48 | figure.suptitle(title) 49 | figure.canvas.set_window_title(title) 50 | import pdb; pdb.set_trace() 51 | plt.show() 52 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import numpy as np, theano, theano.tensor as T 3 | import itertools as it 4 | from picklable_itertools.extras import equizip 5 | 6 | # scan with arguments in dicts rather than lists 7 | def scan(fn, 8 | sequences=None, 9 | outputs_info=None, 10 | non_sequences=None, 11 | **scan_kwargs): 12 | # we don't care about the order, as long as it's consistent 13 | sequences = OrderedDict(sequences or []) 14 | outputs_info = OrderedDict(outputs_info or []) 15 | non_sequences = OrderedDict(non_sequences or []) 16 | 17 | # make sure names are unique 18 | assert not (set(sequences) & set(outputs_info) & set(non_sequences)) 19 | 20 | def listified_fn(*input_list): 21 | input_dict = OrderedDict() 22 | input_it = iter(input_list) 23 | input_dict.update(equizip(sequences.keys(), 24 | it.islice(input_it, len(sequences)))) 25 | for name, info in outputs_info.items(): 26 | if info is None: 27 | continue # no inputs 28 | elif isinstance(info, (dict, OrderedDict)): 29 | ntaps = len(info.get("taps", [-1])) 30 | else: 31 | # assume some kind of tensor variable or numpy array 32 | ntaps = 1 33 | taps = [next(input_it) for _ in range(ntaps)] 34 | input_dict[name] = taps if ntaps > 1 else taps[0] 35 | input_dict.update(equizip(non_sequences.keys(), 36 | it.islice(input_it, len(non_sequences)))) 37 | 38 | # input_list should be exactly empty here 39 | try: 40 | next(input_it) 41 | except StopIteration: 42 | pass 43 | else: 44 | assert False 45 | 46 | output_dict = fn(**input_dict) 47 | output_list = [output_dict[output_name].copy(name=output_name) 48 | for output_name in outputs_info.keys()] 49 | return output_list 50 | 51 | outputs, updates = theano.scan( 52 | listified_fn, 53 | sequences=sequences.values(), 54 | outputs_info=outputs_info.values(), 55 | non_sequences=non_sequences.values(), 56 | **scan_kwargs) 57 | outputs = OrderedDict(equizip(outputs_info.keys(), outputs)) 58 | return outputs, updates 59 | -------------------------------------------------------------------------------- /paper/minipage.sty: -------------------------------------------------------------------------------- 1 | %%% ==================================================================== 2 | %%% @LaTeX-style-file{ 3 | %%% author = "Mario Wolczko", 4 | %%% version = "2", 5 | %%% date = "21 May 1992", 6 | %%% time = "20:55:01 BST", 7 | %%% filename = "boxedminipage.sty", 8 | %%% email = "mario@acm.org", 9 | %%% codetable = "ISO/ASCII", 10 | %%% keywords = "LaTeX, minipage, framebox", 11 | %%% supported = "no", 12 | %%% docstring = "LaTeX document-style option which defines 13 | %%% the boxedminipage environment -- just like minipage, but with 14 | %%% a box around it.", 15 | %%% } 16 | %%% ==================================================================== 17 | % 18 | % This file is in the public domain 19 | % 20 | % The thickness of the rules around the box is controlled by 21 | % \fboxrule, and the distance between the rules and the edges of the 22 | % inner box is governed by \fboxsep. 23 | % 24 | % This code is based on Lamport's minipage code. 25 | % 26 | % Fixed, 7 Jun 89 by Jerry Leichter 27 | % Leave \fboxsep worth of separation at top and bottom, not just at 28 | % the sides! 29 | % 30 | \def\boxedminipage{\@ifnextchar [{\@iboxedminipage}{\@iboxedminipage[c]}} 31 | 32 | \def\@iboxedminipage[#1]#2{\leavevmode \@pboxswfalse 33 | \if #1b\vbox 34 | \else \if #1t\vtop 35 | \else \ifmmode \vcenter 36 | \else \@pboxswtrue $\vcenter 37 | \fi 38 | \fi 39 | \fi\bgroup % start of outermost vbox/vtop/vcenter 40 | \hsize #2 41 | \hrule\@height\fboxrule 42 | \hbox\bgroup % inner hbox 43 | \vrule\@width\fboxrule \hskip\fboxsep \vbox\bgroup % innermost vbox 44 | \vskip\fboxsep 45 | \advance\hsize -2\fboxrule \advance\hsize-2\fboxsep 46 | \textwidth\hsize \columnwidth\hsize 47 | \@parboxrestore 48 | \def\@mpfn{mpfootnote}\def\thempfn{\thempfootnote}\c@mpfootnote\z@ 49 | \let\@footnotetext\@mpfootnotetext 50 | \let\@listdepth\@mplistdepth \@mplistdepth\z@ 51 | \@minipagerestore\@minipagetrue 52 | \everypar{\global\@minipagefalse\everypar{}}} 53 | 54 | \def\endboxedminipage{% 55 | \par\vskip-\lastskip 56 | \ifvoid\@mpfootins\else 57 | \vskip\skip\@mpfootins\footnoterule\unvbox\@mpfootins\fi 58 | \vskip\fboxsep 59 | \egroup % ends the innermost \vbox 60 | \hskip\fboxsep \vrule\@width\fboxrule 61 | \egroup % ends the \hbox 62 | \hrule\@height\fboxrule 63 | \egroup% ends the vbox/vtop/vcenter 64 | \if@pboxsw $\fi} 65 | 66 | -------------------------------------------------------------------------------- /results/iclr/attr/plot_full.py: -------------------------------------------------------------------------------- 1 | import sys, numpy as np 2 | import cPickle as pkl 3 | import matplotlib.pyplot as plt 4 | from matplotlib import cm 5 | 6 | friendly_labels = { 7 | "Unidir": "BN-e unidir tweaked", 8 | "Bidir": "BN-e bidir tweaked", 9 | } 10 | 11 | paths = """ 12 | LSTM:full/baseline.npz.pkl 13 | BN-e**:full/120_uni_bs64_lr8e-4_use_dq_sims1_use_desc_skip_c_g1_sequencewise/stats_bn-bidir-dropout.npz.pkl 14 | """.split() 15 | 16 | #Unidir:full/bn_280_bi_bs64_lr8e-4_use_dq_sims1_use_desc_skip_c_g1_sequensewisenorm/stats_bn-bidir-dropout.npz.pkl 17 | 18 | instances = [] 19 | for i, path in enumerate(paths): 20 | try: 21 | label, path = path.split(":") 22 | except ValueError: 23 | label = i 24 | label = friendly_labels.get(label, label) 25 | print(label, path) 26 | instances.append(dict(label=label, 27 | path=path, 28 | data=pkl.load(open(path)))) 29 | 30 | colors = "r b g goldenrod purple".split() 31 | #colors = cm.rainbow(np.linspace(0, 1, len(instances))) 32 | 33 | channel_labels = dict(train_err_ave="train", 34 | valid_errs="valid") 35 | 36 | import matplotlib 37 | matplotlib.rcParams.update({"font.size": 18}) 38 | plt.figure() 39 | for channel_name, kwargs in [ 40 | ("train_err_ave", dict(linestyle="dotted")), 41 | ("valid_errs", dict(linestyle="solid"))]: 42 | for color, instance in zip(colors, instances): 43 | label = "%s %s" % (instance["label"], channel_labels[channel_name]) 44 | x = np.asarray(instance["data"][channel_name]) 45 | minimum = min(x) 46 | print label, "min", minimum 47 | plt.plot(x, label=label, color=color, linewidth=2, **kwargs) 48 | if "valid" in channel_name: 49 | plt.axhline(y=minimum, xmin=0.98, xmax=1, linewidth=2, color=color) 50 | 51 | plt.legend(prop=dict(size=14)) 52 | plt.xlim((0, 400)) 53 | plt.ylabel("error rate") 54 | plt.xlabel("training steps (thousands)") 55 | 56 | if False: 57 | plt.figure() 58 | for channel_name, kwargs in [ 59 | ("train_cost_ave", dict(linestyle="dotted")), 60 | ("valid_costs", dict(linestyle="solid"))]: 61 | for color, instance in zip(colors, instances): 62 | label = "%s %s" % (instance["label"], channel_name) 63 | plt.plot(instance["data"][channel_name], label=label, color=color, linewidth=3, **kwargs) 64 | plt.legend(prop=dict(size=14)) 65 | plt.xlim((0, 400)) 66 | plt.ylabel("cost") 67 | plt.xlabel("training steps") 68 | 69 | plt.tight_layout() 70 | fig = plt.gcf() 71 | fig.set_size_inches(800 / fig.dpi, 600 / fig.dpi) 72 | plt.savefig("attr_full_valid.pdf", bbox_inches="tight") 73 | #plt.show() 74 | -------------------------------------------------------------------------------- /penntreebank_inspect.py: -------------------------------------------------------------------------------- 1 | import sys, cPickle 2 | import matplotlib.pyplot as plt 3 | 4 | def split_path(pathlike): 5 | i, pathlike = pathlike 6 | try: 7 | name, path = pathlike.split(":") 8 | except (ValueError, AttributeError): 9 | name, path = i, pathlike 10 | print("%s: %s" % (name, path)) 11 | return name, path 12 | 13 | def load_instance(pathlike): 14 | name, path = split_path(pathlike) 15 | with open(path, "rb") as file: 16 | thing = cPickle.load(file) 17 | return dict(name=name, path=path, **thing) 18 | 19 | # arguments: (optionally labeled) paths to pickle files generated by penntreebank_evaluate.py, in the form [label:]path 20 | paths = sys.argv[1:] 21 | instances = list(map(load_instance, enumerate(paths))) 22 | 23 | import math 24 | def natstobits(x): 25 | return x / math.log(2) 26 | 27 | colors = "blue red green cyan magenta yellow black white".split() 28 | for which_set in "train valid test".split(): 29 | plt.figure() 30 | for situation, kwargs in [("inference", dict(linestyle="solid")), 31 | ("training", dict(linestyle="dashed"))]: 32 | for color, instance in zip(colors, instances): 33 | # baseline training/inference performances will be identical 34 | if instance["name"] == "LSTM" and situation == "training": 35 | continue 36 | 37 | label = instance["name"] 38 | if instance["name"] == "BN-LSTM": 39 | label += ", " + dict(training="batch statistics", 40 | inference="population statistics")[situation] 41 | 42 | results = instance["results"][situation][which_set] 43 | tvs = [(t, v["cross_entropy"]) for t, v in results.items()] 44 | time, value = zip(*tvs) 45 | 46 | # don't care about result of length 50 as we're training on 100 now 47 | assert time[0] == 50 48 | time = time[1:] 49 | value = value[1:] 50 | 51 | value = list(map(natstobits, value)) 52 | plt.plot(time, value, label=label, c=color, linewidth=3, **kwargs) 53 | #plt.yscale("log") 54 | #plt.legend(loc='center left', bbox_to_anchor=(1, 0.5)) 55 | plt.legend() 56 | #plt.title("performance on slices of the " + which_set + " string") 57 | plt.xlabel("sequence length") 58 | plt.ylabel("mean bits per character") 59 | 60 | for instance in instances: 61 | print "bpc on full test", instance["name"], natstobits(instance["results"]["proper_test"]["cross_entropy"]) 62 | 63 | import pdb; pdb.set_trace() 64 | plt.show() 65 | 66 | if False: 67 | for instance in instances: 68 | for variable, value in instance["new_popstats"].items(): 69 | plt.figure() 70 | plt.imshow(value, cmap="bone", aspect="auto") 71 | plt.colorbar() 72 | plt.title("%s %s" % (instance["name"], variable.name)) 73 | import pdb; pdb.set_trace() 74 | plt.show() 75 | 76 | import pdb; pdb.set_trace() 77 | -------------------------------------------------------------------------------- /results/iclr/rnn_hidden_norms/plot_rnn_gradnorm.py: -------------------------------------------------------------------------------- 1 | import sys, numpy as np 2 | from blocks.serialization import load 3 | import matplotlib.pyplot as plt 4 | import matplotlib.cm as cm 5 | from collections import OrderedDict 6 | import matplotlib 7 | matplotlib.rcParams.update({"font.size": 20}) 8 | 9 | def split_path(pathlike): 10 | i, pathlike = pathlike 11 | try: 12 | name, path = pathlike.split(":") 13 | except (ValueError, AttributeError): 14 | name, path = i, pathlike 15 | print("%s: %s" % (name, path)) 16 | return name, path 17 | 18 | def load_instance(pathlike): 19 | name, path = split_path(pathlike) 20 | return load_named_instance(name, path) 21 | 22 | def load_named_instance(name, path): 23 | with open(path, "rb") as file: 24 | hiddens = load(file) 25 | return dict(name=name, path=path, hiddens=hiddens) 26 | 27 | def load_named_instance(name, path): 28 | print(name, path) 29 | npz_path = path.replace(".pkl", ".npz") 30 | try: 31 | hiddens = np.load(npz_path) 32 | except: 33 | with open(path, "rb") as file: 34 | hiddens = load(file) 35 | np.savez_compressed(npz_path, **hiddens) 36 | return dict(name=name, path=path, hiddens=hiddens) 37 | 38 | #paths = sys.argv[1:] 39 | #instances = list(map(load_instance, enumerate(paths))) 40 | 41 | paths = OrderedDict([ 42 | ("gamma=0.10", "rnn_gamma0.10/hiddens_0.pkl"), 43 | ("gamma=0.20", "rnn_gamma0.20/hiddens_0.pkl"), 44 | ("gamma=0.30", "rnn_gamma0.30/hiddens_0.pkl"), 45 | ("gamma=0.40", "rnn_gamma0.40/hiddens_0.pkl"), 46 | ("gamma=0.50", "rnn_gamma0.50/hiddens_0.pkl"), 47 | ("gamma=0.60", "rnn_gamma0.60/hiddens_0.pkl"), 48 | ("gamma=0.70", "rnn_gamma0.70/hiddens_0.pkl"), 49 | ("gamma=0.80", "rnn_gamma0.80/hiddens_0.pkl"), 50 | ("gamma=0.90", "rnn_gamma0.90/hiddens_0.pkl"), 51 | ("gamma=1.00", "rnn_gamma1.0/hiddens_0.pkl"), 52 | ]) 53 | instances = [load_named_instance(k, v) for k, v in paths.items()] 54 | 55 | colors = cm.viridis(np.linspace(0.1, 0.9, len(instances))) 56 | linestyles = "- - - - - - - - - - - - - - - -".split() 57 | assert len(linestyles) >= len(instances) 58 | 59 | plt.figure() 60 | allnorms = [] 61 | for instance, color, linestyle in zip(instances, colors, linestyles): 62 | # expected gradient norm over time (expectation over data) 63 | norms = np.sqrt((instance["hiddens"]["h_grad"] ** 2).sum(axis=2)).mean(axis=1) 64 | plt.plot(norms, label=instance["name"], color=color, linewidth="3", linestyle=linestyle) 65 | allnorms.extend(norms) 66 | 67 | plt.title("RNN gradient propagation") 68 | plt.yscale("log") 69 | plt.ylim(ymax=1) 70 | plt.xlabel("t") 71 | plt.ylabel("||dloss/dh_t||_2") 72 | plt.legend(loc="lower left", prop=dict(size=16)) 73 | 74 | axis = plt.gca() 75 | yticks = axis.get_yticks() 76 | # ticks are on odd powers for some reason 77 | yticks *= 10 78 | # ticks exceed range for some reason 79 | yticks = yticks[yticks <= axis.get_ylim()[1]] 80 | axis.set_yticks(yticks) 81 | 82 | #import pdb; pdb.set_trace() 83 | plt.tight_layout() 84 | fig = plt.gcf() 85 | fig.set_size_inches(800 / fig.dpi, 600 / fig.dpi) 86 | plt.savefig("rnn_grad_prop.pdf", bbox_inches="tight") 87 | #plt.show() 88 | -------------------------------------------------------------------------------- /results/iclr/attr/plot.py: -------------------------------------------------------------------------------- 1 | import sys, numpy as np 2 | import cPickle as pkl 3 | import matplotlib.pyplot as plt 4 | from matplotlib import cm 5 | 6 | paths = """ 7 | LSTM:preliminary/baseline/stats_dimworda_[240]_datamode_top4_usedqsim_1_useelug_0_validFre_1000_clip-c_[10.0]_usebidir_0_encoderq_lstm_dimproj_[240]_use-drop_[True]_optimize_adam_decay-c_[0.0]_truncate_-1_learnh0_1_default.npz.pkl 8 | BN-LSTM:preliminary/batchnorm/stats_dimworda_[240]_datamode_top4_usedqsim_1_useelug_0_validFre_1000_clip-c_[10.0]_usebidir_0_encoderq_bnlstm_dimproj_[240]_use-drop_[True]_optimize_adam_decay-c_[0.0]_truncate_-1_default.npz.pkl 9 | BN-everywhere:preliminary/batchnorm-everywhere/stats_dimworda_[240]_datamode_top4_usedqsim_1_useelug_0_validFre_1000_clip-c_[10.0]_usebidir_0_encoderq_bnlstm_dimproj_[240]_use-drop_[True]_optimize_adam_decay-c_[0.0]_default.npz.pkl 10 | Unidir:reprod/improved_240_uni_bs40_lr8e-4_use_dq_sims1_use_desc_skip_c_g1_sequensewisenorm/stats_bn-bidir-dropout.npz.pkl 11 | Bidir:reprod/bidir/240_bi_bs64_lr8e-5_use_dq_sims1_use_desc_skip_c_g1_sequensewisenorm/stats_bn-bidir-dropout.npz.pkl 12 | """.split() 13 | 14 | friendly_labels = { 15 | "Unidir": "BN-e*", 16 | "Bidir": "BN-e**", 17 | } 18 | 19 | instances = [] 20 | for i, path in enumerate(paths): 21 | try: 22 | label, path = path.split(":") 23 | except ValueError: 24 | label = i 25 | label = friendly_labels.get(label, label) 26 | print(label, path) 27 | instances.append(dict(label=label, 28 | path=path, 29 | data=pkl.load(open(path)))) 30 | 31 | colors = "r b g goldenrod purple".split() 32 | #colors = cm.rainbow(np.linspace(0, 1, len(instances))) 33 | 34 | channel_labels = dict(train_err_ave="train", 35 | valid_errs="valid") 36 | 37 | import matplotlib 38 | matplotlib.rcParams.update({"font.size": 18}) 39 | plt.figure() 40 | for channel_name, kwargs in [ 41 | ("train_err_ave", dict(linestyle="dotted")), 42 | ("valid_errs", dict(linestyle="solid"))]: 43 | for color, instance in zip(colors, instances): 44 | label = "%s %s" % (instance["label"], channel_labels[channel_name]) 45 | minimum = min(instance["data"][channel_name]) 46 | print label, "min", minimum 47 | plt.plot(np.asarray(instance["data"][channel_name]), label=label, color=color, linewidth=2, **kwargs) 48 | if "valid" in channel_name: 49 | plt.axhline(y=minimum, xmin=0.98, xmax=1, linewidth=2, color=color) 50 | plt.legend(prop=dict(size=12)) 51 | plt.xlim((0, 800)) 52 | plt.ylabel("error rate") 53 | plt.xlabel("training steps (thousands)") 54 | 55 | if False: 56 | plt.figure() 57 | for channel_name, kwargs in [ 58 | ("train_cost_ave", dict(linestyle="dotted")), 59 | ("valid_costs", dict(linestyle="solid"))]: 60 | for color, instance in zip(colors, instances): 61 | label = "%s %s" % (instance["label"], channel_name) 62 | plt.plot(instance["data"][channel_name], label=label, color=color, linewidth=3, **kwargs) 63 | plt.legend(prop=dict(size=12)) 64 | plt.xlim((0, 800)) 65 | plt.ylabel("cost") 66 | plt.xlabel("training steps") 67 | 68 | plt.tight_layout() 69 | fig = plt.gcf() 70 | fig.set_size_inches(800 / fig.dpi, 600 / fig.dpi) 71 | plt.savefig("attr_valid2.pdf", bbox_inches="tight") 72 | #plt.show() 73 | -------------------------------------------------------------------------------- /penntreebank_evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import theano, itertools, pprint, copy, numpy as np, theano.tensor as T, re 3 | from collections import OrderedDict 4 | from blocks.serialization import load 5 | import util 6 | 7 | # to make unpickling work :-( 8 | from penntreebank import * 9 | 10 | # argument: path to a checkpoint file 11 | main_loop = load(sys.argv[1]) 12 | print main_loop.log.current_row 13 | 14 | # extract population statistic updates 15 | updates = [update for update in main_loop.algorithm.updates 16 | # FRAGILE 17 | if re.search("_(mean|var)$", update[0].name)] 18 | print updates 19 | 20 | old_popstats = dict((popstat, popstat.get_value()) for popstat, _ in updates) 21 | 22 | # baseline doesn't need all this 23 | if updates: 24 | train_stream = get_stream(which_set="train", 25 | batch_size=100, 26 | augment=False, 27 | length=100) 28 | nbatches = len(list(train_stream.get_epoch_iterator())) 29 | 30 | # destructure moving average expression to construct a new expression 31 | new_updates = [] 32 | for popstat, value in updates: 33 | # FRAGILE 34 | assert value.owner.op.scalar_op == theano.scalar.add 35 | terms = value.owner.inputs 36 | # right multiplicand of second term is popstat 37 | assert popstat in theano.gof.graph.ancestors([terms[1].owner.inputs[1]]) 38 | # right multiplicand of first term is batchstat 39 | batchstat = terms[0].owner.inputs[1] 40 | 41 | old_popstats[popstat] = popstat.get_value() 42 | 43 | # FRAGILE: assume population statistics not used in computation of batch statistics 44 | # otherwise popstat should always have a reasonable value 45 | popstat.set_value(0 * popstat.get_value(borrow=True)) 46 | new_updates.append((popstat, popstat + batchstat / float(nbatches))) 47 | 48 | # FRAGILE: assume all the other algorithm updates are unneeded for computation of batch statistics 49 | estimate_fn = theano.function(main_loop.algorithm.inputs, [], 50 | updates=new_updates, on_unused_input="warn") 51 | for batch in train_stream.get_epoch_iterator(as_dict=True): 52 | estimate_fn(**batch) 53 | 54 | new_popstats = dict((popstat, popstat.get_value()) for popstat, _ in updates) 55 | 56 | from blocks.monitoring.evaluators import DatasetEvaluator 57 | results = dict() 58 | for situation in "training inference".split(): 59 | results[situation] = dict() 60 | outputs, = [ 61 | extension._evaluator.theano_variables 62 | for extension in main_loop.extensions 63 | if getattr(extension, "prefix", None) == "valid_%s" % situation] 64 | evaluator = DatasetEvaluator(outputs) 65 | for which_set in "train valid test".split(): 66 | results[situation][which_set] = OrderedDict( 67 | (length, evaluator.evaluate(get_stream( 68 | which_set=which_set, 69 | batch_size=100, 70 | augment=False, 71 | length=length))) 72 | for length in [50, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]) 73 | 74 | results["proper_test"] = evaluator.evaluate( 75 | get_stream( 76 | which_set="test", 77 | batch_size=1, 78 | length=446184)) 79 | 80 | import cPickle 81 | cPickle.dump(dict(results=results, 82 | old_popstats=old_popstats, 83 | new_popstats=new_popstats), 84 | open(sys.argv[1] + "_popstat_results.pkl", "w")) 85 | -------------------------------------------------------------------------------- /sequential_mnist_evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import theano, itertools, pprint, copy, numpy as np, theano.tensor as T, re 3 | from collections import OrderedDict 4 | from blocks.serialization import load 5 | import util 6 | 7 | from sequential_mnist import get_stream 8 | 9 | # argument: path to a checkpoint file 10 | main_loop = load(sys.argv[1]) 11 | print main_loop.log.current_row 12 | 13 | 14 | # extract population statistic updates 15 | updates = [update for update in main_loop.algorithm.updates 16 | # FRAGILE 17 | if re.search("_(mean|var)$", update[0].name)] 18 | print updates 19 | 20 | old_popstats = dict((popstat, popstat.get_value()) for popstat, _ in updates) 21 | 22 | 23 | # baseline doesn't need all this 24 | if updates: 25 | which_set = "train" 26 | batch_size = 5000 # -_- 27 | #nbatches = len(list(main_loop.data_stream.get_epoch_iterator())) 28 | nbatches = len(list(get_stream(which_set=which_set, batch_size=batch_size).get_epoch_iterator())) 29 | 30 | # destructure moving average expression to construct a new expression 31 | new_updates = [] 32 | batchstat_name = [] 33 | batchstat_list = [] 34 | for popstat, value in updates: 35 | 36 | print popstat 37 | batchstat = popstat.tag.estimand 38 | batchstat_name.append(popstat.name) 39 | batchstat_list.append(batchstat) 40 | 41 | old_popstats[popstat] = popstat.get_value() 42 | 43 | # FRAGILE: assume population statistics not used in computation of batch statistics 44 | # otherwise popstat should always have a reasonable value 45 | popstat.set_value(0 * popstat.get_value(borrow=True)) 46 | new_updates.append((popstat, popstat + batchstat / float(nbatches))) 47 | #new_updates.append((popstat, batchstat)) 48 | 49 | # FRAGILE: assume all the other algorithm updates are unneeded for computation of batch statistics 50 | estimate_fn = theano.function(main_loop.algorithm.inputs, batchstat_list, 51 | updates=new_updates, on_unused_input="warn") 52 | 53 | bstats = OrderedDict() 54 | bstats_mean = OrderedDict() 55 | for n in batchstat_name: 56 | bstats[n] = [] 57 | bstats_mean[n] = 0.0 58 | for batch in get_stream(which_set=which_set, batch_size=batch_size).get_epoch_iterator(as_dict=True): 59 | cur_bstat = estimate_fn(**batch) 60 | for i in xrange(len(cur_bstat)): 61 | bstats[batchstat_name[i]].append(cur_bstat[i]) 62 | bstats_mean[batchstat_name[i]] += cur_bstat[i] 63 | for k, v in bstats_mean.items(): 64 | bstats_mean[k] = v / float(nbatches) 65 | #for popstat, value in updates: 66 | #popstat.set_value(bstats_mean[popstat.name]) 67 | 68 | 69 | new_popstats = dict((popstat, popstat.get_value()) for popstat, _ in updates) 70 | 71 | 72 | from blocks.monitoring.evaluators import DatasetEvaluator 73 | results = dict() 74 | for situation in "training inference".split(): 75 | results[situation] = dict() 76 | outputs, = [ 77 | extension._evaluator.theano_variables 78 | for extension in main_loop.extensions 79 | if getattr(extension, "prefix", None) == "valid_%s" % situation] 80 | evaluator = DatasetEvaluator(outputs) 81 | for which_set in "train valid test".split(): 82 | if which_set == "test": 83 | results[situation][which_set] = evaluator.evaluate(get_stream(which_set=which_set, 84 | batch_size=5000)) 85 | else: 86 | results[situation][which_set] = evaluator.evaluate(get_stream(which_set=which_set, 87 | batch_size=5000)) 88 | 89 | results["proper_test"] = evaluator.evaluate( 90 | get_stream( 91 | which_set="test", 92 | batch_size=1000)) 93 | print 'Results: ', results["proper_test"] 94 | import cPickle 95 | cPickle.dump(dict(results=results, 96 | old_popstats=old_popstats, 97 | new_popstats=new_popstats), 98 | open(sys.argv[1] + "_popstat_results.pkl", "w")) 99 | -------------------------------------------------------------------------------- /results/iclr/plot.py: -------------------------------------------------------------------------------- 1 | import os, sys, pickle, zipfile, math 2 | from collections import OrderedDict 3 | from itertools import starmap 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import matplotlib 7 | from matplotlib import cm 8 | from blocks.serialization import load 9 | import logging 10 | logging.basicConfig() 11 | logger = logging.getLogger(__name__) 12 | 13 | matplotlib.rcParams.update({"font.size": 10}) 14 | 15 | def get(name, path, channel_name): 16 | print(name, path) 17 | npz_path = path.replace(".pkl", "_%s.npz" % channel_name) 18 | try: 19 | data = np.load(npz_path) 20 | t = np.array(data["t"]) 21 | v = np.array(data["v"]) 22 | except: 23 | with open(path, "rb") as file: 24 | log = load(file) 25 | tv = np.array([[t, v[channel_name]] 26 | for t, v in log.items() 27 | if v and channel_name in v]) 28 | log = None 29 | t = tv[:, 0] 30 | v = tv[:, 1] 31 | np.savez_compressed(npz_path, t=t, v=v) 32 | return dict(name=name, path=path, t=t, v=v) 33 | 34 | def plot_mnist(fold, ax): 35 | instances = [ 36 | dict(name="gamma 0.10", path="pmnist/gamma_1e-1/log.pkl"), 37 | dict(name="gamma 0.30", path="pmnist/gamma_3e-1/log.pkl"), 38 | dict(name="gamma 0.50", path="pmnist/gamma_5e-1/log.pkl"), 39 | dict(name="gamma 0.70", path="pmnist/gamma_7e-1/log.pkl"), 40 | dict(name="gamma 1.00", path="pmnist/gamma_1/log.pkl"), 41 | ] 42 | channel_name = "%s_training_cross_entropy" % fold 43 | instances = [get(channel_name=channel_name, **instance) 44 | for instance in instances] 45 | plot(instances, ax=ax) 46 | ax.set_ylabel("cross entropy") 47 | ax.set_xlabel("training steps") 48 | ax.set_ylim(dict(train=[0.0, 2.5], 49 | valid=[0.0, 2.5])[fold]) 50 | ax.set_xlim(dict(train=[0, 50000], 51 | valid=[0, 50000])[fold]) 52 | ax.set_title("Permuted MNIST %s" % fold) 53 | 54 | def plot_ptb(fold, ax): 55 | instances = [ 56 | #dict(name="LSTM", path="ptb/baseline/log.pkl"), 57 | #dict(name="gamma 0.01", path="ptb/batchnorm-ig0.01/log.pkl"), 58 | dict(name="gamma 0.10", path="ptb/batchnorm-ig0.10/log.pkl"), 59 | #dict(name="gamma 0.20", path="ptb/batchnorm-ig0.20/log.pkl"), 60 | dict(name="gamma 0.30", path="ptb/batchnorm-ig0.30/log.pkl"), 61 | #dict(name="gamma 0.40", path="ptb/batchnorm-ig0.40/log.pkl"), 62 | dict(name="gamma 0.50", path="ptb/batchnorm-ig0.50/log.pkl"), 63 | #dict(name="gamma 0.60", path="ptb/batchnorm-ig0.60/log.pkl"), 64 | dict(name="gamma 0.70", path="ptb/batchnorm-ig0.70/log.pkl"), 65 | #dict(name="gamma 0.80", path="ptb/batchnorm-ig0.80/log.pkl"), 66 | #dict(name="gamma 0.90", path="ptb/batchnorm-ig0.90/log.pkl"), 67 | dict(name="gamma 1.00", path="ptb/batchnorm-ig1.00/log.pkl"), 68 | ] 69 | channel_name = "%s_training_cross_entropy" % fold 70 | instances = [get(channel_name=channel_name, **instance) 71 | for instance in instances] 72 | 73 | for instance in instances: 74 | # nats to bits 75 | instance["t"] /= math.log(2) 76 | 77 | plot(instances, ax=ax) 78 | ax.set_ylabel("bits per character") 79 | ax.set_xlabel("training steps") 80 | ax.set_ylim(dict(train=[0.8, 1.1], 81 | valid=[1.0, 1.1])[fold]) 82 | ax.set_xlim(dict(train=[0, 19000], 83 | valid=[0, 19000])[fold]) 84 | ax.set_title("PTB %s" % fold) 85 | 86 | def plot(instances, ax=None): 87 | colors = cm.viridis(np.linspace(.2, .8, len(instances))) 88 | for color, instance in zip(colors, instances): 89 | ax.plot(instance["t"], instance["v"], label=instance["name"], 90 | c=color, linewidth=1) 91 | ax.legend(prop=dict(size=10)) 92 | 93 | if __name__ == "__main__": 94 | fig, axes = plt.subplots(2, 2) 95 | plot_mnist(fold="train", ax=axes[0][0]) 96 | plot_mnist(fold="valid", ax=axes[0][1]) 97 | plot_ptb(fold="train", ax=axes[1][0]) 98 | plot_ptb(fold="valid", ax=axes[1][1]) 99 | plt.tight_layout() 100 | #plt.show() 101 | fig.set_size_inches(800 / fig.dpi, 600 / fig.dpi) 102 | plt.savefig("gammas.pdf", bbox_inches="tight") 103 | -------------------------------------------------------------------------------- /extensions.py: -------------------------------------------------------------------------------- 1 | import tempfile, os.path, cPickle, zipfile, shutil, sys 2 | from cStringIO import StringIO 3 | from collections import OrderedDict 4 | import numpy as np 5 | import theano 6 | from blocks.extensions import SimpleExtension, Printing 7 | from blocks.serialization import secure_dump 8 | import blocks.config 9 | 10 | class PrintingTo(Printing): 11 | def __init__(self, path, **kwargs): 12 | super(PrintingTo, self).__init__(**kwargs) 13 | self.path = path 14 | with open(self.path, "w") as f: 15 | f.truncate(0) 16 | 17 | def do(self, *args, **kwargs): 18 | stdout, stringio = sys.stdout, StringIO() 19 | sys.stdout = stringio 20 | super(PrintingTo, self).do(*args, **kwargs) 21 | sys.stdout = stdout 22 | lines = stringio.getvalue().splitlines() 23 | with open(self.path, "a") as f: 24 | f.write("\n".join(lines)) 25 | f.write("\n") 26 | 27 | class DumpLog(SimpleExtension): 28 | def __init__(self, path, **kwargs): 29 | kwargs.setdefault("after_training", True) 30 | super(DumpLog, self).__init__(**kwargs) 31 | self.path = path 32 | 33 | def do(self, callback_name, *args): 34 | secure_dump(self.main_loop.log, self.path, use_cpickle=True) 35 | 36 | class DumpGraph(SimpleExtension): 37 | def __init__(self, path, **kwargs): 38 | kwargs["after_batch"] = True 39 | super(DumpGraph, self).__init__(**kwargs) 40 | self.path = path 41 | 42 | def do(self, which_callback, *args, **kwargs): 43 | try: 44 | self.done 45 | except AttributeError: 46 | if hasattr(self.main_loop.algorithm, "_function"): 47 | self.done = True 48 | with open(self.path, "w") as f: 49 | theano.printing.debugprint(self.main_loop.algorithm._function, file=f) 50 | 51 | class DumpBest(SimpleExtension): 52 | """dump if the `notification_name` record is present""" 53 | def __init__(self, notification_name, save_path, **kwargs): 54 | self.notification_name = notification_name 55 | self.save_path = save_path 56 | kwargs.setdefault("after_epoch", True) 57 | super(DumpBest, self).__init__(**kwargs) 58 | 59 | def do(self, which_callback, *args): 60 | if self.notification_name in self.main_loop.log.current_row: 61 | secure_dump(self.main_loop, self.save_path, use_cpickle=True) 62 | 63 | from blocks.algorithms import StepRule 64 | from blocks.roles import ALGORITHM_BUFFER, add_role 65 | from blocks.utils import shared_floatx 66 | from blocks.theano_expressions import l2_norm 67 | 68 | class StepMemory(StepRule): 69 | def compute_steps(self, steps): 70 | # memorize steps for one time step 71 | self.last_steps = OrderedDict() 72 | updates = [] 73 | for parameter, step in steps.items(): 74 | last_step = shared_floatx( 75 | parameter.get_value() * 0., 76 | "last_step_%s" % parameter.name) 77 | add_role(last_step, ALGORITHM_BUFFER) 78 | updates.append((last_step, step)) 79 | self.last_steps[parameter] = last_step 80 | 81 | # compare last and current step directions 82 | self.cosine = (sum((step * self.last_steps[parameter]).sum() 83 | for parameter, step in steps.items()) 84 | / l2_norm(steps.values()) 85 | / l2_norm(self.last_steps.values())) 86 | 87 | return steps, updates 88 | 89 | class DumpVariables(SimpleExtension): 90 | def __init__(self, save_path, inputs, variables, batch, **kwargs): 91 | super(DumpVariables, self).__init__(**kwargs) 92 | self.save_path = save_path 93 | self.variables = variables 94 | self.function = theano.function(inputs, variables, on_unused_input="warn") 95 | self.batch = batch 96 | self.i = 0 97 | 98 | def do(self, which_callback, *args): 99 | values = dict((variable.name, np.asarray(value)) for variable, value in 100 | zip(self.variables, self.function(**self.batch))) 101 | secure_dump(values, "%s_%i.pkl" % (self.save_path, self.i)) 102 | self.i += 1 103 | 104 | class SharedVariableModifier(SimpleExtension): 105 | def __init__(self, parameter, function, **kwargs): 106 | kwargs.setdefault("after_batch", True) 107 | super(SharedVariableModifier, self).__init__(**kwargs) 108 | self.parameter = parameter 109 | self.function = function 110 | 111 | def do(self, which_callback, *args): 112 | iterations_done = self.main_loop.log.status['iterations_done'] 113 | old_value = self.parameter.get_value() 114 | new_value = self.function(iterations_done, old_value) 115 | self.parameter.set_value(new_value) 116 | -------------------------------------------------------------------------------- /paper/nips15submit_e.sty: -------------------------------------------------------------------------------- 1 | %%%% NIPS Macros (LaTex) 2 | %%%% Style File 3 | %%%% Dec 12, 1990 Rev Aug 14, 1991; Sept, 1995; April, 1997; April, 1999 4 | 5 | % This file can be used with Latex2e whether running in main mode, or 6 | % 2.09 compatibility mode. 7 | % 8 | % If using main mode, you need to include the commands 9 | % \documentclass{article} 10 | % \usepackage{nips10submit_e,times} 11 | % as the first lines in your document. Or, if you do not have Times 12 | % Roman font available, you can just use 13 | % \documentclass{article} 14 | % \usepackage{nips10submit_e} 15 | % instead. 16 | % 17 | % If using 2.09 compatibility mode, you need to include the command 18 | % \documentstyle[nips10submit_09,times]{article} 19 | % as the first line in your document. Or, if you do not have Times 20 | % Roman font available, you can include the command 21 | % \documentstyle[nips10submit_09]{article} 22 | % instead. 23 | 24 | % Change the overall width of the page. If these parameters are 25 | % changed, they will require corresponding changes in the 26 | % maketitle section. 27 | % 28 | \usepackage{eso-pic} % used by \AddToShipoutPicture 29 | 30 | \renewcommand{\topfraction}{0.95} % let figure take up nearly whole page 31 | \renewcommand{\textfraction}{0.05} % let figure take up nearly whole page 32 | 33 | % Define nipsfinal, set to true if nipsfinalcopy is defined 34 | \newif\ifnipsfinal 35 | \nipsfinalfalse 36 | \def\nipsfinalcopy{\nipsfinaltrue} 37 | \font\nipstenhv = phvb at 8pt % *** IF THIS FAILS, SEE nips10submit_e.sty *** 38 | 39 | % Specify the dimensions of each page 40 | 41 | \setlength{\paperheight}{11in} 42 | \setlength{\paperwidth}{8.5in} 43 | 44 | \oddsidemargin .5in % Note \oddsidemargin = \evensidemargin 45 | \evensidemargin .5in 46 | \marginparwidth 0.07 true in 47 | %\marginparwidth 0.75 true in 48 | %\topmargin 0 true pt % Nominal distance from top of page to top of 49 | %\topmargin 0.125in 50 | \topmargin -0.625in 51 | \addtolength{\headsep}{0.25in} 52 | \textheight 9.0 true in % Height of text (including footnotes & figures) 53 | \textwidth 5.5 true in % Width of text line. 54 | \widowpenalty=10000 55 | \clubpenalty=10000 56 | 57 | % \thispagestyle{empty} \pagestyle{empty} 58 | \flushbottom \sloppy 59 | 60 | % We're never going to need a table of contents, so just flush it to 61 | % save space --- suggested by drstrip@sandia-2 62 | \def\addcontentsline#1#2#3{} 63 | 64 | % Title stuff, taken from deproc. 65 | \def\maketitle{\par 66 | \begingroup 67 | \def\thefootnote{\fnsymbol{footnote}} 68 | \def\@makefnmark{\hbox to 0pt{$^{\@thefnmark}$\hss}} % for perfect author 69 | % name centering 70 | % The footnote-mark was overlapping the footnote-text, 71 | % added the following to fix this problem (MK) 72 | \long\def\@makefntext##1{\parindent 1em\noindent 73 | \hbox to1.8em{\hss $\m@th ^{\@thefnmark}$}##1} 74 | \@maketitle \@thanks 75 | \endgroup 76 | \setcounter{footnote}{0} 77 | \let\maketitle\relax \let\@maketitle\relax 78 | \gdef\@thanks{}\gdef\@author{}\gdef\@title{}\let\thanks\relax} 79 | 80 | % The toptitlebar has been raised to top-justify the first page 81 | 82 | % Title (includes both anonimized and non-anonimized versions) 83 | \def\@maketitle{\vbox{\hsize\textwidth 84 | \linewidth\hsize \vskip 0.1in \toptitlebar \centering 85 | {\LARGE\bf \@title\par} \bottomtitlebar % \vskip 0.1in % minus 86 | \ifnipsfinal 87 | \def\And{\end{tabular}\hfil\linebreak[0]\hfil 88 | \begin{tabular}[t]{c}\bf\rule{\z@}{24pt}\ignorespaces}% 89 | \def\AND{\end{tabular}\hfil\linebreak[4]\hfil 90 | \begin{tabular}[t]{c}\bf\rule{\z@}{24pt}\ignorespaces}% 91 | \begin{tabular}[t]{c}\bf\rule{\z@}{24pt}\@author\end{tabular}% 92 | \else 93 | \begin{tabular}[t]{c}\bf\rule{\z@}{24pt} 94 | Anonymous Author(s) \\ 95 | Affiliation \\ 96 | Address \\ 97 | \texttt{email} \\ 98 | \end{tabular}% 99 | \fi 100 | \vskip 0.3in minus 0.1in}} 101 | 102 | \renewenvironment{abstract}{\vskip.075in\centerline{\large\bf 103 | Abstract}\vspace{0.5ex}\begin{quote}}{\par\end{quote}\vskip 1ex} 104 | 105 | % sections with less space 106 | \def\section{\@startsection {section}{1}{\z@}{-2.0ex plus 107 | -0.5ex minus -.2ex}{1.5ex plus 0.3ex 108 | minus0.2ex}{\large\bf\raggedright}} 109 | 110 | \def\subsection{\@startsection{subsection}{2}{\z@}{-1.8ex plus 111 | -0.5ex minus -.2ex}{0.8ex plus .2ex}{\normalsize\bf\raggedright}} 112 | \def\subsubsection{\@startsection{subsubsection}{3}{\z@}{-1.5ex 113 | plus -0.5ex minus -.2ex}{0.5ex plus 114 | .2ex}{\normalsize\bf\raggedright}} 115 | \def\paragraph{\@startsection{paragraph}{4}{\z@}{1.5ex plus 116 | 0.5ex minus .2ex}{-1em}{\normalsize\bf}} 117 | \def\subparagraph{\@startsection{subparagraph}{5}{\z@}{1.5ex plus 118 | 0.5ex minus .2ex}{-1em}{\normalsize\bf}} 119 | \def\subsubsubsection{\vskip 120 | 5pt{\noindent\normalsize\rm\raggedright}} 121 | 122 | 123 | % Footnotes 124 | \footnotesep 6.65pt % 125 | \skip\footins 9pt plus 4pt minus 2pt 126 | \def\footnoterule{\kern-3pt \hrule width 12pc \kern 2.6pt } 127 | \setcounter{footnote}{0} 128 | 129 | % Lists and paragraphs 130 | \parindent 0pt 131 | \topsep 4pt plus 1pt minus 2pt 132 | \partopsep 1pt plus 0.5pt minus 0.5pt 133 | \itemsep 2pt plus 1pt minus 0.5pt 134 | \parsep 2pt plus 1pt minus 0.5pt 135 | \parskip .5pc 136 | 137 | 138 | %\leftmargin2em 139 | \leftmargin3pc 140 | \leftmargini\leftmargin \leftmarginii 2em 141 | \leftmarginiii 1.5em \leftmarginiv 1.0em \leftmarginv .5em 142 | 143 | %\labelsep \labelsep 5pt 144 | 145 | \def\@listi{\leftmargin\leftmargini} 146 | \def\@listii{\leftmargin\leftmarginii 147 | \labelwidth\leftmarginii\advance\labelwidth-\labelsep 148 | \topsep 2pt plus 1pt minus 0.5pt 149 | \parsep 1pt plus 0.5pt minus 0.5pt 150 | \itemsep \parsep} 151 | \def\@listiii{\leftmargin\leftmarginiii 152 | \labelwidth\leftmarginiii\advance\labelwidth-\labelsep 153 | \topsep 1pt plus 0.5pt minus 0.5pt 154 | \parsep \z@ \partopsep 0.5pt plus 0pt minus 0.5pt 155 | \itemsep \topsep} 156 | \def\@listiv{\leftmargin\leftmarginiv 157 | \labelwidth\leftmarginiv\advance\labelwidth-\labelsep} 158 | \def\@listv{\leftmargin\leftmarginv 159 | \labelwidth\leftmarginv\advance\labelwidth-\labelsep} 160 | \def\@listvi{\leftmargin\leftmarginvi 161 | \labelwidth\leftmarginvi\advance\labelwidth-\labelsep} 162 | 163 | \abovedisplayskip 7pt plus2pt minus5pt% 164 | \belowdisplayskip \abovedisplayskip 165 | \abovedisplayshortskip 0pt plus3pt% 166 | \belowdisplayshortskip 4pt plus3pt minus3pt% 167 | 168 | % Less leading in most fonts (due to the narrow columns) 169 | % The choices were between 1-pt and 1.5-pt leading 170 | %\def\@normalsize{\@setsize\normalsize{11pt}\xpt\@xpt} % got rid of @ (MK) 171 | \def\normalsize{\@setsize\normalsize{11pt}\xpt\@xpt} 172 | \def\small{\@setsize\small{10pt}\ixpt\@ixpt} 173 | \def\footnotesize{\@setsize\footnotesize{10pt}\ixpt\@ixpt} 174 | \def\scriptsize{\@setsize\scriptsize{8pt}\viipt\@viipt} 175 | \def\tiny{\@setsize\tiny{7pt}\vipt\@vipt} 176 | \def\large{\@setsize\large{14pt}\xiipt\@xiipt} 177 | \def\Large{\@setsize\Large{16pt}\xivpt\@xivpt} 178 | \def\LARGE{\@setsize\LARGE{20pt}\xviipt\@xviipt} 179 | \def\huge{\@setsize\huge{23pt}\xxpt\@xxpt} 180 | \def\Huge{\@setsize\Huge{28pt}\xxvpt\@xxvpt} 181 | 182 | \def\toptitlebar{\hrule height4pt\vskip .25in\vskip-\parskip} 183 | 184 | \def\bottomtitlebar{\vskip .29in\vskip-\parskip\hrule height1pt\vskip 185 | .09in} % 186 | %Reduced second vskip to compensate for adding the strut in \@author 187 | 188 | % Vertical Ruler 189 | % This code is, largely, from the CVPR 2010 conference style file 190 | % ----- define vruler 191 | \makeatletter 192 | \newbox\nipsrulerbox 193 | \newcount\nipsrulercount 194 | \newdimen\nipsruleroffset 195 | \newdimen\cv@lineheight 196 | \newdimen\cv@boxheight 197 | \newbox\cv@tmpbox 198 | \newcount\cv@refno 199 | \newcount\cv@tot 200 | % NUMBER with left flushed zeros \fillzeros[] 201 | \newcount\cv@tmpc@ \newcount\cv@tmpc 202 | \def\fillzeros[#1]#2{\cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi 203 | \cv@tmpc=1 % 204 | \loop\ifnum\cv@tmpc@<10 \else \divide\cv@tmpc@ by 10 \advance\cv@tmpc by 1 \fi 205 | \ifnum\cv@tmpc@=10\relax\cv@tmpc@=11\relax\fi \ifnum\cv@tmpc@>10 \repeat 206 | \ifnum#2<0\advance\cv@tmpc1\relax-\fi 207 | \loop\ifnum\cv@tmpc<#1\relax0\advance\cv@tmpc1\relax\fi \ifnum\cv@tmpc<#1 \repeat 208 | \cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi \relax\the\cv@tmpc@}% 209 | % \makevruler[][][][][] 210 | \def\makevruler[#1][#2][#3][#4][#5]{\begingroup\offinterlineskip 211 | \textheight=#5\vbadness=10000\vfuzz=120ex\overfullrule=0pt% 212 | \global\setbox\nipsrulerbox=\vbox to \textheight{% 213 | {\parskip=0pt\hfuzz=150em\cv@boxheight=\textheight 214 | \cv@lineheight=#1\global\nipsrulercount=#2% 215 | \cv@tot\cv@boxheight\divide\cv@tot\cv@lineheight\advance\cv@tot2% 216 | \cv@refno1\vskip-\cv@lineheight\vskip1ex% 217 | \loop\setbox\cv@tmpbox=\hbox to0cm{{\nipstenhv\hfil\fillzeros[#4]\nipsrulercount}}% 218 | \ht\cv@tmpbox\cv@lineheight\dp\cv@tmpbox0pt\box\cv@tmpbox\break 219 | \advance\cv@refno1\global\advance\nipsrulercount#3\relax 220 | \ifnum\cv@refno<\cv@tot\repeat}}\endgroup}% 221 | \makeatother 222 | % ----- end of vruler 223 | 224 | % \makevruler[][][][][] 225 | \def\nipsruler#1{\makevruler[12pt][#1][1][3][0.993\textheight]\usebox{\nipsrulerbox}} 226 | \AddToShipoutPicture{% 227 | \ifnipsfinal\else 228 | \nipsruleroffset=\textheight 229 | \advance\nipsruleroffset by -3.7pt 230 | \color[rgb]{.7,.7,.7} 231 | \AtTextUpperLeft{% 232 | \put(\LenToUnit{-35pt},\LenToUnit{-\nipsruleroffset}){%left ruler 233 | \nipsruler{\nipsrulercount}} 234 | } 235 | \fi 236 | } 237 | -------------------------------------------------------------------------------- /paper/iclr2017_conference.sty: -------------------------------------------------------------------------------- 1 | %%%% ICLR Macros (LaTex) 2 | %%%% Adapted by Hugo Larochelle from the NIPS stylefile Macros 3 | %%%% Style File 4 | %%%% Dec 12, 1990 Rev Aug 14, 1991; Sept, 1995; April, 1997; April, 1999; October 2014 5 | 6 | % This file can be used with Latex2e whether running in main mode, or 7 | % 2.09 compatibility mode. 8 | % 9 | % If using main mode, you need to include the commands 10 | % \documentclass{article} 11 | % \usepackage{iclr14submit_e,times} 12 | % 13 | 14 | % Change the overall width of the page. If these parameters are 15 | % changed, they will require corresponding changes in the 16 | % maketitle section. 17 | % 18 | \usepackage{eso-pic} % used by \AddToShipoutPicture 19 | \RequirePackage{fancyhdr} 20 | \RequirePackage{natbib} 21 | 22 | % modification to natbib citations 23 | \setcitestyle{authoryear,round,citesep={;},aysep={,},yysep={;}} 24 | 25 | \renewcommand{\topfraction}{0.95} % let figure take up nearly whole page 26 | \renewcommand{\textfraction}{0.05} % let figure take up nearly whole page 27 | 28 | % Define iclrfinal, set to true if iclrfinalcopy is defined 29 | \newif\ificlrfinal 30 | \iclrfinalfalse 31 | \def\iclrfinalcopy{\iclrfinaltrue} 32 | \font\iclrtenhv = phvb at 8pt 33 | 34 | % Specify the dimensions of each page 35 | 36 | \setlength{\paperheight}{11in} 37 | \setlength{\paperwidth}{8.5in} 38 | 39 | 40 | \oddsidemargin .5in % Note \oddsidemargin = \evensidemargin 41 | \evensidemargin .5in 42 | \marginparwidth 0.07 true in 43 | %\marginparwidth 0.75 true in 44 | %\topmargin 0 true pt % Nominal distance from top of page to top of 45 | %\topmargin 0.125in 46 | \topmargin -0.625in 47 | \addtolength{\headsep}{0.25in} 48 | \textheight 9.0 true in % Height of text (including footnotes & figures) 49 | \textwidth 5.5 true in % Width of text line. 50 | \widowpenalty=10000 51 | \clubpenalty=10000 52 | 53 | % \thispagestyle{empty} \pagestyle{empty} 54 | \flushbottom \sloppy 55 | 56 | % We're never going to need a table of contents, so just flush it to 57 | % save space --- suggested by drstrip@sandia-2 58 | \def\addcontentsline#1#2#3{} 59 | 60 | % Title stuff, taken from deproc. 61 | \def\maketitle{\par 62 | \begingroup 63 | \def\thefootnote{\fnsymbol{footnote}} 64 | \def\@makefnmark{\hbox to 0pt{$^{\@thefnmark}$\hss}} % for perfect author 65 | % name centering 66 | % The footnote-mark was overlapping the footnote-text, 67 | % added the following to fix this problem (MK) 68 | \long\def\@makefntext##1{\parindent 1em\noindent 69 | \hbox to1.8em{\hss $\m@th ^{\@thefnmark}$}##1} 70 | \@maketitle \@thanks 71 | \endgroup 72 | \setcounter{footnote}{0} 73 | \let\maketitle\relax \let\@maketitle\relax 74 | \gdef\@thanks{}\gdef\@author{}\gdef\@title{}\let\thanks\relax} 75 | 76 | % The toptitlebar has been raised to top-justify the first page 77 | 78 | \usepackage{fancyhdr} 79 | \pagestyle{fancy} 80 | \fancyhead{} 81 | 82 | % Title (includes both anonimized and non-anonimized versions) 83 | \def\@maketitle{\vbox{\hsize\textwidth 84 | %\linewidth\hsize \vskip 0.1in \toptitlebar \centering 85 | {\LARGE\sc \@title\par} 86 | %\bottomtitlebar % \vskip 0.1in % minus 87 | \ificlrfinal 88 | \lhead{Published as a conference paper at ICLR 2017} 89 | \def\And{\end{tabular}\hfil\linebreak[0]\hfil 90 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 91 | \def\AND{\end{tabular}\hfil\linebreak[4]\hfil 92 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 93 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\@author\end{tabular}% 94 | \else 95 | \lhead{Under review as a conference paper at ICLR 2017} 96 | \def\And{\end{tabular}\hfil\linebreak[0]\hfil 97 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 98 | \def\AND{\end{tabular}\hfil\linebreak[4]\hfil 99 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 100 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\@author\end{tabular}% 101 | % \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces 102 | %Anonymous Author(s) \\ 103 | %Affiliation \\ 104 | %Address \\ 105 | %\texttt{email} \\ 106 | %\end{tabular}% 107 | \fi 108 | \vskip 0.3in minus 0.1in}} 109 | 110 | \renewenvironment{abstract}{\vskip.075in\centerline{\large\sc 111 | Abstract}\vspace{0.5ex}\begin{quote}}{\par\end{quote}\vskip 1ex} 112 | 113 | % sections with less space 114 | \def\section{\@startsection {section}{1}{\z@}{-2.0ex plus 115 | -0.5ex minus -.2ex}{1.5ex plus 0.3ex 116 | minus0.2ex}{\large\sc\raggedright}} 117 | 118 | \def\subsection{\@startsection{subsection}{2}{\z@}{-1.8ex plus 119 | -0.5ex minus -.2ex}{0.8ex plus .2ex}{\normalsize\sc\raggedright}} 120 | \def\subsubsection{\@startsection{subsubsection}{3}{\z@}{-1.5ex 121 | plus -0.5ex minus -.2ex}{0.5ex plus 122 | .2ex}{\normalsize\sc\raggedright}} 123 | \def\paragraph{\@startsection{paragraph}{4}{\z@}{1.5ex plus 124 | 0.5ex minus .2ex}{-1em}{\normalsize\bf}} 125 | \def\subparagraph{\@startsection{subparagraph}{5}{\z@}{1.5ex plus 126 | 0.5ex minus .2ex}{-1em}{\normalsize\sc}} 127 | \def\subsubsubsection{\vskip 128 | 5pt{\noindent\normalsize\rm\raggedright}} 129 | 130 | 131 | % Footnotes 132 | \footnotesep 6.65pt % 133 | \skip\footins 9pt plus 4pt minus 2pt 134 | \def\footnoterule{\kern-3pt \hrule width 12pc \kern 2.6pt } 135 | \setcounter{footnote}{0} 136 | 137 | % Lists and paragraphs 138 | \parindent 0pt 139 | \topsep 4pt plus 1pt minus 2pt 140 | \partopsep 1pt plus 0.5pt minus 0.5pt 141 | \itemsep 2pt plus 1pt minus 0.5pt 142 | \parsep 2pt plus 1pt minus 0.5pt 143 | \parskip .5pc 144 | 145 | 146 | %\leftmargin2em 147 | \leftmargin3pc 148 | \leftmargini\leftmargin \leftmarginii 2em 149 | \leftmarginiii 1.5em \leftmarginiv 1.0em \leftmarginv .5em 150 | 151 | %\labelsep \labelsep 5pt 152 | 153 | \def\@listi{\leftmargin\leftmargini} 154 | \def\@listii{\leftmargin\leftmarginii 155 | \labelwidth\leftmarginii\advance\labelwidth-\labelsep 156 | \topsep 2pt plus 1pt minus 0.5pt 157 | \parsep 1pt plus 0.5pt minus 0.5pt 158 | \itemsep \parsep} 159 | \def\@listiii{\leftmargin\leftmarginiii 160 | \labelwidth\leftmarginiii\advance\labelwidth-\labelsep 161 | \topsep 1pt plus 0.5pt minus 0.5pt 162 | \parsep \z@ \partopsep 0.5pt plus 0pt minus 0.5pt 163 | \itemsep \topsep} 164 | \def\@listiv{\leftmargin\leftmarginiv 165 | \labelwidth\leftmarginiv\advance\labelwidth-\labelsep} 166 | \def\@listv{\leftmargin\leftmarginv 167 | \labelwidth\leftmarginv\advance\labelwidth-\labelsep} 168 | \def\@listvi{\leftmargin\leftmarginvi 169 | \labelwidth\leftmarginvi\advance\labelwidth-\labelsep} 170 | 171 | \abovedisplayskip 7pt plus2pt minus5pt% 172 | \belowdisplayskip \abovedisplayskip 173 | \abovedisplayshortskip 0pt plus3pt% 174 | \belowdisplayshortskip 4pt plus3pt minus3pt% 175 | 176 | % Less leading in most fonts (due to the narrow columns) 177 | % The choices were between 1-pt and 1.5-pt leading 178 | %\def\@normalsize{\@setsize\normalsize{11pt}\xpt\@xpt} % got rid of @ (MK) 179 | \def\normalsize{\@setsize\normalsize{11pt}\xpt\@xpt} 180 | \def\small{\@setsize\small{10pt}\ixpt\@ixpt} 181 | \def\footnotesize{\@setsize\footnotesize{10pt}\ixpt\@ixpt} 182 | \def\scriptsize{\@setsize\scriptsize{8pt}\viipt\@viipt} 183 | \def\tiny{\@setsize\tiny{7pt}\vipt\@vipt} 184 | \def\large{\@setsize\large{14pt}\xiipt\@xiipt} 185 | \def\Large{\@setsize\Large{16pt}\xivpt\@xivpt} 186 | \def\LARGE{\@setsize\LARGE{20pt}\xviipt\@xviipt} 187 | \def\huge{\@setsize\huge{23pt}\xxpt\@xxpt} 188 | \def\Huge{\@setsize\Huge{28pt}\xxvpt\@xxvpt} 189 | 190 | \def\toptitlebar{\hrule height4pt\vskip .25in\vskip-\parskip} 191 | 192 | \def\bottomtitlebar{\vskip .29in\vskip-\parskip\hrule height1pt\vskip 193 | .09in} % 194 | %Reduced second vskip to compensate for adding the strut in \@author 195 | 196 | 197 | %% % Vertical Ruler 198 | %% % This code is, largely, from the CVPR 2010 conference style file 199 | %% % ----- define vruler 200 | %% \makeatletter 201 | %% \newbox\iclrrulerbox 202 | %% \newcount\iclrrulercount 203 | %% \newdimen\iclrruleroffset 204 | %% \newdimen\cv@lineheight 205 | %% \newdimen\cv@boxheight 206 | %% \newbox\cv@tmpbox 207 | %% \newcount\cv@refno 208 | %% \newcount\cv@tot 209 | %% % NUMBER with left flushed zeros \fillzeros[] 210 | %% \newcount\cv@tmpc@ \newcount\cv@tmpc 211 | %% \def\fillzeros[#1]#2{\cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi 212 | %% \cv@tmpc=1 % 213 | %% \loop\ifnum\cv@tmpc@<10 \else \divide\cv@tmpc@ by 10 \advance\cv@tmpc by 1 \fi 214 | %% \ifnum\cv@tmpc@=10\relax\cv@tmpc@=11\relax\fi \ifnum\cv@tmpc@>10 \repeat 215 | %% \ifnum#2<0\advance\cv@tmpc1\relax-\fi 216 | %% \loop\ifnum\cv@tmpc<#1\relax0\advance\cv@tmpc1\relax\fi \ifnum\cv@tmpc<#1 \repeat 217 | %% \cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi \relax\the\cv@tmpc@}% 218 | %% % \makevruler[][][][][] 219 | %% \def\makevruler[#1][#2][#3][#4][#5]{\begingroup\offinterlineskip 220 | %% \textheight=#5\vbadness=10000\vfuzz=120ex\overfullrule=0pt% 221 | %% \global\setbox\iclrrulerbox=\vbox to \textheight{% 222 | %% {\parskip=0pt\hfuzz=150em\cv@boxheight=\textheight 223 | %% \cv@lineheight=#1\global\iclrrulercount=#2% 224 | %% \cv@tot\cv@boxheight\divide\cv@tot\cv@lineheight\advance\cv@tot2% 225 | %% \cv@refno1\vskip-\cv@lineheight\vskip1ex% 226 | %% \loop\setbox\cv@tmpbox=\hbox to0cm{{\iclrtenhv\hfil\fillzeros[#4]\iclrrulercount}}% 227 | %% \ht\cv@tmpbox\cv@lineheight\dp\cv@tmpbox0pt\box\cv@tmpbox\break 228 | %% \advance\cv@refno1\global\advance\iclrrulercount#3\relax 229 | %% \ifnum\cv@refno<\cv@tot\repeat}}\endgroup}% 230 | %% \makeatother 231 | %% % ----- end of vruler 232 | 233 | %% % \makevruler[][][][][] 234 | %% \def\iclrruler#1{\makevruler[12pt][#1][1][3][0.993\textheight]\usebox{\iclrrulerbox}} 235 | %% \AddToShipoutPicture{% 236 | %% \ificlrfinal\else 237 | %% \iclrruleroffset=\textheight 238 | %% \advance\iclrruleroffset by -3.7pt 239 | %% \color[rgb]{.7,.7,.7} 240 | %% \AtTextUpperLeft{% 241 | %% \put(\LenToUnit{-35pt},\LenToUnit{-\iclrruleroffset}){%left ruler 242 | %% \iclrruler{\iclrrulercount}} 243 | %% } 244 | %% \fi 245 | %% } 246 | %%% To add a vertical bar on the side 247 | %\AddToShipoutPicture{ 248 | %\AtTextLowerLeft{ 249 | %\hspace*{-1.8cm} 250 | %\colorbox[rgb]{0.7,0.7,0.7}{\small \parbox[b][\textheight]{0.1cm}{}}} 251 | %} 252 | 253 | -------------------------------------------------------------------------------- /paper/nips_2016.sty: -------------------------------------------------------------------------------- 1 | % partial rewrite of the LaTeX2e package for submissions to the 2 | % Conference on Neural Information Processing Systems (NIPS): 3 | % 4 | % - uses more LaTeX conventions 5 | % - line numbers at submission time replaced with aligned numbers from 6 | % lineno package 7 | % - \nipsfinalcopy replaced with [final] package option 8 | % - automatically loads times package for authors 9 | % - loads natbib automatically; this can be suppressed with the 10 | % [nonatbib] package option 11 | % - adds foot line to first page identifying the conference 12 | % 13 | % Roman Garnett (garnett@wustl.edu) and the many authors of 14 | % nips15submit_e.sty, including MK and drstrip@sandia 15 | % 16 | % last revision: March 2016 17 | 18 | \NeedsTeXFormat{LaTeX2e} 19 | \ProvidesPackage{nips_2016}[2016/03/07 NIPS 2016 submission/camera-ready style file] 20 | 21 | % declare final option, which creates camera-ready copy 22 | \newif\if@nipsfinal\@nipsfinalfalse 23 | \DeclareOption{final}{ 24 | \@nipsfinaltrue 25 | } 26 | 27 | % declare nonatbib option, which does not load natbib in case of 28 | % package clash (users can pass options to natbib via 29 | % \PassOptionsToPackage) 30 | \newif\if@natbib\@natbibtrue 31 | \DeclareOption{nonatbib}{ 32 | \@natbibfalse 33 | } 34 | 35 | \ProcessOptions\relax 36 | 37 | % fonts 38 | \renewcommand{\rmdefault}{ptm} 39 | \renewcommand{\sfdefault}{phv} 40 | 41 | % change this every year for notice string at bottom 42 | \newcommand{\@nipsordinal}{29th} 43 | \newcommand{\@nipsyear}{2016} 44 | \newcommand{\@nipslocation}{Barcelona, Spain} 45 | 46 | % handle tweaks for camera-ready copy vs. submission copy 47 | \if@nipsfinal 48 | \newcommand{\@noticestring}{% 49 | \@nipsordinal\/ Conference on Neural Information Processing Systems 50 | (NIPS \@nipsyear), \@nipslocation.% 51 | } 52 | \else 53 | \newcommand{\@noticestring}{% 54 | Submitted to \@nipsordinal\/ Conference on Neural Information 55 | Processing Systems (NIPS \@nipsyear). Do not distribute.% 56 | } 57 | 58 | % line numbers for submission 59 | \RequirePackage{lineno} 60 | \linenumbers 61 | 62 | % fix incompatibilities between lineno and amsmath, if required, by 63 | % transparently wrapping linenomath environments around amsmath 64 | % environments 65 | \AtBeginDocument{% 66 | \@ifpackageloaded{amsmath}{% 67 | \newcommand*\patchAmsMathEnvironmentForLineno[1]{% 68 | \expandafter\let\csname old#1\expandafter\endcsname\csname #1\endcsname 69 | \expandafter\let\csname oldend#1\expandafter\endcsname\csname end#1\endcsname 70 | \renewenvironment{#1}% 71 | {\linenomath\csname old#1\endcsname}% 72 | {\csname oldend#1\endcsname\endlinenomath}% 73 | }% 74 | \newcommand*\patchBothAmsMathEnvironmentsForLineno[1]{% 75 | \patchAmsMathEnvironmentForLineno{#1}% 76 | \patchAmsMathEnvironmentForLineno{#1*}% 77 | }% 78 | \patchBothAmsMathEnvironmentsForLineno{equation}% 79 | \patchBothAmsMathEnvironmentsForLineno{align}% 80 | \patchBothAmsMathEnvironmentsForLineno{flalign}% 81 | \patchBothAmsMathEnvironmentsForLineno{alignat}% 82 | \patchBothAmsMathEnvironmentsForLineno{gather}% 83 | \patchBothAmsMathEnvironmentsForLineno{multline}% 84 | }{} 85 | } 86 | \fi 87 | 88 | % load natbib unless told otherwise 89 | \if@natbib 90 | \RequirePackage{natbib} 91 | \fi 92 | 93 | % set page geometry 94 | \usepackage[ 95 | letterpaper, 96 | textheight=9in, 97 | textwidth=5.5in, 98 | top=1in 99 | ]{geometry} 100 | 101 | \widowpenalty=10000 102 | \clubpenalty=10000 103 | \flushbottom 104 | \sloppy 105 | 106 | % font sizes with reduced leading 107 | \renewcommand{\normalsize}{% 108 | \@setfontsize\normalsize\@xpt\@xipt 109 | \abovedisplayskip 7\p@ \@plus 2\p@ \@minus 5\p@ 110 | \abovedisplayshortskip \z@ \@plus 3\p@ 111 | \belowdisplayskip \abovedisplayskip 112 | \belowdisplayshortskip 4\p@ \@plus 3\p@ \@minus 3\p@ 113 | } 114 | \normalsize 115 | \renewcommand{\small}{% 116 | \@setfontsize\small\@ixpt\@xpt 117 | \abovedisplayskip 6\p@ \@plus 1.5\p@ \@minus 4\p@ 118 | \abovedisplayshortskip \z@ \@plus 2\p@ 119 | \belowdisplayskip \abovedisplayskip 120 | \belowdisplayshortskip 3\p@ \@plus 2\p@ \@minus 2\p@ 121 | } 122 | \renewcommand{\footnotesize}{\@setfontsize\footnotesize\@ixpt\@xpt} 123 | \renewcommand{\scriptsize}{\@setfontsize\scriptsize\@viipt\@viiipt} 124 | \renewcommand{\tiny}{\@setfontsize\tiny\@vipt\@viipt} 125 | \renewcommand{\large}{\@setfontsize\large\@xiipt{14}} 126 | \renewcommand{\Large}{\@setfontsize\Large\@xivpt{16}} 127 | \renewcommand{\LARGE}{\@setfontsize\LARGE\@xviipt{20}} 128 | \renewcommand{\huge}{\@setfontsize\huge\@xxpt{23}} 129 | \renewcommand{\Huge}{\@setfontsize\Huge\@xxvpt{28}} 130 | 131 | % sections with less space 132 | \providecommand{\section}{} 133 | \renewcommand{\section}{% 134 | \@startsection{section}{1}{\z@}% 135 | {-2.0ex \@plus -0.5ex \@minus -0.2ex}% 136 | { 1.5ex \@plus 0.3ex \@minus 0.2ex}% 137 | {\large\bf\raggedright}% 138 | } 139 | \providecommand{\subsection}{} 140 | \renewcommand{\subsection}{% 141 | \@startsection{subsection}{2}{\z@}% 142 | {-1.8ex \@plus -0.5ex \@minus -0.2ex}% 143 | { 0.8ex \@plus 0.2ex}% 144 | {\normalsize\bf\raggedright}% 145 | } 146 | \providecommand{\subsubsection}{} 147 | \renewcommand{\subsubsection}{% 148 | \@startsection{subsubsection}{3}{\z@}% 149 | {-1.5ex \@plus -0.5ex \@minus -0.2ex}% 150 | { 0.5ex \@plus 0.2ex}% 151 | {\normalsize\bf\raggedright}% 152 | } 153 | \providecommand{\paragraph}{} 154 | \renewcommand{\paragraph}{% 155 | \@startsection{paragraph}{4}{\z@}% 156 | {1.5ex \@plus 0.5ex \@minus 0.2ex}% 157 | {-1em}% 158 | {\normalsize\bf}% 159 | } 160 | \providecommand{\subparagraph}{} 161 | \renewcommand{\subparagraph}{% 162 | \@startsection{subparagraph}{5}{\z@}% 163 | {1.5ex \@plus 0.5ex \@minus 0.2ex}% 164 | {-1em}% 165 | {\normalsize\bf}% 166 | } 167 | \providecommand{\subsubsubsection}{} 168 | \renewcommand{\subsubsubsection}{% 169 | \vskip5pt{\noindent\normalsize\rm\raggedright}% 170 | } 171 | 172 | % float placement 173 | \renewcommand{\topfraction }{0.85} 174 | \renewcommand{\bottomfraction }{0.4} 175 | \renewcommand{\textfraction }{0.1} 176 | \renewcommand{\floatpagefraction}{0.7} 177 | 178 | \newlength{\@nipsabovecaptionskip}\setlength{\@nipsabovecaptionskip}{7\p@} 179 | \newlength{\@nipsbelowcaptionskip}\setlength{\@nipsbelowcaptionskip}{\z@} 180 | 181 | \setlength{\abovecaptionskip}{\@nipsabovecaptionskip} 182 | \setlength{\belowcaptionskip}{\@nipsbelowcaptionskip} 183 | 184 | % swap above/belowcaptionskip lengths for tables 185 | \renewenvironment{table} 186 | {\setlength{\abovecaptionskip}{\@nipsbelowcaptionskip}% 187 | \setlength{\belowcaptionskip}{\@nipsabovecaptionskip}% 188 | \@float{table}} 189 | {\end@float} 190 | 191 | % footnote formatting 192 | \setlength{\footnotesep }{6.65\p@} 193 | \setlength{\skip\footins}{9\p@ \@plus 4\p@ \@minus 2\p@} 194 | \renewcommand{\footnoterule}{\kern-3\p@ \hrule width 12pc \kern 2.6\p@} 195 | \setcounter{footnote}{0} 196 | 197 | % paragraph formatting 198 | \setlength{\parindent}{\z@} 199 | \setlength{\parskip }{5.5\p@} 200 | 201 | % list formatting 202 | \setlength{\topsep }{4\p@ \@plus 1\p@ \@minus 2\p@} 203 | \setlength{\partopsep }{1\p@ \@plus 0.5\p@ \@minus 0.5\p@} 204 | \setlength{\itemsep }{2\p@ \@plus 1\p@ \@minus 0.5\p@} 205 | \setlength{\parsep }{2\p@ \@plus 1\p@ \@minus 0.5\p@} 206 | \setlength{\leftmargin }{3pc} 207 | \setlength{\leftmargini }{\leftmargin} 208 | \setlength{\leftmarginii }{2em} 209 | \setlength{\leftmarginiii}{1.5em} 210 | \setlength{\leftmarginiv }{1.0em} 211 | \setlength{\leftmarginv }{0.5em} 212 | \def\@listi {\leftmargin\leftmargini} 213 | \def\@listii {\leftmargin\leftmarginii 214 | \labelwidth\leftmarginii 215 | \advance\labelwidth-\labelsep 216 | \topsep 2\p@ \@plus 1\p@ \@minus 0.5\p@ 217 | \parsep 1\p@ \@plus 0.5\p@ \@minus 0.5\p@ 218 | \itemsep \parsep} 219 | \def\@listiii{\leftmargin\leftmarginiii 220 | \labelwidth\leftmarginiii 221 | \advance\labelwidth-\labelsep 222 | \topsep 1\p@ \@plus 0.5\p@ \@minus 0.5\p@ 223 | \parsep \z@ 224 | \partopsep 0.5\p@ \@plus 0\p@ \@minus 0.5\p@ 225 | \itemsep \topsep} 226 | \def\@listiv {\leftmargin\leftmarginiv 227 | \labelwidth\leftmarginiv 228 | \advance\labelwidth-\labelsep} 229 | \def\@listv {\leftmargin\leftmarginv 230 | \labelwidth\leftmarginv 231 | \advance\labelwidth-\labelsep} 232 | \def\@listvi {\leftmargin\leftmarginvi 233 | \labelwidth\leftmarginvi 234 | \advance\labelwidth-\labelsep} 235 | 236 | % create title 237 | \providecommand{\maketitle}{} 238 | \renewcommand{\maketitle}{% 239 | \par 240 | \begingroup 241 | \renewcommand{\thefootnote}{\fnsymbol{footnote}} 242 | % for perfect author name centering 243 | \renewcommand{\@makefnmark}{\hbox to \z@{$^{\@thefnmark}$\hss}} 244 | % The footnote-mark was overlapping the footnote-text, 245 | % added the following to fix this problem (MK) 246 | \long\def\@makefntext##1{% 247 | \parindent 1em\noindent 248 | \hbox to 1.8em{\hss $\m@th ^{\@thefnmark}$}##1 249 | } 250 | \thispagestyle{empty} 251 | \@maketitle 252 | \@thanks 253 | \@notice 254 | \endgroup 255 | \let\maketitle\relax 256 | \let\thanks\relax 257 | } 258 | 259 | % rules for title box at top of first page 260 | \newcommand{\@toptitlebar}{ 261 | \hrule height 4\p@ 262 | \vskip 0.25in 263 | \vskip -\parskip% 264 | } 265 | \newcommand{\@bottomtitlebar}{ 266 | \vskip 0.29in 267 | \vskip -\parskip 268 | \hrule height 1\p@ 269 | \vskip 0.09in% 270 | } 271 | 272 | % create title (includes both anonymized and non-anonymized versions) 273 | \providecommand{\@maketitle}{} 274 | \renewcommand{\@maketitle}{% 275 | \vbox{% 276 | \hsize\textwidth 277 | \linewidth\hsize 278 | \vskip 0.1in 279 | \@toptitlebar 280 | \centering 281 | {\LARGE\bf \@title\par} 282 | \@bottomtitlebar 283 | \if@nipsfinal 284 | \def\And{% 285 | \end{tabular}\hfil\linebreak[0]\hfil% 286 | \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\ignorespaces% 287 | } 288 | \def\AND{% 289 | \end{tabular}\hfil\linebreak[4]\hfil% 290 | \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\ignorespaces% 291 | } 292 | \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\@author\end{tabular}% 293 | \else 294 | \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@} 295 | Anonymous Author(s) \\ 296 | Affiliation \\ 297 | Address \\ 298 | \texttt{email} \\ 299 | \end{tabular}% 300 | \fi 301 | \vskip 0.3in \@minus 0.1in 302 | } 303 | } 304 | 305 | % add conference notice to bottom of first page 306 | \newcommand{\ftype@noticebox}{8} 307 | \newcommand{\@notice}{% 308 | % give a bit of extra room back to authors on first page 309 | \enlargethispage{2\baselineskip}% 310 | \@float{noticebox}[b]% 311 | \footnotesize\@noticestring% 312 | \end@float% 313 | } 314 | 315 | % abstract styling 316 | \renewenvironment{abstract}% 317 | {% 318 | \vskip 0.075in% 319 | \centerline% 320 | {\large\bf Abstract}% 321 | \vspace{0.5ex}% 322 | \begin{quote}% 323 | } 324 | { 325 | \par% 326 | \end{quote}% 327 | \vskip 1ex% 328 | } 329 | 330 | \endinput 331 | -------------------------------------------------------------------------------- /paper/index.bib: -------------------------------------------------------------------------------- 1 | @article{batchnorm, 2 | author = {Sergey Ioffe and 3 | Christian Szegedy}, 4 | title = {Batch Normalization: Accelerating Deep Network Training by Reducing 5 | Internal Covariate Shift}, 6 | volume = {abs/1502.03167}, 7 | year = {2015}, 8 | timestamp = {Mon, 02 Mar 2015 14:17:34 +0100}, 9 | biburl = {http://dblp.uni-trier.de/rec/bib/journals/corr/IoffeS15}, 10 | bibsource = {dblp computer science bibliography, http://dblp.org} 11 | } 12 | 13 | @inproceedings{hessianfree, 14 | title={Learning recurrent neural networks with hessian-free optimization}, 15 | author={Martens, J. and Sutskever, I.}, 16 | booktitle={ICML}, 17 | year={2011} 18 | } 19 | 20 | @article{ollivier, 21 | author = {Yann Ollivier}, 22 | title = {Persistent Contextual Neural Networks for learning symbolic data sequences}, 23 | journal = {CoRR}, 24 | volume = {abs/1306.0514}, 25 | year = {2013}, 26 | timestamp = {Mon, 01 Jul 2013 20:31:24 +0200}, 27 | biburl = {http://dblp.uni-trier.de/rec/bib/journals/corr/Ollivier13}, 28 | bibsource = {dblp computer science bibliography, http://dblp.org} 29 | } 30 | 31 | @article{KFAC, 32 | author = {James Martens and 33 | Roger B. Grosse}, 34 | title = {Optimizing Neural Networks with Kronecker-factored Approximate Curvature}, 35 | journal = {CoRR}, 36 | volume = {abs/1503.05671}, 37 | year = {2015}, 38 | url = {http://arxiv.org/abs/1503.05671}, 39 | timestamp = {Thu, 09 Apr 2015 11:33:20 +0200}, 40 | biburl = {http://dblp.uni-trier.de/rec/bib/journals/corr/MartensG15}, 41 | bibsource = {dblp computer science bibliography, http://dblp.org} 42 | } 43 | 44 | @incollection{efficientbackprop, 45 | title={Efficient backprop}, 46 | author={LeCun, Yann A and Bottou, L{\'e}on and Orr, Genevieve B and M{\"u}ller, Klaus-Robert}, 47 | booktitle={Neural networks: Tricks of the trade}, 48 | pages={9--48}, 49 | year={2012}, 50 | publisher={Springer} 51 | } 52 | 53 | @inproceedings{raiko, 54 | title={Deep learning made easier by linear transformations in perceptrons}, 55 | author={Raiko, Tapani and Valpola, Harri and LeCun, Yann}, 56 | booktitle={International Conference on Artificial Intelligence and Statistics}, 57 | pages={924--932}, 58 | year={2012} 59 | } 60 | 61 | @inproceedings{naturalneuralnetworks, 62 | title={Natural neural networks}, 63 | author={Desjardins, Guillaume and Simonyan, Karen and Pascanu, Razvan and others}, 64 | booktitle={Advances in Neural Information Processing Systems}, 65 | pages={2062--2070}, 66 | year={2015} 67 | } 68 | 69 | @article{baidu, 70 | title={Deep Speech 2: End-to-End Speech Recognition in English and Mandarin}, 71 | author={Amodei, D. and others}, 72 | journal={arXiv:1512.02595}, 73 | year={2015} 74 | } 75 | 76 | @article{cesar, 77 | title={Batch Normalized Recurrent Neural Networks}, 78 | author={Laurent, C. and Pereyra, G. and Brakel, P. and Zhang, Y. and Bengio, Y.}, 79 | journal={ICASSP}, 80 | year={2016} 81 | } 82 | 83 | @article{weightnorm, 84 | title={Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks}, 85 | author={Salimans, Tim and Kingma, Diederik P}, 86 | journal={arXiv:1602.07868}, 87 | year={2016} 88 | } 89 | 90 | @article{lstm, 91 | title={Long short-term memory}, 92 | author={Hochreiter, S. and Schmidhuber, J}, 93 | journal={Neural computation}, 94 | year={1997}, 95 | publisher={MIT Press} 96 | } 97 | 98 | @article{urnn, 99 | title={Unitary Evolution Recurrent Neural Networks}, 100 | author={Arjovsky, M. and Shah, A. and Bengio, Y.}, 101 | journal={arXiv:1511.06464}, 102 | year={2015} 103 | } 104 | 105 | @article{amari, 106 | title={Natural gradient works efficiently in learning}, 107 | author={Amari, Shun-Ichi}, 108 | journal={Neural computation}, 109 | volume={10}, 110 | number={2}, 111 | pages={251--276}, 112 | year={1998}, 113 | publisher={MIT Press} 114 | } 115 | 116 | @article{pascanudifficulty, 117 | title={On the difficulty of training recurrent neural networks}, 118 | author={Pascanu, Razvan and Mikolov, Tomas and Bengio, Yoshua}, 119 | journal={arXiv:1211.5063}, 120 | year={2012} 121 | } 122 | 123 | @article{krueger, 124 | title={Regularizing RNNs by Stabilizing Activations}, 125 | author={Krueger, D and Memisevic, R.}, 126 | journal={ICLR}, 127 | year={2016} 128 | } 129 | 130 | @misc{rmsprop, 131 | title={{Lecture 6.5---RmsProp: Divide the gradient by a running average of its recent magnitude}}, 132 | author={Tieleman, T. and Hinton, G.}, 133 | howpublished={COURSERA: Neural Networks for Machine Learning}, 134 | year={2012} 135 | } 136 | 137 | @article{penntreebank, 138 | author = {Marcus, M. P. and Marcinkiewicz, M. and Santorini, B.}, 139 | title = {Building a Large Annotated Corpus of English: The Penn Treebank}, 140 | journal = {Comput. Linguist.}, 141 | year = {1993}, 142 | publisher = {MIT Press}, 143 | address = {Cambridge, MA, USA}, 144 | } 145 | 146 | @article{mahoney2009large, 147 | title={Large text compression benchmark}, 148 | author={Mahoney, M.}, 149 | year={2009} 150 | } 151 | 152 | @inproceedings{attentivereader, 153 | title={Teaching machines to read and comprehend}, 154 | author={Hermann, K. M. and Kocisky, T. and Grefenstette, E. and Espeholt, L. and Kay, W. and Suleyman, M. and Blunsom, P.}, 155 | booktitle={NIPS}, 156 | year={2015} 157 | } 158 | 159 | @article{blocks, 160 | author = {Bart van Merri{\"{e}}nboer and 161 | Dzmitry Bahdanau and 162 | Vincent Dumoulin and 163 | Dmitriy Serdyuk and 164 | David Warde{-}Farley and 165 | Jan Chorowski and 166 | Yoshua Bengio}, 167 | title = {Blocks and Fuel: Frameworks for deep learning}, 168 | journal = {CoRR}, 169 | volume = {abs/1506.00619}, 170 | year = {2015}, 171 | url = {http://arxiv.org/abs/1506.00619}, 172 | timestamp = {Wed, 01 Jul 2015 15:10:24 +0200}, 173 | biburl = {http://dblp.uni-trier.de/rec/bib/journals/corr/MerrienboerBDSW15}, 174 | bibsource = {dblp computer science bibliography, http://dblp.org} 175 | } 176 | 177 | @article{le2015simple, 178 | title={A Simple Way to Initialize Recurrent Networks of Rectified Linear Units}, 179 | author={Le, Quoc V and Jaitly, N. and Hinton, G.}, 180 | journal={arXiv:1504.00941}, 181 | year={2015} 182 | } 183 | 184 | @article{zhang2016architectural, 185 | title={Architectural Complexity Measures of Recurrent Neural Networks}, 186 | author={Zhang, S. and Wu, Y. and Che, T. and Lin, Z. and Memisevic, R. and Salakhutdinov, R. and Bengio, Y.}, 187 | journal={arXiv:1602.08210}, 188 | year={2016} 189 | } 190 | 191 | 192 | @article{bahdanau2014neural, 193 | title={Neural machine translation by jointly learning to align and translate}, 194 | author={Bahdanau, D. and Cho, K. and Bengio, Y.}, 195 | journal={ICLR}, 196 | year={2015} 197 | } 198 | 199 | 200 | @article{xu2015show, 201 | title={Show, attend and tell: Neural image caption generation with visual attention}, 202 | author={Xu, K. and Ba, J. and Kiros, R. and Courville, A. and Salakhutdinov, R. and Zemel, R. and Bengio, Y.}, 203 | journal={arXiv:1502.03044}, 204 | year={2015} 205 | } 206 | 207 | @inproceedings{yao2015describing, 208 | title={Describing videos by exploiting temporal structure}, 209 | author={Yao, L. and Torabi, A. and Cho, K. and Ballas, N. and Pal, C. and Larochelle, H. and Courville, A.}, 210 | booktitle={ICCV}, 211 | year={2015} 212 | } 213 | 214 | @article{bengio1994learning, 215 | title={Learning long-term dependencies with gradient descent is difficult}, 216 | author={Bengio, Y. and Simard, P. and Frasconi, P.}, 217 | journal={Neural Networks, IEEE Transactions on}, 218 | year={1994}, 219 | publisher={IEEE} 220 | } 221 | 222 | 223 | @article{hochreiter1991untersuchungen, 224 | title={Untersuchungen zu dynamischen neuronalen Netzen}, 225 | author={Hochreiter, S.}, 226 | journal={Master's thesis}, 227 | year={1991} 228 | } 229 | 230 | @article{cho2014learning, 231 | title={Learning phrase representations using rnn encoder-decoder for statistical machine translation}, 232 | author={Cho, K. and Van Merri{\"e}nboer, B. and Gulcehre, C. and Bahdanau, D. and Bougares, F. and Schwenk, H. and Bengio, Y.}, 233 | journal={arXiv:1406.1078}, 234 | year={2014} 235 | } 236 | 237 | @article{rumelhart1988learning, 238 | title={Learning representations by back-propagating errors}, 239 | author={Rumelhart, David E and Hinton, Geoffrey E and Williams, Ronald J}, 240 | journal={Cognitive modeling}, 241 | year={1988} 242 | } 243 | 244 | @article{shimodaira2000improving, 245 | title={Improving predictive inference under covariate shift by weighting the log-likelihood function}, 246 | author={Shimodaira, H.}, 247 | journal={Journal of statistical planning and inference}, 248 | year={2000}, 249 | publisher={Elsevier} 250 | } 251 | 252 | @article{mikolov2012subword, 253 | title={Subword language modeling with neural networks}, 254 | author={Mikolov, T. and Sutskever, I. and Deoras, A. and Le, H. and Kombrink, S. and Cernocky, J.}, 255 | journal={preprint}, 256 | year={2012} 257 | } 258 | 259 | @article{graves2013generating, 260 | title={Generating sequences with recurrent neural networks}, 261 | author={Graves, A.}, 262 | journal={arXiv:1308.0850}, 263 | year={2013} 264 | } 265 | 266 | @article{pachitariu2013regularization, 267 | title={Regularization and nonlinearities for neural language models: when are they needed?}, 268 | author={Pachitariu, Marius and Sahani, Maneesh}, 269 | journal={arXiv:1301.5650}, 270 | year={2013} 271 | } 272 | 273 | @article{kingma2014adam, 274 | title={Adam: A method for stochastic optimization}, 275 | author={Kingma, D. and Ba, J.}, 276 | journal={arXiv:1412.6980}, 277 | year={2014} 278 | } 279 | 280 | @ARTICLE{theano, 281 | author = { Team, The Theano Development and others }, 282 | collaboration = {Theano Development Team}, 283 | title = "{Theano: A {Python} framework for fast computation of mathematical expressions}", 284 | journal = {arXiv e-prints}, 285 | volume = {abs/1605.02688}, 286 | primaryClass = "cs.SC", 287 | keywords = {Computer Science - Symbolic Computation, Computer Science - Learning, Computer Science - Mathematical Software}, 288 | year = 2016, 289 | month = may, 290 | } 291 | 292 | @article{krueger2016zoneout, 293 | title={Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations}, 294 | author={Krueger, David and Maharaj, Tegan and Kram{\'a}r, J{\'a}nos and Pezeshki, Mohammad and Ballas, Nicolas and Ke, Nan Rosemary and Goyal, Anirudh and Bengio, Yoshua and Larochelle, Hugo and Courville, Aaron}, 295 | journal={arXiv:1606.01305}, 296 | year={2016} 297 | } 298 | 299 | @article{chung2016hierarchical, 300 | title={Hierarchical Multiscale Recurrent Neural Networks}, 301 | author={Chung, Junyoung and Ahn, Sungjin and Bengio, Yoshua}, 302 | journal={arXiv:1609.01704}, 303 | year={2016} 304 | } 305 | 306 | @article{ha2016hypernetworks, 307 | title={Hypernetworks}, 308 | author={Ha, David and Dai, Andrew and Le, Quoc V}, 309 | journal={arXiv:1609.09106}, 310 | year={2016} 311 | } 312 | 313 | @article{ba2016layer, 314 | title={Layer normalization}, 315 | author={Ba, Jimmy Lei and Kiros, Jamie Ryan and Hinton, Geoffrey E}, 316 | journal={arXiv:1607.06450}, 317 | year={2016} 318 | } 319 | 320 | @article{liao2016bridging, 321 | title={Bridging the Gaps Between Residual Learning, Recurrent Neural Networks and Visual Cortex}, 322 | author={Liao, Qianli and Poggio, Tomaso}, 323 | journal={arXiv:1604.03640}, 324 | year={2016} 325 | } 326 | 327 | -------------------------------------------------------------------------------- /sequential_mnist.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | from collections import OrderedDict 4 | import numpy as np 5 | import theano, theano.tensor as T 6 | from theano.sandbox.rng_mrg import MRG_RandomStreams 7 | import blocks.config 8 | import fuel.datasets, fuel.streams, fuel.transformers, fuel.schemes 9 | 10 | ### optimization algorithm definition 11 | from blocks.graph import ComputationGraph 12 | from blocks.algorithms import GradientDescent, RMSProp, StepClipping, CompositeRule, Momentum 13 | from blocks.model import Model 14 | from blocks.extensions import FinishAfter, Printing, ProgressBar, Timing 15 | from blocks.extensions.monitoring import TrainingDataMonitoring, DataStreamMonitoring 16 | from blocks.extensions.stopping import FinishIfNoImprovementAfter 17 | from blocks.extensions.training import TrackTheBest 18 | from blocks.extensions.saveload import Checkpoint 19 | from extensions import DumpLog, DumpBest, PrintingTo, DumpVariables 20 | from blocks.main_loop import MainLoop 21 | from blocks.utils import shared_floatx_zeros 22 | from blocks.roles import add_role, PARAMETER 23 | 24 | import util 25 | 26 | 27 | logging.basicConfig() 28 | logger = logging.getLogger(__name__) 29 | 30 | def zeros(shape): 31 | return np.zeros(shape, dtype=theano.config.floatX) 32 | 33 | def ones(shape): 34 | return np.ones(shape, dtype=theano.config.floatX) 35 | 36 | def glorot(shape): 37 | d = np.sqrt(6. / sum(shape)) 38 | return np.random.uniform(-d, +d, size=shape).astype(theano.config.floatX) 39 | 40 | def orthogonal(shape): 41 | # taken from https://gist.github.com/kastnerkyle/f7464d98fe8ca14f2a1a 42 | """ benanne lasagne ortho init (faster than qr approach)""" 43 | flat_shape = (shape[0], np.prod(shape[1:])) 44 | a = np.random.normal(0.0, 1.0, flat_shape) 45 | u, _, v = np.linalg.svd(a, full_matrices=False) 46 | q = u if u.shape == flat_shape else v # pick the one with the correct shape 47 | q = q.reshape(shape) 48 | return q[:shape[0], :shape[1]].astype(theano.config.floatX) 49 | 50 | _datasets = None 51 | def get_dataset(which_set): 52 | global _datasets 53 | if not _datasets: 54 | MNIST = fuel.datasets.MNIST 55 | # jump through hoops to instantiate only once and only if needed 56 | _datasets = dict( 57 | train=MNIST(which_sets=["train"], subset=slice(None, 50000)), 58 | valid=MNIST(which_sets=["train"], subset=slice(50000, None)), 59 | test=MNIST(which_sets=["test"])) 60 | return _datasets[which_set] 61 | 62 | def get_stream(which_set, batch_size, num_examples=None): 63 | dataset = get_dataset(which_set) 64 | if num_examples is None or num_examples > dataset.num_examples: 65 | num_examples = dataset.num_examples 66 | stream = fuel.streams.DataStream.default_stream( 67 | dataset, 68 | iteration_scheme=fuel.schemes.ShuffledScheme(num_examples, batch_size)) 69 | return stream 70 | 71 | 72 | def bn(x, gammas, betas, mean, var, args): 73 | assert mean.ndim == 1 74 | assert var.ndim == 1 75 | assert x.ndim == 2 76 | if not args.use_population_statistics: 77 | mean = x.mean(axis=0) 78 | var = x.var(axis=0) 79 | #var = T.maximum(var, args.epsilon) 80 | #var = var + args.epsilon 81 | 82 | if args.baseline: 83 | y = x + betas 84 | else: 85 | var_corrected = var + args.epsilon 86 | 87 | y = theano.tensor.nnet.bn.batch_normalization( 88 | inputs=x, gamma=gammas, beta=betas, 89 | mean=T.shape_padleft(mean), std=T.shape_padleft(T.sqrt(var_corrected)), 90 | mode="high_mem") 91 | assert mean.ndim == 1 92 | assert var.ndim == 1 93 | return y, mean, var 94 | 95 | activations = dict( 96 | tanh=T.tanh, 97 | identity=lambda x: x, 98 | relu=lambda x: T.max(0, x)) 99 | 100 | 101 | class Empty(object): 102 | pass 103 | 104 | class LSTM(object): 105 | def __init__(self, args, nclasses): 106 | self.nclasses = nclasses 107 | self.activation = activations[args.activation] 108 | 109 | def allocate_parameters(self, args): 110 | if hasattr(self, "parameters"): 111 | return self.parameters 112 | 113 | self.parameters = Empty() 114 | 115 | h0 = theano.shared(zeros((args.num_hidden,)), name="h0") 116 | c0 = theano.shared(zeros((args.num_hidden,)), name="c0") 117 | if args.init == "id": 118 | Wa = theano.shared(np.concatenate([ 119 | np.eye(args.num_hidden), 120 | orthogonal((args.num_hidden, 121 | 3 * args.num_hidden)),], axis=1).astype(theano.config.floatX), name="Wa") 122 | else: 123 | Wa = theano.shared(orthogonal((args.num_hidden, 4 * args.num_hidden)), name="Wa") 124 | Wx = theano.shared(orthogonal((1, 4 * args.num_hidden)), name="Wx") 125 | a_gammas = theano.shared(args.initial_gamma * ones((4 * args.num_hidden,)), name="a_gammas") 126 | b_gammas = theano.shared(args.initial_gamma * ones((4 * args.num_hidden,)), name="b_gammas") 127 | ab_betas = theano.shared(args.initial_beta * ones((4 * args.num_hidden,)), name="ab_betas") 128 | 129 | # forget gate bias initialization 130 | forget_biais = ab_betas.get_value() 131 | forget_biais[args.num_hidden:2*args.num_hidden] = 1. 132 | ab_betas.set_value(forget_biais) 133 | 134 | c_gammas = theano.shared(args.initial_gamma * ones((args.num_hidden,)), name="c_gammas") 135 | c_betas = theano.shared(args.initial_beta * ones((args.num_hidden,)), name="c_betas") 136 | 137 | if not args.baseline: 138 | parameters_list = [h0, c0, Wa, Wx, a_gammas, b_gammas, ab_betas, c_gammas, c_betas] 139 | else: 140 | parameters_list = [h0, c0, Wa, Wx, ab_betas, c_betas] 141 | for parameter in parameters_list: 142 | print parameter.name 143 | add_role(parameter, PARAMETER) 144 | setattr(self.parameters, parameter.name, parameter) 145 | 146 | return self.parameters 147 | 148 | 149 | def construct_graph_ref(self, args, x, length, popstats=None): 150 | 151 | p = self.allocate_parameters(args) 152 | 153 | if args.baseline: 154 | def bn(x, gammas, betas): 155 | return x + betas 156 | else: 157 | def bn(x, gammas, betas): 158 | mean, var = x.mean(axis=0, keepdims=True), x.var(axis=0, keepdims=True) 159 | # if only 160 | mean.tag.batchstat, var.tag.batchstat = True, True 161 | #var = T.maximum(var, args.epsilon) 162 | var = var + args.epsilon 163 | return (x - mean) / T.sqrt(var) * gammas + betas 164 | 165 | def stepfn(x, dummy_h, dummy_c, h, c): 166 | # a_mean, b_mean, c_mean, 167 | # a_var, b_var, c_var): 168 | 169 | a_mean, b_mean, c_mean = 0, 0, 0 170 | a_var, b_var, c_var = 0, 0, 0 171 | 172 | atilde = T.dot(h, p.Wa) 173 | btilde = x 174 | a_normal = bn(atilde, p.a_gammas, p.ab_betas) 175 | b_normal = bn(btilde, p.b_gammas, 0) 176 | ab = a_normal + b_normal 177 | g, f, i, o = [fn(ab[:, j * args.num_hidden:(j + 1) * args.num_hidden]) 178 | for j, fn in enumerate([self.activation] + 3 * [T.nnet.sigmoid])] 179 | c = dummy_c + f * c + i * g 180 | c_normal = bn(c, p.c_gammas, p.c_betas) 181 | h = dummy_h + o * self.activation(c_normal) 182 | return h, c, atilde, btilde, c_normal 183 | 184 | 185 | 186 | xtilde = T.dot(x, p.Wx) 187 | 188 | if args.noise: 189 | # prime h with white noise 190 | Trng = MRG_RandomStreams() 191 | h_prime = Trng.normal((xtilde.shape[1], args.num_hidden), std=args.noise) 192 | elif args.summarize: 193 | # prime h with mean of example 194 | h_prime = x.mean(axis=[0, 2])[:, None] 195 | else: 196 | h_prime = 0 197 | 198 | dummy_states = dict(h=T.zeros((xtilde.shape[0], xtilde.shape[1], args.num_hidden)), 199 | c=T.zeros((xtilde.shape[0], xtilde.shape[1], args.num_hidden))) 200 | 201 | [h, c, atilde, btilde, htilde], _ = theano.scan( 202 | stepfn, 203 | sequences=[xtilde, dummy_states["h"], dummy_states["c"]], 204 | outputs_info=[T.repeat(p.h0[None, :], xtilde.shape[1], axis=0) + h_prime, 205 | T.repeat(p.c0[None, :], xtilde.shape[1], axis=0), 206 | None, None, None]) 207 | return dict(h=h, c=c, 208 | atilde=atilde, btilde=btilde, htilde=htilde), [], dummy_states, popstats 209 | 210 | def construct_graph_popstats(self, args, x, length, popstats=None): 211 | p = self.allocate_parameters(args) 212 | 213 | 214 | def stepfn(x, dummy_h, dummy_c, 215 | pop_means_a, pop_means_b, pop_means_c, 216 | pop_vars_a, pop_vars_b, pop_vars_c, 217 | h, c): 218 | 219 | atilde = T.dot(h, p.Wa) 220 | btilde = x 221 | if args.baseline: 222 | a_normal, a_mean, a_var = bn(atilde, 1.0, p.ab_betas, pop_means_a, pop_vars_a, args) 223 | b_normal, b_mean, b_var = bn(btilde, 1.0, 0, pop_means_b, pop_vars_b, args) 224 | else: 225 | a_normal, a_mean, a_var = bn(atilde, p.a_gammas, p.ab_betas, pop_means_a, pop_vars_a, args) 226 | b_normal, b_mean, b_var = bn(btilde, p.b_gammas, 0, pop_means_b, pop_vars_b, args) 227 | ab = a_normal + b_normal 228 | g, f, i, o = [fn(ab[:, j * args.num_hidden:(j + 1) * args.num_hidden]) 229 | for j, fn in enumerate([self.activation] + 3 * [T.nnet.sigmoid])] 230 | c = dummy_c + f * c + i * g 231 | if args.baseline: 232 | c_normal, c_mean, c_var = bn(c, 1.0, p.c_betas, pop_means_c, pop_vars_c, args) 233 | else: 234 | c_normal, c_mean, c_var = bn(c, p.c_gammas, p.c_betas, pop_means_c, pop_vars_c, args) 235 | h = dummy_h + o * self.activation(c_normal) 236 | return (h, c, atilde, btilde, c_normal, 237 | a_mean, b_mean, c_mean, 238 | a_var, b_var, c_var) 239 | 240 | 241 | xtilde = T.dot(x, p.Wx) 242 | if args.noise: 243 | # prime h with white noise 244 | Trng = MRG_RandomStreams() 245 | h_prime = Trng.normal((xtilde.shape[1], args.num_hidden), std=args.noise) 246 | elif args.summarize: 247 | # prime h with mean of example 248 | h_prime = x.mean(axis=[0, 2])[:, None] 249 | else: 250 | h_prime = 0 251 | 252 | dummy_states = dict(h=T.zeros((xtilde.shape[0], xtilde.shape[1], args.num_hidden)), 253 | c=T.zeros((xtilde.shape[0], xtilde.shape[1], args.num_hidden))) 254 | 255 | if popstats is None: 256 | popstats = OrderedDict() 257 | for key, size in zip("abc", [4*args.num_hidden, 4*args.num_hidden, args.num_hidden]): 258 | for stat, init in zip("mean var".split(), [0, 1]): 259 | name = "%s_%s" % (key, stat) 260 | popstats[name] = theano.shared( 261 | init + np.zeros((length, size,), dtype=theano.config.floatX), 262 | name=name) 263 | popstats_seq = [popstats['a_mean'], popstats['b_mean'], popstats['c_mean'], 264 | popstats['a_var'], popstats['b_var'], popstats['c_var']] 265 | 266 | [h, c, atilde, btilde, htilde, 267 | batch_mean_a, batch_mean_b, batch_mean_c, 268 | batch_var_a, batch_var_b, batch_var_c ], _ = theano.scan( 269 | stepfn, 270 | sequences=[xtilde, dummy_states["h"], dummy_states["c"]] + popstats_seq, 271 | outputs_info=[T.repeat(p.h0[None, :], xtilde.shape[1], axis=0) + h_prime, 272 | T.repeat(p.c0[None, :], xtilde.shape[1], axis=0), 273 | None, None, None, 274 | None, None, None, 275 | None, None, None]) 276 | 277 | batchstats = OrderedDict() 278 | batchstats['a_mean'] = batch_mean_a 279 | batchstats['b_mean'] = batch_mean_b 280 | batchstats['c_mean'] = batch_mean_c 281 | batchstats['a_var'] = batch_var_a 282 | batchstats['b_var'] = batch_var_b 283 | batchstats['c_var'] = batch_var_c 284 | 285 | updates = OrderedDict() 286 | if not args.use_population_statistics: 287 | alpha = 1e-2 288 | for key in "abc": 289 | for stat, init in zip("mean var".split(), [0, 1]): 290 | name = "%s_%s" % (key, stat) 291 | popstats[name].tag.estimand = batchstats[name] 292 | updates[popstats[name]] = (alpha * batchstats[name] + 293 | (1 - alpha) * popstats[name]) 294 | return dict(h=h, c=c, 295 | atilde=atilde, btilde=btilde, htilde=htilde), updates, dummy_states, popstats 296 | 297 | 298 | def construct_common_graph(situation, args, outputs, dummy_states, Wy, by, y): 299 | ytilde = T.dot(outputs["h"][-1], Wy) + by 300 | yhat = T.nnet.softmax(ytilde) 301 | 302 | errors = T.neq(y, T.argmax(yhat, axis=1)) 303 | cross_entropies = T.nnet.categorical_crossentropy(yhat, y) 304 | 305 | error_rate = errors.mean().copy(name="error_rate") 306 | cross_entropy = cross_entropies.mean().copy(name="cross_entropy") 307 | cost = cross_entropy.copy(name="cost") 308 | graph = ComputationGraph([cost, cross_entropy, error_rate]) 309 | 310 | state_grads = dict((k, T.grad(cost, v)) for k, v in dummy_states.items()) 311 | 312 | extensions = [] 313 | # extensions = [ 314 | # DumpVariables("%s_hiddens" % situation, graph.inputs, 315 | # [v.copy(name="%s%s" % (k, suffix)) 316 | # for suffix, things in [("", outputs), ("_grad", state_grads)] 317 | # for k, v in things.items()], 318 | # batch=next(get_stream(which_set="train", 319 | # batch_size=args.batch_size, 320 | # num_examples=args.batch_size) 321 | # .get_epoch_iterator(as_dict=True)), 322 | # before_training=True, every_n_epochs=10)] 323 | 324 | return graph, extensions 325 | 326 | def construct_graphs(args, nclasses, length): 327 | constructor = LSTM if args.lstm else RNN 328 | 329 | if args.permuted: 330 | permutation = np.random.randint(0, length, size=(length,)) 331 | 332 | Wy = theano.shared(orthogonal((args.num_hidden, nclasses)), name="Wy") 333 | by = theano.shared(np.zeros((nclasses,), dtype=theano.config.floatX), name="by") 334 | 335 | ### graph construction 336 | inputs = dict(features=T.tensor4("features"), targets=T.imatrix("targets")) 337 | x, y = inputs["features"], inputs["targets"] 338 | 339 | theano.config.compute_test_value = "warn" 340 | batch = next(get_stream(which_set="train", batch_size=args.batch_size).get_epoch_iterator()) 341 | x.tag.test_value = batch[0] 342 | y.tag.test_value = batch[1] 343 | 344 | x = x.reshape((x.shape[0], length + 0, 1)) 345 | y = y.flatten(ndim=1) 346 | x = x.dimshuffle(1, 0, 2) 347 | x = x[0:, :, :] 348 | 349 | if args.permuted: 350 | x = x[permutation] 351 | 352 | args.use_population_statistics = False 353 | turd = constructor(args, nclasses) 354 | (outputs, training_updates, dummy_states, popstats) = turd.construct_graph_popstats(args, x, length) 355 | training_graph, training_extensions = construct_common_graph("training", args, outputs, dummy_states, Wy, by, y) 356 | 357 | args.use_population_statistics = True 358 | (inf_outputs, inference_updates, dummy_states, _) = turd.construct_graph_popstats(args, x, length, popstats=popstats) 359 | inference_graph, inference_extensions = construct_common_graph("inference", args, inf_outputs, dummy_states, Wy, by, y) 360 | 361 | add_role(Wy, PARAMETER) 362 | add_role(by, PARAMETER) 363 | args.use_population_statistics = False 364 | return (dict(training=training_graph, inference=inference_graph), 365 | dict(training=training_extensions, inference=inference_extensions), 366 | dict(training=training_updates, inference=inference_updates)) 367 | 368 | if __name__ == "__main__": 369 | sequence_length = 784 370 | nclasses = 10 371 | 372 | import argparse 373 | parser = argparse.ArgumentParser() 374 | parser.add_argument("--seed", type=int, default=1) 375 | parser.add_argument("--num-epochs", type=int, default=100) 376 | parser.add_argument("--learning-rate", type=float, default=1e-4) 377 | parser.add_argument("--epsilon", type=float, default=1e-5) 378 | parser.add_argument("--batch-size", type=int, default=100) 379 | parser.add_argument("--noise", type=float, default=None) 380 | parser.add_argument("--summarize", action="store_true") 381 | parser.add_argument("--num-hidden", type=int, default=100) 382 | parser.add_argument("--baseline", action="store_true") 383 | parser.add_argument("--lstm", action="store_true") 384 | parser.add_argument("--initial-gamma", type=float, default=0.1) 385 | parser.add_argument("--initial-beta", type=float, default=0) 386 | parser.add_argument("--cluster", action="store_true") 387 | parser.add_argument("--activation", choices=list(activations.keys()), default="tanh") 388 | parser.add_argument("--init", type=str, default="ortho") 389 | parser.add_argument("--continue-from") 390 | parser.add_argument("--permuted", action="store_true") 391 | args = parser.parse_args() 392 | 393 | #assert not (args.noise and args.summarize) 394 | np.random.seed(args.seed) 395 | blocks.config.config.default_seed = args.seed 396 | 397 | 398 | if args.continue_from: 399 | from blocks.serialization import load 400 | main_loop = load(args.continue_from) 401 | main_loop.run() 402 | sys.exit(0) 403 | 404 | graphs, extensions, updates = construct_graphs(args, nclasses, sequence_length) 405 | 406 | ### optimization algorithm definition 407 | step_rule = CompositeRule([ 408 | StepClipping(1.), 409 | #Momentum(learning_rate=args.learning_rate, momentum=0.9), 410 | RMSProp(learning_rate=args.learning_rate, decay_rate=0.5), 411 | ]) 412 | 413 | algorithm = GradientDescent(cost=graphs["training"].outputs[0], 414 | parameters=graphs["training"].parameters, 415 | step_rule=step_rule) 416 | algorithm.add_updates(updates["training"]) 417 | model = Model(graphs["training"].outputs[0]) 418 | extensions = extensions["training"] + extensions["inference"] 419 | 420 | 421 | # step monitor (after epoch to limit the log size) 422 | step_channels = [] 423 | step_channels.extend([ 424 | algorithm.steps[param].norm(2).copy(name="step_norm:%s" % name) 425 | for name, param in model.get_parameter_dict().items()]) 426 | step_channels.append(algorithm.total_step_norm.copy(name="total_step_norm")) 427 | step_channels.append(algorithm.total_gradient_norm.copy(name="total_gradient_norm")) 428 | step_channels.extend(graphs["training"].outputs) 429 | logger.warning("constructing training data monitor") 430 | extensions.append(TrainingDataMonitoring( 431 | step_channels, prefix="iteration", after_batch=False)) 432 | 433 | # parameter monitor 434 | extensions.append(DataStreamMonitoring( 435 | [param.norm(2).copy(name="parameter.norm:%s" % name) 436 | for name, param in model.get_parameter_dict().items()], 437 | data_stream=None, after_epoch=True)) 438 | 439 | # performance monitor 440 | for situation in "training".split(): # add inference 441 | for which_set in "train valid test".split(): 442 | logger.warning("constructing %s %s monitor" % (which_set, situation)) 443 | channels = list(graphs[situation].outputs) 444 | extensions.append(DataStreamMonitoring( 445 | channels, 446 | prefix="%s_%s" % (which_set, situation), after_epoch=True, 447 | data_stream=get_stream(which_set=which_set, batch_size=args.batch_size)))#, num_examples=1000))) 448 | for situation in "inference".split(): # add inference 449 | for which_set in "valid test".split(): 450 | logger.warning("constructing %s %s monitor" % (which_set, situation)) 451 | channels = list(graphs[situation].outputs) 452 | extensions.append(DataStreamMonitoring( 453 | channels, 454 | prefix="%s_%s" % (which_set, situation), after_epoch=True, 455 | data_stream=get_stream(which_set=which_set, batch_size=args.batch_size)))#, num_examples=1000))) 456 | 457 | extensions.extend([ 458 | TrackTheBest("valid_training_error_rate", "best_valid_training_error_rate"), 459 | DumpBest("best_valid_training_error_rate", "best.zip"), 460 | FinishAfter(after_n_epochs=args.num_epochs), 461 | #FinishIfNoImprovementAfter("best_valid_error_rate", epochs=50), 462 | Checkpoint("checkpoint.zip", on_interrupt=False, every_n_epochs=1, use_cpickle=True), 463 | DumpLog("log.pkl", after_epoch=True)]) 464 | 465 | if not args.cluster: 466 | extensions.append(ProgressBar()) 467 | 468 | extensions.extend([ 469 | Timing(), 470 | Printing(), 471 | PrintingTo("log"), 472 | ]) 473 | main_loop = MainLoop( 474 | data_stream=get_stream(which_set="train", batch_size=args.batch_size), 475 | algorithm=algorithm, extensions=extensions, model=model) 476 | main_loop.run() 477 | -------------------------------------------------------------------------------- /paper/fancyhdr.sty: -------------------------------------------------------------------------------- 1 | % fancyhdr.sty version 3.2 2 | % Fancy headers and footers for LaTeX. 3 | % Piet van Oostrum, 4 | % Dept of Computer and Information Sciences, University of Utrecht, 5 | % Padualaan 14, P.O. Box 80.089, 3508 TB Utrecht, The Netherlands 6 | % Telephone: +31 30 2532180. Email: piet@cs.uu.nl 7 | % ======================================================================== 8 | % LICENCE: 9 | % This file may be distributed under the terms of the LaTeX Project Public 10 | % License, as described in lppl.txt in the base LaTeX distribution. 11 | % Either version 1 or, at your option, any later version. 12 | % ======================================================================== 13 | % MODIFICATION HISTORY: 14 | % Sep 16, 1994 15 | % version 1.4: Correction for use with \reversemargin 16 | % Sep 29, 1994: 17 | % version 1.5: Added the \iftopfloat, \ifbotfloat and \iffloatpage commands 18 | % Oct 4, 1994: 19 | % version 1.6: Reset single spacing in headers/footers for use with 20 | % setspace.sty or doublespace.sty 21 | % Oct 4, 1994: 22 | % version 1.7: changed \let\@mkboth\markboth to 23 | % \def\@mkboth{\protect\markboth} to make it more robust 24 | % Dec 5, 1994: 25 | % version 1.8: corrections for amsbook/amsart: define \@chapapp and (more 26 | % importantly) use the \chapter/sectionmark definitions from ps@headings if 27 | % they exist (which should be true for all standard classes). 28 | % May 31, 1995: 29 | % version 1.9: The proposed \renewcommand{\headrulewidth}{\iffloatpage... 30 | % construction in the doc did not work properly with the fancyplain style. 31 | % June 1, 1995: 32 | % version 1.91: The definition of \@mkboth wasn't restored on subsequent 33 | % \pagestyle{fancy}'s. 34 | % June 1, 1995: 35 | % version 1.92: The sequence \pagestyle{fancyplain} \pagestyle{plain} 36 | % \pagestyle{fancy} would erroneously select the plain version. 37 | % June 1, 1995: 38 | % version 1.93: \fancypagestyle command added. 39 | % Dec 11, 1995: 40 | % version 1.94: suggested by Conrad Hughes 41 | % CJCH, Dec 11, 1995: added \footruleskip to allow control over footrule 42 | % position (old hardcoded value of .3\normalbaselineskip is far too high 43 | % when used with very small footer fonts). 44 | % Jan 31, 1996: 45 | % version 1.95: call \@normalsize in the reset code if that is defined, 46 | % otherwise \normalsize. 47 | % this is to solve a problem with ucthesis.cls, as this doesn't 48 | % define \@currsize. Unfortunately for latex209 calling \normalsize doesn't 49 | % work as this is optimized to do very little, so there \@normalsize should 50 | % be called. Hopefully this code works for all versions of LaTeX known to 51 | % mankind. 52 | % April 25, 1996: 53 | % version 1.96: initialize \headwidth to a magic (negative) value to catch 54 | % most common cases that people change it before calling \pagestyle{fancy}. 55 | % Note it can't be initialized when reading in this file, because 56 | % \textwidth could be changed afterwards. This is quite probable. 57 | % We also switch to \MakeUppercase rather than \uppercase and introduce a 58 | % \nouppercase command for use in headers. and footers. 59 | % May 3, 1996: 60 | % version 1.97: Two changes: 61 | % 1. Undo the change in version 1.8 (using the pagestyle{headings} defaults 62 | % for the chapter and section marks. The current version of amsbook and 63 | % amsart classes don't seem to need them anymore. Moreover the standard 64 | % latex classes don't use \markboth if twoside isn't selected, and this is 65 | % confusing as \leftmark doesn't work as expected. 66 | % 2. include a call to \ps@empty in ps@@fancy. This is to solve a problem 67 | % in the amsbook and amsart classes, that make global changes to \topskip, 68 | % which are reset in \ps@empty. Hopefully this doesn't break other things. 69 | % May 7, 1996: 70 | % version 1.98: 71 | % Added % after the line \def\nouppercase 72 | % May 7, 1996: 73 | % version 1.99: This is the alpha version of fancyhdr 2.0 74 | % Introduced the new commands \fancyhead, \fancyfoot, and \fancyhf. 75 | % Changed \headrulewidth, \footrulewidth, \footruleskip to 76 | % macros rather than length parameters, In this way they can be 77 | % conditionalized and they don't consume length registers. There is no need 78 | % to have them as length registers unless you want to do calculations with 79 | % them, which is unlikely. Note that this may make some uses of them 80 | % incompatible (i.e. if you have a file that uses \setlength or \xxxx=) 81 | % May 10, 1996: 82 | % version 1.99a: 83 | % Added a few more % signs 84 | % May 10, 1996: 85 | % version 1.99b: 86 | % Changed the syntax of \f@nfor to be resistent to catcode changes of := 87 | % Removed the [1] from the defs of \lhead etc. because the parameter is 88 | % consumed by the \@[xy]lhead etc. macros. 89 | % June 24, 1997: 90 | % version 1.99c: 91 | % corrected \nouppercase to also include the protected form of \MakeUppercase 92 | % \global added to manipulation of \headwidth. 93 | % \iffootnote command added. 94 | % Some comments added about \@fancyhead and \@fancyfoot. 95 | % Aug 24, 1998 96 | % version 1.99d 97 | % Changed the default \ps@empty to \ps@@empty in order to allow 98 | % \fancypagestyle{empty} redefinition. 99 | % Oct 11, 2000 100 | % version 2.0 101 | % Added LPPL license clause. 102 | % 103 | % A check for \headheight is added. An errormessage is given (once) if the 104 | % header is too large. Empty headers don't generate the error even if 105 | % \headheight is very small or even 0pt. 106 | % Warning added for the use of 'E' option when twoside option is not used. 107 | % In this case the 'E' fields will never be used. 108 | % 109 | % Mar 10, 2002 110 | % version 2.1beta 111 | % New command: \fancyhfoffset[place]{length} 112 | % defines offsets to be applied to the header/footer to let it stick into 113 | % the margins (if length > 0). 114 | % place is like in fancyhead, except that only E,O,L,R can be used. 115 | % This replaces the old calculation based on \headwidth and the marginpar 116 | % area. 117 | % \headwidth will be dynamically calculated in the headers/footers when 118 | % this is used. 119 | % 120 | % Mar 26, 2002 121 | % version 2.1beta2 122 | % \fancyhfoffset now also takes h,f as possible letters in the argument to 123 | % allow the header and footer widths to be different. 124 | % New commands \fancyheadoffset and \fancyfootoffset added comparable to 125 | % \fancyhead and \fancyfoot. 126 | % Errormessages and warnings have been made more informative. 127 | % 128 | % Dec 9, 2002 129 | % version 2.1 130 | % The defaults for \footrulewidth, \plainheadrulewidth and 131 | % \plainfootrulewidth are changed from \z@skip to 0pt. In this way when 132 | % someone inadvertantly uses \setlength to change any of these, the value 133 | % of \z@skip will not be changed, rather an errormessage will be given. 134 | 135 | % March 3, 2004 136 | % Release of version 3.0 137 | 138 | % Oct 7, 2004 139 | % version 3.1 140 | % Added '\endlinechar=13' to \fancy@reset to prevent problems with 141 | % includegraphics in header when verbatiminput is active. 142 | 143 | % March 22, 2005 144 | % version 3.2 145 | % reset \everypar (the real one) in \fancy@reset because spanish.ldf does 146 | % strange things with \everypar between << and >>. 147 | 148 | \def\ifancy@mpty#1{\def\temp@a{#1}\ifx\temp@a\@empty} 149 | 150 | \def\fancy@def#1#2{\ifancy@mpty{#2}\fancy@gbl\def#1{\leavevmode}\else 151 | \fancy@gbl\def#1{#2\strut}\fi} 152 | 153 | \let\fancy@gbl\global 154 | 155 | \def\@fancyerrmsg#1{% 156 | \ifx\PackageError\undefined 157 | \errmessage{#1}\else 158 | \PackageError{Fancyhdr}{#1}{}\fi} 159 | \def\@fancywarning#1{% 160 | \ifx\PackageWarning\undefined 161 | \errmessage{#1}\else 162 | \PackageWarning{Fancyhdr}{#1}{}\fi} 163 | 164 | % Usage: \@forc \var{charstring}{command to be executed for each char} 165 | % This is similar to LaTeX's \@tfor, but expands the charstring. 166 | 167 | \def\@forc#1#2#3{\expandafter\f@rc\expandafter#1\expandafter{#2}{#3}} 168 | \def\f@rc#1#2#3{\def\temp@ty{#2}\ifx\@empty\temp@ty\else 169 | \f@@rc#1#2\f@@rc{#3}\fi} 170 | \def\f@@rc#1#2#3\f@@rc#4{\def#1{#2}#4\f@rc#1{#3}{#4}} 171 | 172 | % Usage: \f@nfor\name:=list\do{body} 173 | % Like LaTeX's \@for but an empty list is treated as a list with an empty 174 | % element 175 | 176 | \newcommand{\f@nfor}[3]{\edef\@fortmp{#2}% 177 | \expandafter\@forloop#2,\@nil,\@nil\@@#1{#3}} 178 | 179 | % Usage: \def@ult \cs{defaults}{argument} 180 | % sets \cs to the characters from defaults appearing in argument 181 | % or defaults if it would be empty. All characters are lowercased. 182 | 183 | \newcommand\def@ult[3]{% 184 | \edef\temp@a{\lowercase{\edef\noexpand\temp@a{#3}}}\temp@a 185 | \def#1{}% 186 | \@forc\tmpf@ra{#2}% 187 | {\expandafter\if@in\tmpf@ra\temp@a{\edef#1{#1\tmpf@ra}}{}}% 188 | \ifx\@empty#1\def#1{#2}\fi} 189 | % 190 | % \if@in 191 | % 192 | \newcommand{\if@in}[4]{% 193 | \edef\temp@a{#2}\def\temp@b##1#1##2\temp@b{\def\temp@b{##1}}% 194 | \expandafter\temp@b#2#1\temp@b\ifx\temp@a\temp@b #4\else #3\fi} 195 | 196 | \newcommand{\fancyhead}{\@ifnextchar[{\f@ncyhf\fancyhead h}% 197 | {\f@ncyhf\fancyhead h[]}} 198 | \newcommand{\fancyfoot}{\@ifnextchar[{\f@ncyhf\fancyfoot f}% 199 | {\f@ncyhf\fancyfoot f[]}} 200 | \newcommand{\fancyhf}{\@ifnextchar[{\f@ncyhf\fancyhf{}}% 201 | {\f@ncyhf\fancyhf{}[]}} 202 | 203 | % New commands for offsets added 204 | 205 | \newcommand{\fancyheadoffset}{\@ifnextchar[{\f@ncyhfoffs\fancyheadoffset h}% 206 | {\f@ncyhfoffs\fancyheadoffset h[]}} 207 | \newcommand{\fancyfootoffset}{\@ifnextchar[{\f@ncyhfoffs\fancyfootoffset f}% 208 | {\f@ncyhfoffs\fancyfootoffset f[]}} 209 | \newcommand{\fancyhfoffset}{\@ifnextchar[{\f@ncyhfoffs\fancyhfoffset{}}% 210 | {\f@ncyhfoffs\fancyhfoffset{}[]}} 211 | 212 | % The header and footer fields are stored in command sequences with 213 | % names of the form: \f@ncy with for [eo], from [lcr] 214 | % and from [hf]. 215 | 216 | \def\f@ncyhf#1#2[#3]#4{% 217 | \def\temp@c{}% 218 | \@forc\tmpf@ra{#3}% 219 | {\expandafter\if@in\tmpf@ra{eolcrhf,EOLCRHF}% 220 | {}{\edef\temp@c{\temp@c\tmpf@ra}}}% 221 | \ifx\@empty\temp@c\else 222 | \@fancyerrmsg{Illegal char `\temp@c' in \string#1 argument: 223 | [#3]}% 224 | \fi 225 | \f@nfor\temp@c{#3}% 226 | {\def@ult\f@@@eo{eo}\temp@c 227 | \if@twoside\else 228 | \if\f@@@eo e\@fancywarning 229 | {\string#1's `E' option without twoside option is useless}\fi\fi 230 | \def@ult\f@@@lcr{lcr}\temp@c 231 | \def@ult\f@@@hf{hf}{#2\temp@c}% 232 | \@forc\f@@eo\f@@@eo 233 | {\@forc\f@@lcr\f@@@lcr 234 | {\@forc\f@@hf\f@@@hf 235 | {\expandafter\fancy@def\csname 236 | f@ncy\f@@eo\f@@lcr\f@@hf\endcsname 237 | {#4}}}}}} 238 | 239 | \def\f@ncyhfoffs#1#2[#3]#4{% 240 | \def\temp@c{}% 241 | \@forc\tmpf@ra{#3}% 242 | {\expandafter\if@in\tmpf@ra{eolrhf,EOLRHF}% 243 | {}{\edef\temp@c{\temp@c\tmpf@ra}}}% 244 | \ifx\@empty\temp@c\else 245 | \@fancyerrmsg{Illegal char `\temp@c' in \string#1 argument: 246 | [#3]}% 247 | \fi 248 | \f@nfor\temp@c{#3}% 249 | {\def@ult\f@@@eo{eo}\temp@c 250 | \if@twoside\else 251 | \if\f@@@eo e\@fancywarning 252 | {\string#1's `E' option without twoside option is useless}\fi\fi 253 | \def@ult\f@@@lcr{lr}\temp@c 254 | \def@ult\f@@@hf{hf}{#2\temp@c}% 255 | \@forc\f@@eo\f@@@eo 256 | {\@forc\f@@lcr\f@@@lcr 257 | {\@forc\f@@hf\f@@@hf 258 | {\expandafter\setlength\csname 259 | f@ncyO@\f@@eo\f@@lcr\f@@hf\endcsname 260 | {#4}}}}}% 261 | \fancy@setoffs} 262 | 263 | % Fancyheadings version 1 commands. These are more or less deprecated, 264 | % but they continue to work. 265 | 266 | \newcommand{\lhead}{\@ifnextchar[{\@xlhead}{\@ylhead}} 267 | \def\@xlhead[#1]#2{\fancy@def\f@ncyelh{#1}\fancy@def\f@ncyolh{#2}} 268 | \def\@ylhead#1{\fancy@def\f@ncyelh{#1}\fancy@def\f@ncyolh{#1}} 269 | 270 | \newcommand{\chead}{\@ifnextchar[{\@xchead}{\@ychead}} 271 | \def\@xchead[#1]#2{\fancy@def\f@ncyech{#1}\fancy@def\f@ncyoch{#2}} 272 | \def\@ychead#1{\fancy@def\f@ncyech{#1}\fancy@def\f@ncyoch{#1}} 273 | 274 | \newcommand{\rhead}{\@ifnextchar[{\@xrhead}{\@yrhead}} 275 | \def\@xrhead[#1]#2{\fancy@def\f@ncyerh{#1}\fancy@def\f@ncyorh{#2}} 276 | \def\@yrhead#1{\fancy@def\f@ncyerh{#1}\fancy@def\f@ncyorh{#1}} 277 | 278 | \newcommand{\lfoot}{\@ifnextchar[{\@xlfoot}{\@ylfoot}} 279 | \def\@xlfoot[#1]#2{\fancy@def\f@ncyelf{#1}\fancy@def\f@ncyolf{#2}} 280 | \def\@ylfoot#1{\fancy@def\f@ncyelf{#1}\fancy@def\f@ncyolf{#1}} 281 | 282 | \newcommand{\cfoot}{\@ifnextchar[{\@xcfoot}{\@ycfoot}} 283 | \def\@xcfoot[#1]#2{\fancy@def\f@ncyecf{#1}\fancy@def\f@ncyocf{#2}} 284 | \def\@ycfoot#1{\fancy@def\f@ncyecf{#1}\fancy@def\f@ncyocf{#1}} 285 | 286 | \newcommand{\rfoot}{\@ifnextchar[{\@xrfoot}{\@yrfoot}} 287 | \def\@xrfoot[#1]#2{\fancy@def\f@ncyerf{#1}\fancy@def\f@ncyorf{#2}} 288 | \def\@yrfoot#1{\fancy@def\f@ncyerf{#1}\fancy@def\f@ncyorf{#1}} 289 | 290 | \newlength{\fancy@headwidth} 291 | \let\headwidth\fancy@headwidth 292 | \newlength{\f@ncyO@elh} 293 | \newlength{\f@ncyO@erh} 294 | \newlength{\f@ncyO@olh} 295 | \newlength{\f@ncyO@orh} 296 | \newlength{\f@ncyO@elf} 297 | \newlength{\f@ncyO@erf} 298 | \newlength{\f@ncyO@olf} 299 | \newlength{\f@ncyO@orf} 300 | \newcommand{\headrulewidth}{0.4pt} 301 | \newcommand{\footrulewidth}{0pt} 302 | \newcommand{\footruleskip}{.3\normalbaselineskip} 303 | 304 | % Fancyplain stuff shouldn't be used anymore (rather 305 | % \fancypagestyle{plain} should be used), but it must be present for 306 | % compatibility reasons. 307 | 308 | \newcommand{\plainheadrulewidth}{0pt} 309 | \newcommand{\plainfootrulewidth}{0pt} 310 | \newif\if@fancyplain \@fancyplainfalse 311 | \def\fancyplain#1#2{\if@fancyplain#1\else#2\fi} 312 | 313 | \headwidth=-123456789sp %magic constant 314 | 315 | % Command to reset various things in the headers: 316 | % a.o. single spacing (taken from setspace.sty) 317 | % and the catcode of ^^M (so that epsf files in the header work if a 318 | % verbatim crosses a page boundary) 319 | % It also defines a \nouppercase command that disables \uppercase and 320 | % \Makeuppercase. It can only be used in the headers and footers. 321 | \let\fnch@everypar\everypar% save real \everypar because of spanish.ldf 322 | \def\fancy@reset{\fnch@everypar{}\restorecr\endlinechar=13 323 | \def\baselinestretch{1}% 324 | \def\nouppercase##1{{\let\uppercase\relax\let\MakeUppercase\relax 325 | \expandafter\let\csname MakeUppercase \endcsname\relax##1}}% 326 | \ifx\undefined\@newbaseline% NFSS not present; 2.09 or 2e 327 | \ifx\@normalsize\undefined \normalsize % for ucthesis.cls 328 | \else \@normalsize \fi 329 | \else% NFSS (2.09) present 330 | \@newbaseline% 331 | \fi} 332 | 333 | % Initialization of the head and foot text. 334 | 335 | % The default values still contain \fancyplain for compatibility. 336 | \fancyhf{} % clear all 337 | % lefthead empty on ``plain'' pages, \rightmark on even, \leftmark on odd pages 338 | % evenhead empty on ``plain'' pages, \leftmark on even, \rightmark on odd pages 339 | \if@twoside 340 | \fancyhead[el,or]{\fancyplain{}{\sl\rightmark}} 341 | \fancyhead[er,ol]{\fancyplain{}{\sl\leftmark}} 342 | \else 343 | \fancyhead[l]{\fancyplain{}{\sl\rightmark}} 344 | \fancyhead[r]{\fancyplain{}{\sl\leftmark}} 345 | \fi 346 | \fancyfoot[c]{\rm\thepage} % page number 347 | 348 | % Use box 0 as a temp box and dimen 0 as temp dimen. 349 | % This can be done, because this code will always 350 | % be used inside another box, and therefore the changes are local. 351 | 352 | \def\@fancyvbox#1#2{\setbox0\vbox{#2}\ifdim\ht0>#1\@fancywarning 353 | {\string#1 is too small (\the#1): ^^J Make it at least \the\ht0.^^J 354 | We now make it that large for the rest of the document.^^J 355 | This may cause the page layout to be inconsistent, however\@gobble}% 356 | \dimen0=#1\global\setlength{#1}{\ht0}\ht0=\dimen0\fi 357 | \box0} 358 | 359 | % Put together a header or footer given the left, center and 360 | % right text, fillers at left and right and a rule. 361 | % The \lap commands put the text into an hbox of zero size, 362 | % so overlapping text does not generate an errormessage. 363 | % These macros have 5 parameters: 364 | % 1. LEFTSIDE BEARING % This determines at which side the header will stick 365 | % out. When \fancyhfoffset is used this calculates \headwidth, otherwise 366 | % it is \hss or \relax (after expansion). 367 | % 2. \f@ncyolh, \f@ncyelh, \f@ncyolf or \f@ncyelf. This is the left component. 368 | % 3. \f@ncyoch, \f@ncyech, \f@ncyocf or \f@ncyecf. This is the middle comp. 369 | % 4. \f@ncyorh, \f@ncyerh, \f@ncyorf or \f@ncyerf. This is the right component. 370 | % 5. RIGHTSIDE BEARING. This is always \relax or \hss (after expansion). 371 | 372 | \def\@fancyhead#1#2#3#4#5{#1\hbox to\headwidth{\fancy@reset 373 | \@fancyvbox\headheight{\hbox 374 | {\rlap{\parbox[b]{\headwidth}{\raggedright#2}}\hfill 375 | \parbox[b]{\headwidth}{\centering#3}\hfill 376 | \llap{\parbox[b]{\headwidth}{\raggedleft#4}}}\headrule}}#5} 377 | 378 | \def\@fancyfoot#1#2#3#4#5{#1\hbox to\headwidth{\fancy@reset 379 | \@fancyvbox\footskip{\footrule 380 | \hbox{\rlap{\parbox[t]{\headwidth}{\raggedright#2}}\hfill 381 | \parbox[t]{\headwidth}{\centering#3}\hfill 382 | \llap{\parbox[t]{\headwidth}{\raggedleft#4}}}}}#5} 383 | 384 | \def\headrule{{\if@fancyplain\let\headrulewidth\plainheadrulewidth\fi 385 | \hrule\@height\headrulewidth\@width\headwidth \vskip-\headrulewidth}} 386 | 387 | \def\footrule{{\if@fancyplain\let\footrulewidth\plainfootrulewidth\fi 388 | \vskip-\footruleskip\vskip-\footrulewidth 389 | \hrule\@width\headwidth\@height\footrulewidth\vskip\footruleskip}} 390 | 391 | \def\ps@fancy{% 392 | \@ifundefined{@chapapp}{\let\@chapapp\chaptername}{}%for amsbook 393 | % 394 | % Define \MakeUppercase for old LaTeXen. 395 | % Note: we used \def rather than \let, so that \let\uppercase\relax (from 396 | % the version 1 documentation) will still work. 397 | % 398 | \@ifundefined{MakeUppercase}{\def\MakeUppercase{\uppercase}}{}% 399 | \@ifundefined{chapter}{\def\sectionmark##1{\markboth 400 | {\MakeUppercase{\ifnum \c@secnumdepth>\z@ 401 | \thesection\hskip 1em\relax \fi ##1}}{}}% 402 | \def\subsectionmark##1{\markright {\ifnum \c@secnumdepth >\@ne 403 | \thesubsection\hskip 1em\relax \fi ##1}}}% 404 | {\def\chaptermark##1{\markboth {\MakeUppercase{\ifnum \c@secnumdepth>\m@ne 405 | \@chapapp\ \thechapter. \ \fi ##1}}{}}% 406 | \def\sectionmark##1{\markright{\MakeUppercase{\ifnum \c@secnumdepth >\z@ 407 | \thesection. \ \fi ##1}}}}% 408 | %\csname ps@headings\endcsname % use \ps@headings defaults if they exist 409 | \ps@@fancy 410 | \gdef\ps@fancy{\@fancyplainfalse\ps@@fancy}% 411 | % Initialize \headwidth if the user didn't 412 | % 413 | \ifdim\headwidth<0sp 414 | % 415 | % This catches the case that \headwidth hasn't been initialized and the 416 | % case that the user added something to \headwidth in the expectation that 417 | % it was initialized to \textwidth. We compensate this now. This loses if 418 | % the user intended to multiply it by a factor. But that case is more 419 | % likely done by saying something like \headwidth=1.2\textwidth. 420 | % The doc says you have to change \headwidth after the first call to 421 | % \pagestyle{fancy}. This code is just to catch the most common cases were 422 | % that requirement is violated. 423 | % 424 | \global\advance\headwidth123456789sp\global\advance\headwidth\textwidth 425 | \fi} 426 | \def\ps@fancyplain{\ps@fancy \let\ps@plain\ps@plain@fancy} 427 | \def\ps@plain@fancy{\@fancyplaintrue\ps@@fancy} 428 | \let\ps@@empty\ps@empty 429 | \def\ps@@fancy{% 430 | \ps@@empty % This is for amsbook/amsart, which do strange things with \topskip 431 | \def\@mkboth{\protect\markboth}% 432 | \def\@oddhead{\@fancyhead\fancy@Oolh\f@ncyolh\f@ncyoch\f@ncyorh\fancy@Oorh}% 433 | \def\@oddfoot{\@fancyfoot\fancy@Oolf\f@ncyolf\f@ncyocf\f@ncyorf\fancy@Oorf}% 434 | \def\@evenhead{\@fancyhead\fancy@Oelh\f@ncyelh\f@ncyech\f@ncyerh\fancy@Oerh}% 435 | \def\@evenfoot{\@fancyfoot\fancy@Oelf\f@ncyelf\f@ncyecf\f@ncyerf\fancy@Oerf}% 436 | } 437 | % Default definitions for compatibility mode: 438 | % These cause the header/footer to take the defined \headwidth as width 439 | % And to shift in the direction of the marginpar area 440 | 441 | \def\fancy@Oolh{\if@reversemargin\hss\else\relax\fi} 442 | \def\fancy@Oorh{\if@reversemargin\relax\else\hss\fi} 443 | \let\fancy@Oelh\fancy@Oorh 444 | \let\fancy@Oerh\fancy@Oolh 445 | 446 | \let\fancy@Oolf\fancy@Oolh 447 | \let\fancy@Oorf\fancy@Oorh 448 | \let\fancy@Oelf\fancy@Oelh 449 | \let\fancy@Oerf\fancy@Oerh 450 | 451 | % New definitions for the use of \fancyhfoffset 452 | % These calculate the \headwidth from \textwidth and the specified offsets. 453 | 454 | \def\fancy@offsolh{\headwidth=\textwidth\advance\headwidth\f@ncyO@olh 455 | \advance\headwidth\f@ncyO@orh\hskip-\f@ncyO@olh} 456 | \def\fancy@offselh{\headwidth=\textwidth\advance\headwidth\f@ncyO@elh 457 | \advance\headwidth\f@ncyO@erh\hskip-\f@ncyO@elh} 458 | 459 | \def\fancy@offsolf{\headwidth=\textwidth\advance\headwidth\f@ncyO@olf 460 | \advance\headwidth\f@ncyO@orf\hskip-\f@ncyO@olf} 461 | \def\fancy@offself{\headwidth=\textwidth\advance\headwidth\f@ncyO@elf 462 | \advance\headwidth\f@ncyO@erf\hskip-\f@ncyO@elf} 463 | 464 | \def\fancy@setoffs{% 465 | % Just in case \let\headwidth\textwidth was used 466 | \fancy@gbl\let\headwidth\fancy@headwidth 467 | \fancy@gbl\let\fancy@Oolh\fancy@offsolh 468 | \fancy@gbl\let\fancy@Oelh\fancy@offselh 469 | \fancy@gbl\let\fancy@Oorh\hss 470 | \fancy@gbl\let\fancy@Oerh\hss 471 | \fancy@gbl\let\fancy@Oolf\fancy@offsolf 472 | \fancy@gbl\let\fancy@Oelf\fancy@offself 473 | \fancy@gbl\let\fancy@Oorf\hss 474 | \fancy@gbl\let\fancy@Oerf\hss} 475 | 476 | \newif\iffootnote 477 | \let\latex@makecol\@makecol 478 | \def\@makecol{\ifvoid\footins\footnotetrue\else\footnotefalse\fi 479 | \let\topfloat\@toplist\let\botfloat\@botlist\latex@makecol} 480 | \def\iftopfloat#1#2{\ifx\topfloat\empty #2\else #1\fi} 481 | \def\ifbotfloat#1#2{\ifx\botfloat\empty #2\else #1\fi} 482 | \def\iffloatpage#1#2{\if@fcolmade #1\else #2\fi} 483 | 484 | \newcommand{\fancypagestyle}[2]{% 485 | \@namedef{ps@#1}{\let\fancy@gbl\relax#2\relax\ps@fancy}} 486 | -------------------------------------------------------------------------------- /penntreebank.py: -------------------------------------------------------------------------------- 1 | import sys, os, util, functools 2 | import logging 3 | from collections import OrderedDict 4 | import numpy as np 5 | import theano, theano.tensor as T 6 | from theano.sandbox.rng_mrg import MRG_RandomStreams 7 | import blocks.config 8 | import fuel.datasets, fuel.streams, fuel.transformers, fuel.schemes 9 | 10 | from blocks.graph import ComputationGraph 11 | from blocks.algorithms import GradientDescent, RMSProp, StepClipping, CompositeRule, Momentum, Adam 12 | from blocks.model import Model 13 | from blocks.extensions import FinishAfter, Printing, ProgressBar, Timing 14 | from blocks.extensions.monitoring import TrainingDataMonitoring, DataStreamMonitoring 15 | from blocks.extensions.stopping import FinishIfNoImprovementAfter 16 | from blocks.extensions.training import TrackTheBest 17 | from blocks.extensions.saveload import Checkpoint 18 | from extensions import DumpLog, DumpBest, PrintingTo, DumpVariables, SharedVariableModifier 19 | from blocks.main_loop import MainLoop 20 | from blocks.utils import shared_floatx_zeros 21 | from blocks.roles import add_role, PARAMETER 22 | 23 | logging.basicConfig() 24 | logger = logging.getLogger(__name__) 25 | 26 | def learning_rate_decayer(decay_rate, i, learning_rate): 27 | return ((1. - decay_rate) * learning_rate).astype(theano.config.floatX) 28 | 29 | def zeros(shape): 30 | return np.zeros(shape, dtype=theano.config.floatX) 31 | 32 | def ones(shape): 33 | return np.ones(shape, dtype=theano.config.floatX) 34 | 35 | def glorot(shape): 36 | d = np.sqrt(6. / sum(shape)) 37 | return np.random.uniform(-d, +d, size=shape).astype(theano.config.floatX) 38 | 39 | def orthogonal(shape): 40 | # taken from https://gist.github.com/kastnerkyle/f7464d98fe8ca14f2a1a 41 | """ benanne lasagne ortho init (faster than qr approach)""" 42 | flat_shape = (shape[0], np.prod(shape[1:])) 43 | a = np.random.normal(0.0, 1.0, flat_shape) 44 | u, _, v = np.linalg.svd(a, full_matrices=False) 45 | q = u if u.shape == flat_shape else v # pick the one with the correct shape 46 | q = q.reshape(shape) 47 | return q[:shape[0], :shape[1]].astype(theano.config.floatX) 48 | 49 | def uniform(shape, scale): 50 | return np.random.uniform(-scale, +scale, size=shape).astype(theano.config.floatX) 51 | 52 | def softmax_lastaxis(x): 53 | # for sequence of distributions 54 | return T.nnet.softmax(x.reshape((-1, x.shape[-1]))).reshape(x.shape) 55 | 56 | def crossentropy_lastaxes(yhat, y): 57 | # for sequence of distributions/targets 58 | return -(y * T.log(yhat)).sum(axis=yhat.ndim - 1) 59 | 60 | _data_cache = dict() 61 | def get_data(which_set): 62 | if which_set not in _data_cache: 63 | path = os.environ["CHAR_LEVEL_PENNTREE_NPZ"] 64 | data = np.load(path) 65 | # put the entire thing on GPU in one-hot (takes 66 | # len(self.vocab) * len(self.data) * sizeof(floatX) bytes 67 | # which is about 1G for the training set and less for the 68 | # other sets) 69 | CudaNdarray = theano.sandbox.cuda.cuda_ndarray.cuda_ndarray.CudaNdarray 70 | # (doing it in numpy first because cudandarray doesn't accept 71 | # lists of indices) 72 | one_hot_data = np.eye(len(data["vocab"]), dtype=theano.config.floatX)[data[which_set]] 73 | _data_cache[which_set] = CudaNdarray(one_hot_data) 74 | return _data_cache[which_set] 75 | 76 | class PTB(fuel.datasets.Dataset): 77 | provides_sources = ('features',) 78 | example_iteration_scheme = None 79 | 80 | def __init__(self, which_set, length, augment=False): 81 | self.which_set = which_set 82 | self.length = length 83 | self.augment = augment 84 | self.data = get_data(which_set) 85 | self.num_examples = int(len(self.data) / self.length) 86 | if self.augment: 87 | # -1 so we have one self.length worth of room for augmentation 88 | self.num_examples -= 1 89 | super(PTB, self).__init__() 90 | 91 | def open(self): 92 | offset = 0 93 | if self.augment: 94 | # choose an offset to get some data augmentation by not always chopping 95 | # the examples at the same point. 96 | offset = np.random.randint(self.length) 97 | # none of this should copy 98 | data = self.data[offset:] 99 | # reshape to nonoverlapping examples 100 | data = (data[:self.num_examples * self.length] 101 | .reshape((self.num_examples, self.length, self.data.shape[1]))) 102 | # return the data so we will get it as the "state" argument to get_data 103 | return data 104 | 105 | def get_data(self, state, request): 106 | if isinstance(request, (tuple, list)): 107 | request = np.array(request, dtype=np.int64) 108 | return (state.take(request, 0),) 109 | return (state[request],) 110 | 111 | def get_stream(which_set, batch_size, length, num_examples=None, augment=False): 112 | dataset = PTB(which_set, length=length, augment=augment) 113 | if num_examples is None or num_examples > dataset.num_examples: 114 | num_examples = dataset.num_examples 115 | stream = fuel.streams.DataStream.default_stream( 116 | dataset, 117 | iteration_scheme=fuel.schemes.ShuffledScheme(num_examples, batch_size)) 118 | return stream 119 | 120 | activations = dict( 121 | tanh=T.tanh, 122 | identity=lambda x: x, 123 | relu=lambda x: T.max(0, x)) 124 | 125 | class Parameters(object): 126 | pass 127 | 128 | class BatchNormalization(object): 129 | def __init__(self, shape, initial_gamma=1, initial_beta=0, name=None, use_bias=True, epsilon=1e-5): 130 | self.shape = shape 131 | self.initial_gamma = initial_gamma 132 | self.initial_beta = initial_beta 133 | self.name = name 134 | self.use_bias = use_bias 135 | self.epsilon = epsilon 136 | 137 | @property 138 | def parameters(self): 139 | if not hasattr(self, "_parameters"): 140 | self._parameters = self.allocate_parameters() 141 | return self._parameters 142 | 143 | def allocate_parameters(self): 144 | parameters = Parameters() 145 | for parameter in [ 146 | theano.shared(self.initial_gamma * ones(self.shape), name="gammas"), 147 | theano.shared(self.initial_beta * ones(self.shape), name="betas")]: 148 | add_role(parameter, PARAMETER) 149 | setattr(parameters, parameter.name, parameter) 150 | if self.name: 151 | parameter.name = "%s.%s" % (self.name, parameter.name) 152 | return parameters 153 | 154 | def construct_graph(self, x, baseline=False, mean=None, var=None): 155 | p = self.parameters 156 | assert x.ndim == 2 157 | mean = x.mean(axis=0) if mean is None else mean 158 | var = x.var (axis=0) if var is None else var 159 | assert mean.ndim == 1 160 | assert var.ndim == 1 161 | betas = p.betas if self.use_bias else 0 162 | if baseline: 163 | y = x + betas 164 | else: 165 | y = theano.tensor.nnet.bn.batch_normalization( 166 | inputs=x, 167 | gamma=p.gammas, beta=betas, 168 | mean=T.shape_padleft(mean), 169 | std=T.shape_padleft(T.sqrt(var + self.epsilon))) 170 | return y, mean, var 171 | 172 | class LSTM(object): 173 | def __init__(self, args, nclasses): 174 | self.num_hidden = args.num_hidden 175 | self.initializer = args.initializer 176 | self.identity_hh = args.initialization == "identity" 177 | self.nclasses = nclasses 178 | self.activation = activations[args.activation] 179 | 180 | self.bn_a = BatchNormalization((4 * args.num_hidden,), initial_gamma=args.initial_gamma, name="bn_a", epsilon=args.epsilon) 181 | self.bn_b = BatchNormalization((4 * args.num_hidden,), initial_gamma=args.initial_gamma, name="bn_b", epsilon=args.epsilon, use_bias=False) 182 | self.bn_c = BatchNormalization(( args.num_hidden,), initial_gamma=args.initial_gamma, name="bn_c", epsilon=args.epsilon) 183 | 184 | @property 185 | def parameters(self): 186 | if not hasattr(self, "_parameters"): 187 | self._parameters = self.allocate_parameters() 188 | return self._parameters 189 | 190 | def allocate_parameters(self): 191 | parameters = Parameters() 192 | Wa = self.initializer((self.num_hidden, 4 * self.num_hidden)) 193 | 194 | if self.identity_hh: 195 | Wa[:self.num_hidden, :self.num_hidden] = np.eye(self.num_hidden) 196 | 197 | for parameter in [ 198 | theano.shared(zeros((self.num_hidden,)), name="h0"), 199 | theano.shared(zeros((self.num_hidden,)), name="c0"), 200 | theano.shared(Wa, name="Wa"), 201 | theano.shared(self.initializer((self.nclasses, 4 * self.num_hidden)), name="Wx")]: 202 | add_role(parameter, PARAMETER) 203 | setattr(parameters, parameter.name, parameter) 204 | 205 | # forget gate bias initialization 206 | ab_betas = self.bn_a.parameters.betas 207 | pffft = ab_betas.get_value() 208 | pffft[self.num_hidden:2*self.num_hidden] = 1. 209 | ab_betas.set_value(pffft) 210 | 211 | return parameters 212 | 213 | def construct_graph(self, args, x, length, popstats=None): 214 | p = self.parameters 215 | 216 | # use `symlength` where we need to be able to adapt to longer sequences 217 | # than the ones we trained on 218 | symlength = x.shape[0] 219 | t = T.cast(T.arange(symlength), "int16") 220 | long_sequence_is_long = T.ge(T.cast(T.arange(symlength), theano.config.floatX), length) 221 | batch_size = x.shape[1] 222 | dummy_states = dict(h=T.zeros((symlength, batch_size, args.num_hidden)), 223 | c=T.zeros((symlength, batch_size, args.num_hidden))) 224 | 225 | output_names = "h c atilde btilde".split() 226 | for key in "abc": 227 | for stat in "mean var".split(): 228 | output_names.append("%s_%s" % (key, stat)) 229 | 230 | def stepfn(t, long_sequence_is_long, x, dummy_h, dummy_c, h, c): 231 | # population statistics are sequences, but we use them 232 | # like a non-sequence and index it ourselves. this allows 233 | # us to generalize to longer sequences, in which case we 234 | # repeat the last element. 235 | popstats_by_key = dict() 236 | for key in "abc": 237 | popstats_by_key[key] = dict() 238 | for stat in "mean var".split(): 239 | if not args.baseline and args.use_population_statistics: 240 | popstat = popstats["%s_%s" % (key, stat)] 241 | # pluck the appropriate population statistic for this 242 | # time step out of the sequence, or take the last 243 | # element if we've gone beyond the training length. 244 | # if `long_sequence_is_long` then `t` may be unreliable 245 | # as it will overflow for looong sequences. 246 | popstat = theano.ifelse.ifelse( 247 | long_sequence_is_long, popstat[-1], popstat[t]) 248 | else: 249 | popstat = None 250 | popstats_by_key[key][stat] = popstat 251 | 252 | atilde, btilde = T.dot(h, p.Wa), T.dot(x, p.Wx) 253 | if args.no_normalize_terms_separately: 254 | a_normal, a_mean, a_var = self.bn_a.construct_graph(atilde + btilde, baseline=args.baseline, **popstats_by_key["a"]) 255 | # making sure all names are still available 256 | b_normal, b_mean, b_var = a_normal, a_mean, a_var 257 | ab = a_normal 258 | else: 259 | a_normal, a_mean, a_var = self.bn_a.construct_graph(atilde, baseline=args.baseline, **popstats_by_key["a"]) 260 | b_normal, b_mean, b_var = self.bn_b.construct_graph(btilde, baseline=args.baseline, **popstats_by_key["b"]) 261 | ab = a_normal + b_normal 262 | g, f, i, o = [fn(ab[:, j * args.num_hidden:(j + 1) * args.num_hidden]) 263 | for j, fn in enumerate([self.activation] + 3 * [T.nnet.sigmoid])] 264 | c = dummy_c + f * c + i * g 265 | c_normal, c_mean, c_var = self.bn_c.construct_graph(c, baseline=args.baseline, **popstats_by_key["c"]) 266 | 267 | c_output = c if args.no_normalize_output else c_normal 268 | h = dummy_h + o * self.activation(c_output) 269 | 270 | if args.normalize_cell: 271 | c = c_normal 272 | 273 | return [locals()[name] for name in output_names] 274 | 275 | sequences = [t, long_sequence_is_long, x, dummy_states["h"], dummy_states["c"]] 276 | outputs_info = [ 277 | T.repeat(p.h0[None, :], batch_size, axis=0), 278 | T.repeat(p.c0[None, :], batch_size, axis=0), 279 | ] 280 | outputs_info.extend([None] * (len(output_names) - len(outputs_info))) 281 | 282 | outputs, updates = theano.scan( 283 | stepfn, 284 | sequences=sequences, 285 | outputs_info=outputs_info) 286 | outputs = dict(zip(output_names, outputs)) 287 | 288 | if not args.baseline and not args.use_population_statistics: 289 | # prepare population statistic estimation 290 | popstats = dict() 291 | alpha = 0.05 292 | for key, size in zip("abc", [4*args.num_hidden, 4*args.num_hidden, args.num_hidden]): 293 | for stat, init in zip("mean var".split(), [0, 1]): 294 | name = "%s_%s" % (key, stat) 295 | popstats[name] = theano.shared( 296 | init + np.zeros((length, size,), 297 | dtype=theano.config.floatX), 298 | name=name) 299 | popstats[name].tag.estimand = outputs[name] 300 | updates[popstats[name]] = (alpha * outputs[name] + 301 | (1 - alpha) * popstats[name]) 302 | 303 | return outputs, updates, dummy_states, popstats 304 | 305 | def construct_common_graph(situation, args, outputs, dummy_states, Wy, by, y): 306 | ytilde = T.dot(outputs["h"], Wy) + by 307 | yhat = softmax_lastaxis(ytilde) 308 | 309 | errors = T.neq(T.argmax(y, axis=y.ndim - 1), 310 | T.argmax(yhat, axis=yhat.ndim - 1)) 311 | cross_entropies = crossentropy_lastaxes(yhat, y) 312 | 313 | error_rate = errors.mean().copy(name="error_rate") 314 | cross_entropy = cross_entropies.mean().copy(name="cross_entropy") 315 | cost = cross_entropy.copy(name="cost") 316 | 317 | graph = ComputationGraph([cost, cross_entropy, error_rate]) 318 | 319 | state_grads = dict((k, T.grad(cost, v)) 320 | for k, v in dummy_states.items()) 321 | extensions = [] 322 | if False: 323 | # all these graphs be taking too much gpu memory? 324 | extensions.append( 325 | DumpVariables("%s_hiddens" % situation, graph.inputs, 326 | [v.copy(name="%s%s" % (k, suffix)) 327 | for suffix, things in [("", outputs), ("_grad", state_grads)] 328 | for k, v in things.items()], 329 | batch=next(get_stream(which_set="train", 330 | batch_size=args.batch_size, 331 | num_examples=args.batch_size, 332 | length=args.length) 333 | .get_epoch_iterator(as_dict=True)), 334 | before_training=True, every_n_epochs=10)) 335 | 336 | return graph, extensions 337 | 338 | def construct_graphs(args, nclasses): 339 | if args.initialization in "identity orthogonal".split(): 340 | args.initializer = orthogonal 341 | elif args.initialization == "uniform": 342 | args.initializer = lambda shape: uniform(shape, 0.02) 343 | elif args.initialization == "glorot": 344 | args.initializer = glorot 345 | 346 | Wy = theano.shared(args.initializer((args.num_hidden, nclasses)), name="Wy") 347 | by = theano.shared(np.zeros((nclasses,), dtype=theano.config.floatX), name="by") 348 | for parameter in [Wy, by]: 349 | add_role(parameter, PARAMETER) 350 | 351 | x = T.tensor3("features") 352 | 353 | #theano.config.compute_test_value = "warn" 354 | #x.tag.test_value = np.random.random((7, args.length, nclasses)).astype(theano.config.floatX) 355 | 356 | # move time axis forward 357 | x = x.dimshuffle(1, 0, 2) 358 | # task is to predict next character 359 | x, y = x[:-1], x[1:] 360 | length = args.length - 1 361 | 362 | args.use_population_statistics = False 363 | lstm = LSTM(args, nclasses) 364 | (outputs, training_updates, dummy_states, popstats) = lstm.construct_graph( 365 | args, x, length) 366 | training_graph, training_extensions = construct_common_graph("training", args, outputs, dummy_states, Wy, by, y) 367 | args.use_population_statistics = True 368 | (outputs, inference_updates, dummy_states, _) = lstm.construct_graph( 369 | args, x, length, 370 | # use popstats from previous invocation 371 | popstats=popstats) 372 | inference_graph, inference_extensions = construct_common_graph("inference", args, outputs, dummy_states, Wy, by, y) 373 | args.use_population_statistics = False 374 | 375 | # pfft 376 | return (dict(training=training_graph, inference=inference_graph), 377 | dict(training=training_extensions, inference=inference_extensions), 378 | dict(training=training_updates, inference=inference_updates)) 379 | 380 | if __name__ == "__main__": 381 | nclasses = 50 382 | 383 | import argparse 384 | parser = argparse.ArgumentParser() 385 | parser.add_argument("--seed", type=int, default=1) 386 | parser.add_argument("--length", type=int, default=50) 387 | parser.add_argument("--num-epochs", type=int, default=100) 388 | parser.add_argument("--batch-size", type=int, default=64) 389 | parser.add_argument("--learning-rate", type=float, default=1e-3) 390 | parser.add_argument("--epsilon", type=float, default=1e-5) 391 | parser.add_argument("--num-hidden", type=int, default=1000) 392 | parser.add_argument("--baseline", action="store_true") 393 | parser.add_argument("--initialization", choices="identity glorot orthogonal uniform".split(), default="identity") 394 | parser.add_argument("--initial-gamma", type=float, default=1e-1) 395 | parser.add_argument("--initial-beta", type=float, default=0) 396 | parser.add_argument("--cluster", action="store_true") 397 | parser.add_argument("--activation", choices=list(activations.keys()), default="tanh") 398 | parser.add_argument("--optimizer", choices="sgdmomentum rmsprop adam".split(), default="rmsprop") 399 | parser.add_argument("--learning-rate-decay", type=float, default=0.0) 400 | parser.add_argument("--no-normalize-terms-separately", action="store_true", help="Normalize recurrent and input terms separately") 401 | parser.add_argument("--normalize-cell", action="store_true", help="Pass normalized cell on to next step rather than just use it for output") 402 | parser.add_argument("--no-normalize-output", action="store_true", help="Normalize cell before using it for output") 403 | parser.add_argument("--continue-from") 404 | args = parser.parse_args() 405 | 406 | np.random.seed(args.seed) 407 | blocks.config.config.default_seed = args.seed 408 | 409 | if args.continue_from: 410 | from blocks.serialization import load 411 | main_loop = load(args.continue_from) 412 | main_loop.run() 413 | sys.exit(0) 414 | 415 | graphs, extensions, updates = construct_graphs(args, nclasses) 416 | 417 | ### optimization algorithm definition 418 | if args.optimizer == "adam": 419 | optimizer = Adam(learning_rate=args.learning_rate) 420 | # zzzz 421 | optimizer.learning_rate = theano.shared(np.asarray(optimizer.learning_rate, dtype=theano.config.floatX)) 422 | elif args.optimizer == "rmsprop": 423 | optimizer = RMSProp(learning_rate=args.learning_rate, decay_rate=0.9) 424 | elif args.optimizer == "sgdmomentum": 425 | optimizer = Momentum(learning_rate=args.learning_rate, momentum=0.99) 426 | step_rule = CompositeRule([ 427 | StepClipping(1.), 428 | optimizer, 429 | ]) 430 | algorithm = GradientDescent(cost=graphs["training"].outputs[0], 431 | parameters=graphs["training"].parameters, 432 | step_rule=step_rule) 433 | algorithm.add_updates(updates["training"]) 434 | model = Model(graphs["training"].outputs[0]) 435 | extensions = extensions["training"] + extensions["inference"] 436 | 437 | extensions.append(SharedVariableModifier( 438 | optimizer.learning_rate, 439 | functools.partial(learning_rate_decayer, args.learning_rate_decay))) 440 | 441 | # step monitor 442 | step_channels = [] 443 | step_channels.extend([ 444 | algorithm.steps[param].norm(2).copy(name="step_norm:%s" % name) 445 | for name, param in model.get_parameter_dict().items()]) 446 | step_channels.append(algorithm.total_step_norm.copy(name="total_step_norm")) 447 | step_channels.append(algorithm.total_gradient_norm.copy(name="total_gradient_norm")) 448 | step_channels.extend(graphs["training"].outputs) 449 | logger.warning("constructing training data monitor") 450 | extensions.append(TrainingDataMonitoring( 451 | step_channels, prefix="iteration", after_batch=True)) 452 | 453 | # parameter monitor 454 | extensions.append(DataStreamMonitoring( 455 | ([param.norm(2).copy(name="parameter.norm:%s" % name) 456 | for name, param in model.get_parameter_dict().items()] 457 | + [optimizer.learning_rate.copy(name="learning_rate")]), 458 | data_stream=None, after_epoch=True)) 459 | 460 | # performance monitor 461 | for situation in "training inference".split(): 462 | for which_set in "train valid test".split(): 463 | logger.warning("constructing %s %s monitor" % (which_set, situation)) 464 | channels = list(graphs[situation].outputs) 465 | extensions.append(DataStreamMonitoring( 466 | channels, 467 | prefix="%s_%s" % (which_set, situation), after_epoch=True, 468 | data_stream=get_stream(which_set=which_set, batch_size=args.batch_size, 469 | num_examples=50000, length=args.length))) 470 | 471 | extensions.extend([ 472 | TrackTheBest("valid_training_error_rate", "best_valid_training_error_rate"), 473 | DumpBest("best_valid_training_error_rate", "best.zip"), 474 | FinishAfter(after_n_epochs=args.num_epochs), 475 | #FinishIfNoImprovementAfter("best_valid_error_rate", epochs=50), 476 | Checkpoint("checkpoint.zip", on_interrupt=False, every_n_epochs=1, use_cpickle=True), 477 | DumpLog("log.pkl", after_epoch=True)]) 478 | 479 | if not args.cluster: 480 | extensions.append(ProgressBar()) 481 | 482 | extensions.extend([ 483 | Timing(after_batch=True), 484 | Printing(), 485 | PrintingTo("log"), 486 | ]) 487 | main_loop = MainLoop( 488 | data_stream=get_stream(which_set="train", batch_size=args.batch_size, length=args.length, augment=True), 489 | algorithm=algorithm, extensions=extensions, model=model) 490 | main_loop.run() 491 | -------------------------------------------------------------------------------- /text8.py: -------------------------------------------------------------------------------- 1 | import util 2 | import sys, os, util, itertools, copy, re, pprint 3 | import logging 4 | from collections import OrderedDict 5 | import numpy as np 6 | import theano, theano.tensor as T 7 | from theano.sandbox.rng_mrg import MRG_RandomStreams 8 | import blocks.config 9 | import fuel.datasets, fuel.streams, fuel.transformers, fuel.schemes 10 | 11 | from blocks.graph import ComputationGraph 12 | from blocks.algorithms import GradientDescent, Adam, RMSProp, StepClipping, CompositeRule, Momentum 13 | from blocks.model import Model 14 | from blocks.extensions import FinishAfter, Printing, ProgressBar, Timing 15 | from blocks.extensions.monitoring import TrainingDataMonitoring, DataStreamMonitoring 16 | from blocks.extensions.stopping import FinishIfNoImprovementAfter 17 | from blocks.extensions.training import TrackTheBest 18 | from blocks.extensions.saveload import Checkpoint 19 | from extensions import DumpLog, DumpBest, PrintingTo, DumpVariables 20 | from blocks.main_loop import MainLoop 21 | from blocks.utils import shared_floatx_zeros 22 | from blocks.roles import add_role, PARAMETER 23 | from blocks.serialization import load 24 | 25 | logging.basicConfig() 26 | logger = logging.getLogger(__name__) 27 | 28 | def zeros(shape): 29 | return np.zeros(shape, dtype=theano.config.floatX) 30 | 31 | def ones(shape): 32 | return np.ones(shape, dtype=theano.config.floatX) 33 | 34 | def glorot(shape): 35 | d = np.sqrt(6. / sum(shape)) 36 | return np.random.uniform(-d, +d, size=shape).astype(theano.config.floatX) 37 | 38 | def orthogonal(shape): 39 | # taken from https://gist.github.com/kastnerkyle/f7464d98fe8ca14f2a1a 40 | """ benanne lasagne ortho init (faster than qr approach)""" 41 | flat_shape = (shape[0], np.prod(shape[1:])) 42 | a = np.random.normal(0.0, 1.0, flat_shape) 43 | u, _, v = np.linalg.svd(a, full_matrices=False) 44 | q = u if u.shape == flat_shape else v # pick the one with the correct shape 45 | q = q.reshape(shape) 46 | return q[:shape[0], :shape[1]].astype(theano.config.floatX) 47 | 48 | def uniform(shape, scale): 49 | return np.random.uniform(-scale, +scale, size=shape).astype(theano.config.floatX) 50 | 51 | def softmax_lastaxis(x): 52 | # for sequence of distributions 53 | return T.nnet.softmax(x.reshape((-1, x.shape[-1]))).reshape(x.shape) 54 | 55 | def crossentropy_lastaxes(yhat, y): 56 | # for sequence of distributions/targets 57 | return -(y * T.log(yhat)).sum(axis=yhat.ndim - 1) 58 | 59 | class Text8(fuel.datasets.Dataset): 60 | provides_sources = ('features',) 61 | example_iteration_scheme = None 62 | 63 | def __init__(self, which_set, length, augment=False): 64 | self.which_set = which_set 65 | self.length = length 66 | self.augment = augment 67 | data = np.load(os.environ["CHAR_LEVEL_TEXT8_NPZ"]) 68 | self.data = data[which_set] 69 | self.vocab = data["vocab"] 70 | self.num_examples = int(len(self.data) / self.length) 71 | if self.augment: 72 | # -1 so we have one self.length worth of room for augmentation 73 | self.num_examples -= 1 74 | super(Text8, self).__init__() 75 | 76 | def open(self): 77 | data = self.data 78 | if self.augment: 79 | # choose an offset to get some data augmentation by not always chopping 80 | # the examples at the same point. 81 | offset = np.random.randint(self.length) 82 | data = data[offset:] 83 | # reshape to nonoverlapping examples 84 | data = (data[:self.num_examples * self.length] 85 | .reshape((self.num_examples, self.length))) 86 | # return the data so we will get it as the "state" argument to get_data 87 | return data 88 | 89 | def get_data(self, state, request): 90 | one_hot_batch = np.eye(len(self.vocab), dtype=theano.config.floatX)[state[request]] 91 | return (one_hot_batch,) 92 | 93 | def get_stream(which_set, batch_size, length, num_examples=None, augment=False): 94 | dataset = Text8(which_set, length=length, augment=augment) 95 | if num_examples is None or num_examples > dataset.num_examples: 96 | num_examples = dataset.num_examples 97 | stream = fuel.streams.DataStream.default_stream( 98 | dataset, 99 | iteration_scheme=fuel.schemes.ShuffledScheme(num_examples, batch_size)) 100 | return stream 101 | 102 | activations = dict( 103 | tanh=T.tanh, 104 | identity=lambda x: x, 105 | relu=lambda x: T.max(0, x)) 106 | 107 | class Parameters(object): 108 | pass 109 | 110 | class BatchNormalization(object): 111 | def __init__(self, shape, initial_gamma=1, initial_beta=0, name=None, use_bias=True, epsilon=1e-5): 112 | self.shape = shape 113 | self.initial_gamma = initial_gamma 114 | self.initial_beta = initial_beta 115 | self.name = name 116 | self.use_bias = use_bias 117 | self.epsilon = epsilon 118 | 119 | @property 120 | def parameters(self): 121 | if not hasattr(self, "_parameters"): 122 | self._parameters = self.allocate_parameters() 123 | return self._parameters 124 | 125 | def allocate_parameters(self): 126 | parameters = Parameters() 127 | for parameter in [ 128 | theano.shared(self.initial_gamma * ones(self.shape), name="gammas"), 129 | theano.shared(self.initial_beta * ones(self.shape), name="betas")]: 130 | add_role(parameter, PARAMETER) 131 | setattr(parameters, parameter.name, parameter) 132 | if self.name: 133 | parameter.name = "%s.%s" % (self.name, parameter.name) 134 | return parameters 135 | 136 | def construct_graph(self, x, baseline=False, mean=None, var=None): 137 | p = self.parameters 138 | assert x.ndim == 2 139 | mean = x.mean(axis=0) if mean is None else mean 140 | var = x.var (axis=0) if var is None else var 141 | assert mean.ndim == 1 142 | assert var.ndim == 1 143 | betas = p.betas if self.use_bias else 0 144 | if baseline: 145 | y = x + betas 146 | else: 147 | y = theano.tensor.nnet.bn.batch_normalization( 148 | inputs=x, 149 | gamma=p.gammas, beta=betas, 150 | mean=T.shape_padleft(mean), 151 | std=T.shape_padleft(T.sqrt(var + self.epsilon))) 152 | return y, mean, var 153 | 154 | class LSTM(object): 155 | def __init__(self, args, nclasses): 156 | self.num_hidden = args.num_hidden 157 | self.initializer = args.initializer 158 | self.identity_hh = args.initialization == "identity" 159 | self.nclasses = nclasses 160 | self.activation = activations[args.activation] 161 | 162 | self.bn_a = BatchNormalization((4 * args.num_hidden,), initial_gamma=args.initial_gamma, name="bn_a", epsilon=args.epsilon) 163 | self.bn_b = BatchNormalization((4 * args.num_hidden,), initial_gamma=args.initial_gamma, name="bn_b", epsilon=args.epsilon, use_bias=False) 164 | self.bn_c = BatchNormalization(( args.num_hidden,), initial_gamma=args.initial_gamma, name="bn_c", epsilon=args.epsilon) 165 | 166 | @property 167 | def parameters(self): 168 | if not hasattr(self, "_parameters"): 169 | self._parameters = self.allocate_parameters() 170 | return self._parameters 171 | 172 | def allocate_parameters(self): 173 | parameters = Parameters() 174 | Wa = self.initializer((self.num_hidden, 4 * self.num_hidden)) 175 | 176 | if self.identity_hh: 177 | Wa[:self.num_hidden, :self.num_hidden] = np.eye(self.num_hidden) 178 | 179 | for parameter in [ 180 | theano.shared(zeros((self.num_hidden,)), name="h0"), 181 | theano.shared(zeros((self.num_hidden,)), name="c0"), 182 | theano.shared(Wa, name="Wa"), 183 | theano.shared(self.initializer((self.nclasses, 4 * self.num_hidden)), name="Wx")]: 184 | add_role(parameter, PARAMETER) 185 | setattr(parameters, parameter.name, parameter) 186 | 187 | # forget gate bias initialization 188 | ab_betas = self.bn_a.parameters.betas 189 | pffft = ab_betas.get_value() 190 | pffft[self.num_hidden:2*self.num_hidden] = 1. 191 | ab_betas.set_value(pffft) 192 | 193 | return parameters 194 | 195 | def construct_graph(self, args, x, length, popstats=None): 196 | p = self.parameters 197 | 198 | # use `symlength` where we need to be able to adapt to longer sequences 199 | # than the ones we trained on 200 | symlength = x.shape[0] 201 | t = T.cast(T.arange(symlength), "int16") 202 | long_sequence_is_long = T.ge(T.cast(T.arange(symlength), theano.config.floatX), length) 203 | batch_size = x.shape[1] 204 | dummy_states = dict(h=T.zeros((symlength, batch_size, args.num_hidden)), 205 | c=T.zeros((symlength, batch_size, args.num_hidden))) 206 | 207 | output_names = "h c atilde btilde".split() 208 | for key in "abc": 209 | for stat in "mean var".split(): 210 | output_names.append("%s_%s" % (key, stat)) 211 | 212 | def stepfn(t, long_sequence_is_long, x, dummy_h, dummy_c, h, c): 213 | # population statistics are sequences, but we use them 214 | # like a non-sequence and index it ourselves. this allows 215 | # us to generalize to longer sequences, in which case we 216 | # repeat the last element. 217 | popstats_by_key = dict() 218 | for key in "abc": 219 | popstats_by_key[key] = dict() 220 | for stat in "mean var".split(): 221 | if not args.baseline and args.use_population_statistics: 222 | popstat = popstats["%s_%s" % (key, stat)] 223 | # pluck the appropriate population statistic for this 224 | # time step out of the sequence, or take the last 225 | # element if we've gone beyond the training length. 226 | # if `long_sequence_is_long` then `t` may be unreliable 227 | # as it will overflow for looong sequences. 228 | popstat = theano.ifelse.ifelse( 229 | long_sequence_is_long, popstat[-1], popstat[t]) 230 | else: 231 | popstat = None 232 | popstats_by_key[key][stat] = popstat 233 | 234 | atilde, btilde = T.dot(h, p.Wa), T.dot(x, p.Wx) 235 | a_normal, a_mean, a_var = self.bn_a.construct_graph(atilde, baseline=args.baseline, **popstats_by_key["a"]) 236 | b_normal, b_mean, b_var = self.bn_b.construct_graph(btilde, baseline=args.baseline, **popstats_by_key["b"]) 237 | ab = a_normal + b_normal 238 | 239 | g, f, i, o = [fn(ab[:, j * args.num_hidden:(j + 1) * args.num_hidden]) 240 | for j, fn in enumerate([self.activation] + 3 * [T.nnet.sigmoid])] 241 | 242 | c = dummy_c + f * c + i * g 243 | 244 | c_normal, c_mean, c_var = self.bn_c.construct_graph(c, baseline=args.baseline, **popstats_by_key["c"]) 245 | 246 | h = dummy_h + o * self.activation(c_normal) 247 | 248 | return [locals()[name] for name in output_names] 249 | 250 | sequences = [t, long_sequence_is_long, x, dummy_states["h"], dummy_states["c"]] 251 | outputs_info = [ 252 | T.repeat(p.h0[None, :], batch_size, axis=0), 253 | T.repeat(p.c0[None, :], batch_size, axis=0), 254 | ] 255 | outputs_info.extend([None] * (len(output_names) - len(outputs_info))) 256 | 257 | outputs, updates = theano.scan( 258 | stepfn, 259 | sequences=sequences, 260 | outputs_info=outputs_info) 261 | outputs = dict(zip(output_names, outputs)) 262 | 263 | if not args.baseline and not args.use_population_statistics: 264 | # prepare population statistic estimation 265 | popstats = dict() 266 | alpha = 0.05 267 | for key, size in zip("abc", [4*args.num_hidden, 4*args.num_hidden, args.num_hidden, 3*args.num_hidden]): 268 | for stat, init in zip("mean var".split(), [0, 1]): 269 | name = "%s_%s" % (key, stat) 270 | popstats[name] = theano.shared( 271 | init + np.zeros((length, size,), 272 | dtype=theano.config.floatX), 273 | name=name) 274 | popstats[name].tag.estimand = outputs[name] 275 | updates[popstats[name]] = (alpha * outputs[name] + 276 | (1 - alpha) * popstats[name]) 277 | 278 | return outputs, updates, dummy_states, popstats 279 | 280 | def construct_common_graph(situation, args, outputs, dummy_states, Wy, by, y): 281 | ytilde = T.dot(outputs["h"], Wy) + by 282 | yhat = softmax_lastaxis(ytilde) 283 | 284 | errors = T.neq(T.argmax(y, axis=y.ndim - 1), 285 | T.argmax(yhat, axis=yhat.ndim - 1)) 286 | cross_entropies = crossentropy_lastaxes(yhat, y) 287 | 288 | error_rate = errors.mean().copy(name="error_rate") 289 | cross_entropy = cross_entropies.mean().copy(name="cross_entropy") 290 | bpc = (cross_entropy / np.log(2)).copy(name="bpc") 291 | cost = cross_entropy.copy(name="cost") 292 | 293 | graph = ComputationGraph([cost, cross_entropy, error_rate, bpc]) 294 | 295 | state_grads = dict((k, T.grad(cost, v)) 296 | for k, v in dummy_states.items()) 297 | extensions = [] 298 | if args.dump_hiddens: 299 | extensions.append( 300 | DumpVariables("%s_hiddens" % situation, graph.inputs, 301 | [v.copy(name="%s%s" % (k, suffix)) 302 | for suffix, things in [("", outputs), ("_grad", state_grads)] 303 | for k, v in things.items()], 304 | batch=next(get_stream(which_set="train", 305 | batch_size=args.batch_size, 306 | num_examples=args.batch_size, 307 | length=args.length) 308 | .get_epoch_iterator(as_dict=True)), 309 | before_training=True, every_n_epochs=10)) 310 | 311 | return graph, extensions 312 | 313 | def construct_graphs(args, nclasses): 314 | if args.initialization in "identity orthogonal".split(): 315 | args.initializer = orthogonal 316 | elif args.initialization == "uniform": 317 | args.initializer = lambda shape: uniform(shape, 0.01) 318 | elif args.initialization == "glorot": 319 | args.initializer = glorot 320 | 321 | Wy = theano.shared(args.initializer((args.num_hidden, nclasses)), name="Wy") 322 | by = theano.shared(np.zeros((nclasses,), dtype=theano.config.floatX), name="by") 323 | for parameter in [Wy, by]: 324 | add_role(parameter, PARAMETER) 325 | 326 | x = T.tensor3("features") 327 | 328 | #theano.config.compute_test_value = "warn" 329 | #x.tag.test_value = np.random.random((7, args.length, nclasses)).astype(theano.config.floatX) 330 | 331 | # move time axis forward 332 | x = x.dimshuffle(1, 0, 2) 333 | # task is to predict next character 334 | x, y = x[:-1], x[1:] 335 | length = args.length - 1 336 | 337 | args.use_population_statistics = False 338 | lstm = LSTM(args, nclasses) 339 | (outputs, training_updates, dummy_states, popstats) = lstm.construct_graph( 340 | args, x, length) 341 | training_graph, training_extensions = construct_common_graph("training", args, outputs, dummy_states, Wy, by, y) 342 | args.use_population_statistics = True 343 | (outputs, inference_updates, dummy_states, _) = lstm.construct_graph( 344 | args, x, length, 345 | # use popstats from previous invocation 346 | popstats=popstats) 347 | inference_graph, inference_extensions = construct_common_graph("inference", args, outputs, dummy_states, Wy, by, y) 348 | args.use_population_statistics = False 349 | 350 | return (dict(training=training_graph, inference=inference_graph), 351 | dict(training=training_extensions, inference=inference_extensions), 352 | dict(training=training_updates, inference=inference_updates)) 353 | 354 | def main(): 355 | nclasses = 27 356 | 357 | import argparse 358 | parser = argparse.ArgumentParser() 359 | parser.add_argument("--seed", type=int, default=1) 360 | parser.add_argument("--length", type=int, default=180) 361 | parser.add_argument("--num-epochs", type=int, default=100) 362 | parser.add_argument("--batch-size", type=int, default=64) 363 | parser.add_argument("--learning-rate", type=float, default=1e-3) 364 | parser.add_argument("--epsilon", type=float, default=1e-5) 365 | parser.add_argument("--num-hidden", type=int, default=1000) 366 | parser.add_argument("--baseline", action="store_true") 367 | parser.add_argument("--initialization", choices="identity glorot orthogonal uniform".split(), default="identity") 368 | parser.add_argument("--initial-gamma", type=float, default=1e-1) 369 | parser.add_argument("--initial-beta", type=float, default=0) 370 | parser.add_argument("--cluster", action="store_true") 371 | parser.add_argument("--activation", choices=list(activations.keys()), default="tanh") 372 | parser.add_argument("--optimizer", choices="sgdmomentum adam rmsprop", default="rmsprop") 373 | parser.add_argument("--continue-from") 374 | parser.add_argument("--evaluate") 375 | parser.add_argument("--dump-hiddens") 376 | args = parser.parse_args() 377 | 378 | np.random.seed(args.seed) 379 | blocks.config.config.default_seed = args.seed 380 | 381 | if args.continue_from: 382 | from blocks.serialization import load 383 | main_loop = load(args.continue_from) 384 | main_loop.run() 385 | sys.exit(0) 386 | 387 | graphs, extensions, updates = construct_graphs(args, nclasses) 388 | 389 | ### optimization algorithm definition 390 | if args.optimizer == "adam": 391 | optimizer = Adam(learning_rate=args.learning_rate) 392 | elif args.optimizer == "rmsprop": 393 | optimizer = RMSProp(learning_rate=args.learning_rate, decay_rate=0.9) 394 | elif args.optimizer == "sgdmomentum": 395 | optimizer = Momentum(learning_rate=args.learning_rate, momentum=0.99) 396 | step_rule = CompositeRule([ 397 | StepClipping(1.), 398 | optimizer, 399 | ]) 400 | algorithm = GradientDescent(cost=graphs["training"].outputs[0], 401 | parameters=graphs["training"].parameters, 402 | step_rule=step_rule) 403 | algorithm.add_updates(updates["training"]) 404 | model = Model(graphs["training"].outputs[0]) 405 | extensions = extensions["training"] + extensions["inference"] 406 | 407 | # step monitor 408 | step_channels = [] 409 | step_channels.extend([ 410 | algorithm.steps[param].norm(2).copy(name="step_norm:%s" % name) 411 | for name, param in model.get_parameter_dict().items()]) 412 | step_channels.append(algorithm.total_step_norm.copy(name="total_step_norm")) 413 | step_channels.append(algorithm.total_gradient_norm.copy(name="total_gradient_norm")) 414 | step_channels.extend(graphs["training"].outputs) 415 | logger.warning("constructing training data monitor") 416 | extensions.append(TrainingDataMonitoring( 417 | step_channels, prefix="iteration", after_batch=True)) 418 | 419 | # parameter monitor 420 | extensions.append(DataStreamMonitoring( 421 | [param.norm(2).copy(name="parameter.norm:%s" % name) 422 | for name, param in model.get_parameter_dict().items()], 423 | data_stream=None, after_epoch=True)) 424 | 425 | validation_interval = 500 426 | # performance monitor 427 | for situation in "training inference".split(): 428 | if situation == "inference" and not args.evaluate: 429 | # save time when we don't need the inference graph 430 | continue 431 | 432 | for which_set in "train valid test".split(): 433 | logger.warning("constructing %s %s monitor" % (which_set, situation)) 434 | channels = list(graphs[situation].outputs) 435 | extensions.append(DataStreamMonitoring( 436 | channels, 437 | prefix="%s_%s" % (which_set, situation), every_n_batches=validation_interval, 438 | data_stream=get_stream(which_set=which_set, batch_size=args.batch_size, 439 | num_examples=10000, length=args.length))) 440 | 441 | extensions.extend([ 442 | TrackTheBest("valid_training_error_rate", "best_valid_training_error_rate"), 443 | DumpBest("best_valid_training_error_rate", "best.zip"), 444 | FinishAfter(after_n_epochs=args.num_epochs), 445 | #FinishIfNoImprovementAfter("best_valid_error_rate", epochs=50), 446 | Checkpoint("checkpoint.zip", on_interrupt=False, every_n_epochs=1, use_cpickle=True), 447 | DumpLog("log.pkl", after_epoch=True)]) 448 | 449 | if not args.cluster: 450 | extensions.append(ProgressBar()) 451 | 452 | extensions.extend([ 453 | Timing(), 454 | Printing(every_n_batches=validation_interval), 455 | PrintingTo("log"), 456 | ]) 457 | main_loop = MainLoop( 458 | data_stream=get_stream(which_set="train", batch_size=args.batch_size, length=args.length, augment=True), 459 | algorithm=algorithm, extensions=extensions, model=model) 460 | 461 | if args.dump_hiddens: 462 | dump_hiddens(args, main_loop) 463 | return 464 | 465 | if args.evaluate: 466 | evaluate(args, main_loop) 467 | return 468 | 469 | main_loop.run() 470 | 471 | def transfer_parameters(src_main_loop, dest_main_loop): 472 | src_parameters = dict((parameter.name, parameter) for parameter in src_main_loop.algorithm.parameters) 473 | dest_parameters = dict((parameter.name, parameter) for parameter in dest_main_loop.algorithm.parameters) 474 | 475 | # assert sets of parameters equal 476 | assert not (set(src_parameters) - set(dest_parameters)) 477 | assert not (set(dest_parameters) - set(src_parameters)) 478 | 479 | for name, src_parameter in src_parameters.items(): 480 | dest_parameter = dest_parameters[name] 481 | assert dest_parameter.get_value().shape == src_parameter.get_value().shape 482 | dest_parameter.set_value(src_parameter.get_value()) 483 | 484 | def dump_hiddens(args, main_loop): 485 | # load parameters of trained model 486 | trained_main_loop = load(args.dump_hiddens) 487 | transfer_parameters(trained_main_loop, main_loop) 488 | del trained_main_loop 489 | 490 | for extension in main_loop.extensions: 491 | if isinstance(extension, DumpVariables): 492 | extension.do("after_training") 493 | 494 | def evaluate(args, main_loop): 495 | # load parameters of trained model 496 | trained_main_loop = load(args.evaluate) 497 | transfer_parameters(trained_main_loop, main_loop) 498 | del trained_main_loop 499 | 500 | # extract population statistic updates 501 | updates = [update for update in main_loop.algorithm.updates 502 | # FRAGILE 503 | if re.search("_(mean|var)$", update[0].name)] 504 | print updates 505 | 506 | old_popstats = dict((popstat, popstat.get_value()) for popstat, _ in updates) 507 | 508 | # baseline doesn't need all this 509 | if updates: 510 | train_stream = get_stream(which_set="train", 511 | batch_size=1000, 512 | length=args.length) 513 | nbatches = len(list(train_stream.get_epoch_iterator())) 514 | 515 | # destructure moving average expression to construct a new expression 516 | new_updates = [] 517 | for popstat, value in updates: 518 | # FRAGILE 519 | assert value.owner.op.scalar_op == theano.scalar.add 520 | terms = value.owner.inputs 521 | # right multiplicand of second term is popstat 522 | assert popstat in theano.gof.graph.ancestors([terms[1].owner.inputs[1]]) 523 | # right multiplicand of first term is batchstat 524 | batchstat = terms[0].owner.inputs[1] 525 | 526 | old_popstats[popstat] = popstat.get_value() 527 | 528 | # FRAGILE: assume population statistics not used in computation of batch statistics 529 | # otherwise popstat should always have a reasonable value 530 | popstat.set_value(0 * popstat.get_value(borrow=True)) 531 | new_updates.append((popstat, popstat + batchstat / float(nbatches))) 532 | 533 | # FRAGILE: assume all the other algorithm updates are unneeded for computation of batch statistics 534 | estimate_fn = theano.function(main_loop.algorithm.inputs, [], 535 | updates=new_updates, on_unused_input="warn") 536 | print("averaging batch statistics over", nbatches, "batches") 537 | for batch in train_stream.get_epoch_iterator(as_dict=True): 538 | estimate_fn(**batch) 539 | sys.stdout.write(".") 540 | sys.stdout.flush() 541 | print 542 | 543 | new_popstats = dict((popstat, popstat.get_value()) for popstat, _ in updates) 544 | 545 | from blocks.monitoring.evaluators import DatasetEvaluator 546 | results = dict() 547 | for situation in "training inference".split(): 548 | results[situation] = dict() 549 | outputs, = [ 550 | extension._evaluator.theano_variables 551 | for extension in main_loop.extensions 552 | if getattr(extension, "prefix", None) == "valid_%s" % situation] 553 | evaluator = DatasetEvaluator(outputs) 554 | for which_set in "valid test".split(): 555 | print(situation, which_set) 556 | results[situation][which_set] = OrderedDict( 557 | (length, evaluator.evaluate(get_stream( 558 | which_set=which_set, 559 | batch_size=100, 560 | length=length))) 561 | for length in [1000]) 562 | 563 | try: 564 | results["proper_test"] = evaluator.evaluate( 565 | get_stream( 566 | which_set="test", 567 | batch_size=1, 568 | length=5*10**6)) 569 | except: 570 | # that will probably run out of memory 571 | pass 572 | 573 | import cPickle 574 | cPickle.dump(dict(results=results, 575 | old_popstats=old_popstats, 576 | new_popstats=new_popstats), 577 | open(sys.argv[1] + "_popstat_results.pkl", "w")) 578 | 579 | if __name__ == "__main__": 580 | main() 581 | -------------------------------------------------------------------------------- /memory.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | from collections import OrderedDict 4 | import numpy as np 5 | import theano, theano.tensor as T 6 | from theano.sandbox.rng_mrg import MRG_RandomStreams 7 | import blocks.config 8 | import fuel.datasets, fuel.streams, fuel.transformers, fuel.schemes 9 | 10 | ### optimization algorithm definition 11 | from blocks.graph import ComputationGraph 12 | from blocks.algorithms import GradientDescent, RMSProp, StepClipping, CompositeRule, Momentum 13 | from blocks.model import Model 14 | from blocks.extensions import FinishAfter, Printing, ProgressBar, Timing 15 | from blocks.extensions.monitoring import TrainingDataMonitoring, DataStreamMonitoring 16 | from blocks.extensions.stopping import FinishIfNoImprovementAfter 17 | from blocks.extensions.training import TrackTheBest 18 | from blocks.extensions.saveload import Checkpoint 19 | from extensions import DumpLog, DumpBest, PrintingTo, DumpVariables 20 | from blocks.main_loop import MainLoop 21 | from blocks.utils import shared_floatx_zeros 22 | from blocks.roles import add_role, PARAMETER 23 | 24 | 25 | logging.basicConfig() 26 | logger = logging.getLogger(__name__) 27 | 28 | def zeros(shape): 29 | return np.zeros(shape, dtype=theano.config.floatX) 30 | 31 | def ones(shape): 32 | return np.ones(shape, dtype=theano.config.floatX) 33 | 34 | def glorot(shape): 35 | d = np.sqrt(6. / sum(shape)) 36 | return np.random.uniform(-d, +d, size=shape).astype(theano.config.floatX) 37 | 38 | def orthogonal(shape): 39 | # taken from https://gist.github.com/kastnerkyle/f7464d98fe8ca14f2a1a 40 | """ benanne lasagne ortho init (faster than qr approach)""" 41 | flat_shape = (shape[0], np.prod(shape[1:])) 42 | a = np.random.normal(0.0, 1.0, flat_shape) 43 | u, _, v = np.linalg.svd(a, full_matrices=False) 44 | q = u if u.shape == flat_shape else v # pick the one with the correct shape 45 | q = q.reshape(shape) 46 | return q[:shape[0], :shape[1]].astype(theano.config.floatX) 47 | 48 | 49 | class Memory(fuel.datasets.Dataset): 50 | provides_sources = ('x', 'y') 51 | example_iteration_scheme = None 52 | 53 | def __init__(self, min_interval, max_interval, seed): 54 | self.min_interval = min_interval 55 | self.max_interval = max_interval 56 | self.seed = seed 57 | super(Memory, self).__init__() 58 | 59 | def open(self): 60 | return np.random.RandomState(self.seed) 61 | 62 | def get_data(self, state=None, request=None): 63 | if request is not None: 64 | raise ValueError 65 | 66 | # the sequence to be remembered 67 | targets = state.randint(8, size=10) 68 | 69 | interval = state.random_integers(self.min_interval, self.max_interval) 70 | 71 | x = np.zeros((interval + 20, 10), dtype=np.float32) 72 | x[range(10), targets] = 1. 73 | x[10:-11, 8] = 1. 74 | x[-11, 9] = 1. 75 | x[-10:, 8] = 1. 76 | 77 | y = np.zeros((interval + 20, 10), dtype=np.float32) 78 | y[:-10, 8] = 1. 79 | y[range(-10, -0), targets] = 1. 80 | 81 | return (x, y) 82 | 83 | def get_stream(min_interval, max_interval, which_set, batch_size, num_examples=10000): 84 | seed = dict(train=1, valid=2, test=3)[which_set] 85 | dataset = Memory(min_interval, max_interval, seed) 86 | stream = fuel.streams.DataStream.default_stream(dataset) 87 | stream = fuel.transformers.Padding( 88 | fuel.transformers.Batch( 89 | stream, 90 | fuel.schemes.ConstantScheme(batch_size, num_examples))) 91 | return stream 92 | 93 | 94 | 95 | 96 | 97 | def construct_rnn(args, x, activation): 98 | parameters = [] 99 | 100 | h0 = theano.shared(zeros((args.num_hidden,)), name="h0") 101 | Wh = theano.shared((0.99 if args.baseline else 1) * np.eye(args.num_hidden, dtype=theano.config.floatX), name="Wh") 102 | Wx = theano.shared(orthogonal((10, args.num_hidden)), name="Wx") 103 | 104 | parameters.extend([h0, Wh, Wx]) 105 | 106 | gammas = theano.shared(args.initial_gamma * ones((args.num_hidden,)), name="gammas") 107 | betas = theano.shared(args.initial_beta * ones((args.num_hidden,)), name="betas") 108 | 109 | if args.baseline: 110 | parameters.extend([betas]) 111 | def bn(x, gammas, betas): 112 | return x + betas 113 | else: 114 | parameters.extend([gammas, betas]) 115 | def bn(x, gammas, betas): 116 | mean, var = x.mean(axis=0, keepdims=True), x.var(axis=0, keepdims=True) 117 | # if only 118 | mean.tag.batchstat, var.tag.batchstat = True, True 119 | #var = T.maximum(var, args.epsilon) 120 | var = var + args.epsilon 121 | return (x - mean) / T.sqrt(var) * gammas + betas 122 | 123 | xtilde = T.dot(x, Wx) 124 | 125 | if args.noise: 126 | # prime h with white noise 127 | Trng = MRG_RandomStreams() 128 | h_prime = Trng.normal((xtilde.shape[1], args.num_hidden), std=args.noise) 129 | elif args.summarize: 130 | # prime h with mean of example 131 | h_prime = x.mean(axis=[0, 2])[:, None] 132 | else: 133 | h_prime = 0 134 | 135 | dummy_states = dict(h =T.zeros((xtilde.shape[0], xtilde.shape[1], args.num_hidden)), 136 | htilde=T.zeros((xtilde.shape[0], xtilde.shape[1], args.num_hidden))) 137 | 138 | def stepfn(xtilde, dummy_h, dummy_htilde, h): 139 | htilde = dummy_htilde + T.dot(h, Wh) + xtilde 140 | h = dummy_h + activation(bn(htilde, gammas, betas)) 141 | return h, htilde 142 | 143 | [h, htilde], _ = theano.scan(stepfn, 144 | sequences=[xtilde, dummy_states["h"], dummy_states["htilde"]], 145 | outputs_info=[T.repeat(h0[None, :], xtilde.shape[1], axis=0) + h_prime, 146 | None]) 147 | 148 | return dict(h=h, htilde=htilde), dummy_states, parameters 149 | 150 | def construct_lstm(args, x, activation): 151 | parameters = [] 152 | 153 | h0 = theano.shared(zeros((args.num_hidden,)), name="h0") 154 | c0 = theano.shared(zeros((args.num_hidden,)), name="c0") 155 | Wa = theano.shared(np.concatenate([ 156 | np.eye(args.num_hidden), 157 | orthogonal((args.num_hidden, 3 * args.num_hidden)), 158 | ], axis=1).astype(theano.config.floatX), name="Wa") 159 | Wx = theano.shared(orthogonal((10, 4 * args.num_hidden)), name="Wx") 160 | 161 | parameters.extend([h0, c0, Wa, Wx]) 162 | 163 | a_gammas = theano.shared(args.initial_gamma * ones((4 * args.num_hidden,)), name="a_gammas") 164 | b_gammas = theano.shared(args.initial_gamma * ones((4 * args.num_hidden,)), name="b_gammas") 165 | ab_betas = theano.shared(args.initial_beta * ones((4 * args.num_hidden,)), name="ab_betas") 166 | h_gammas = theano.shared(args.initial_gamma * ones((args.num_hidden,)), name="h_gammas") 167 | h_betas = theano.shared(args.initial_beta * ones((args.num_hidden,)), name="h_betas") 168 | 169 | # forget gate bias initialization 170 | pffft = ab_betas.get_value() 171 | pffft[args.num_hidden:2*args.num_hidden] = 1. 172 | ab_betas.set_value(pffft) 173 | 174 | if args.baseline: 175 | parameters.extend([ab_betas, h_betas]) 176 | def bn(x, gammas, betas): 177 | return x + betas 178 | else: 179 | parameters.extend([a_gammas, b_gammas, h_gammas, ab_betas, h_betas]) 180 | def bn(x, gammas, betas): 181 | mean, var = x.mean(axis=0, keepdims=True), x.var(axis=0, keepdims=True) 182 | # if only 183 | mean.tag.batchstat, var.tag.batchstat = True, True 184 | #var = T.maximum(var, args.epsilon) 185 | vareps = var + args.epsilon 186 | return (x - mean) / T.sqrt(vareps) * gammas + betas, mean, var 187 | 188 | xtilde = T.dot(x, Wx) 189 | 190 | if args.noise: 191 | # prime h with white noise 192 | Trng = MRG_RandomStreams() 193 | h_prime = Trng.normal((xtilde.shape[1], args.num_hidden), std=args.noise) 194 | elif args.summarize: 195 | # prime h with mean of example 196 | h_prime = x.mean(axis=[0, 2])[:, None] 197 | else: 198 | h_prime = 0 199 | 200 | dummy_states = dict(h=T.zeros((xtilde.shape[0], xtilde.shape[1], args.num_hidden)), 201 | c=T.zeros((xtilde.shape[0], xtilde.shape[1], args.num_hidden))) 202 | 203 | def stepfn(xtilde, dummy_h, dummy_c, h, c): 204 | atilde = T.dot(h, Wa) 205 | btilde = xtilde 206 | a, mean, var = bn(atilde, a_gammas, ab_betas) 207 | #b = bn(btilde, b_gammas, 0) 208 | b = btilde 209 | ab = a + b 210 | g, f, i, o = [fn(ab[:, j * args.num_hidden:(j + 1) * args.num_hidden]) 211 | for j, fn in enumerate([activation] + 3 * [T.nnet.sigmoid])] 212 | c = dummy_c + f * c + i * g 213 | htilde = c 214 | bn_, tmp1, tmp2 = bn(htilde, h_gammas, h_betas) 215 | h = dummy_h + o * activation(bn_) 216 | return h, c, atilde, btilde, htilde, g, f, i, o, mean, var 217 | 218 | [h, c, 219 | atilde, btilde, htilde, 220 | g, f, i, o, 221 | mean, var], _ = theano.scan( 222 | stepfn, 223 | sequences=[xtilde, dummy_states["h"], dummy_states["c"]], 224 | outputs_info=[T.repeat(h0[None, :], xtilde.shape[1], axis=0) + h_prime, 225 | T.repeat(c0[None, :], xtilde.shape[1], axis=0), 226 | None, None, None, 227 | None, None, None, None, 228 | None, None]) 229 | return dict(h=h, c=c, atilde=atilde, btilde=btilde, htilde=htilde, 230 | g=g, f=f, i=i, o=o, 231 | mean=mean, var=var), dummy_states, parameters 232 | 233 | 234 | def bn(x, gammas, betas, mean, var, args): 235 | assert mean.ndim == 1 236 | assert var.ndim == 1 237 | assert x.ndim == 2 238 | if not args.use_population_statistics: 239 | mean = x.mean(axis=0) 240 | var = x.var(axis=0) 241 | #var = T.maximum(var, args.epsilon) 242 | #var = var + args.epsilon 243 | 244 | if args.baseline: 245 | y = x + betas 246 | else: 247 | #var_corrected = var.zeros_like() + 1.0 248 | if args.clipvar: 249 | var_corrected = theano.tensor.switch(theano.tensor.eq(var, 0.), 1.0, var + args.epsilon) 250 | else: 251 | var_corrected = var + args.epsilon 252 | 253 | y = theano.tensor.nnet.bn.batch_normalization( 254 | inputs=x, gamma=gammas, beta=betas, 255 | mean=T.shape_padleft(mean), std=T.shape_padleft(T.sqrt(var_corrected)), 256 | mode="high_mem") 257 | assert mean.ndim == 1 258 | assert var.ndim == 1 259 | return y, mean, var 260 | 261 | activations = dict( 262 | tanh=T.tanh, 263 | identity=lambda x: x, 264 | relu=lambda x: T.max(0, x)) 265 | 266 | 267 | class Empty(object): 268 | pass 269 | 270 | class LSTM(object): 271 | def __init__(self, args, nclasses): 272 | self.nclasses = nclasses 273 | self.activation = activations[args.activation] 274 | 275 | def allocate_parameters(self, args): 276 | if hasattr(self, "parameters"): 277 | return self.parameters 278 | 279 | self.parameters = Empty() 280 | 281 | h0 = theano.shared(zeros((args.num_hidden,)), name="h0") 282 | c0 = theano.shared(zeros((args.num_hidden,)), name="c0") 283 | if args.init == "id": 284 | Wa = theano.shared(np.concatenate([ 285 | np.eye(args.num_hidden), 286 | orthogonal((args.num_hidden, 287 | 3 * args.num_hidden)),], axis=1).astype(theano.config.floatX), name="Wa") 288 | else: 289 | Wa = theano.shared(orthogonal((args.num_hidden, 4 * args.num_hidden)), name="Wa") 290 | Wx = theano.shared(orthogonal((10, 4 * args.num_hidden)), name="Wx") 291 | a_gammas = theano.shared(args.initial_gamma * ones((4 * args.num_hidden,)), name="a_gammas") 292 | b_gammas = theano.shared(args.initial_gamma * ones((4 * args.num_hidden,)), name="b_gammas") 293 | ab_betas = theano.shared(args.initial_beta * ones((4 * args.num_hidden,)), name="ab_betas") 294 | 295 | # forget gate bias initialization 296 | forget_biais = ab_betas.get_value() 297 | forget_biais[args.num_hidden:2*args.num_hidden] = 1. 298 | ab_betas.set_value(forget_biais) 299 | 300 | c_gammas = theano.shared(args.initial_gamma * ones((args.num_hidden,)), name="c_gammas") 301 | c_betas = theano.shared(args.initial_beta * ones((args.num_hidden,)), name="c_betas") 302 | 303 | if not args.baseline: 304 | parameters_list = [h0, c0, Wa, Wx, a_gammas, b_gammas, ab_betas, c_gammas, c_betas] 305 | else: 306 | parameters_list = [h0, c0, Wa, Wx, ab_betas, c_betas] 307 | for parameter in parameters_list: 308 | print parameter.name 309 | add_role(parameter, PARAMETER) 310 | setattr(self.parameters, parameter.name, parameter) 311 | 312 | return self.parameters 313 | 314 | 315 | def construct_graph_ref(self, args, x, length, popstats=None): 316 | 317 | p = self.allocate_parameters(args) 318 | 319 | if args.baseline: 320 | def bn(x, gammas, betas): 321 | return x + betas 322 | else: 323 | def bn(x, gammas, betas): 324 | mean, var = x.mean(axis=0, keepdims=True), x.var(axis=0, keepdims=True) 325 | # if only 326 | mean.tag.batchstat, var.tag.batchstat = True, True 327 | #var = T.maximum(var, args.epsilon) 328 | var = var + args.epsilon 329 | return (x - mean) / T.sqrt(var) * gammas + betas 330 | 331 | def stepfn(x, dummy_h, dummy_c, h, c): 332 | # a_mean, b_mean, c_mean, 333 | # a_var, b_var, c_var): 334 | 335 | a_mean, b_mean, c_mean = 0, 0, 0 336 | a_var, b_var, c_var = 0, 0, 0 337 | 338 | atilde = T.dot(h, p.Wa) 339 | btilde = x 340 | a_normal = bn(atilde, p.a_gammas, p.ab_betas) 341 | b_normal = bn(btilde, p.b_gammas, 0) 342 | ab = a_normal + b_normal 343 | g, f, i, o = [fn(ab[:, j * args.num_hidden:(j + 1) * args.num_hidden]) 344 | for j, fn in enumerate([self.activation] + 3 * [T.nnet.sigmoid])] 345 | c = dummy_c + f * c + i * g 346 | c_normal = bn(c, p.c_gammas, p.c_betas) 347 | h = dummy_h + o * self.activation(c_normal) 348 | return h, c, atilde, btilde, c_normal 349 | 350 | 351 | 352 | xtilde = T.dot(x, p.Wx) 353 | 354 | if args.noise: 355 | # prime h with white noise 356 | Trng = MRG_RandomStreams() 357 | h_prime = Trng.normal((xtilde.shape[1], args.num_hidden), std=args.noise) 358 | elif args.summarize: 359 | # prime h with mean of example 360 | h_prime = x.mean(axis=[0, 2])[:, None] 361 | else: 362 | h_prime = 0 363 | 364 | dummy_states = dict(h=T.zeros((xtilde.shape[0], xtilde.shape[1], args.num_hidden)), 365 | c=T.zeros((xtilde.shape[0], xtilde.shape[1], args.num_hidden))) 366 | 367 | [h, c, atilde, btilde, htilde], _ = theano.scan( 368 | stepfn, 369 | sequences=[xtilde, dummy_states["h"], dummy_states["c"]], 370 | outputs_info=[T.repeat(p.h0[None, :], xtilde.shape[1], axis=0) + h_prime, 371 | T.repeat(p.c0[None, :], xtilde.shape[1], axis=0), 372 | None, None, None]) 373 | return dict(h=h, c=c, 374 | atilde=atilde, btilde=btilde, htilde=htilde), [], dummy_states, popstats 375 | 376 | def construct_graph_popstats(self, args, x, length, popstats=None): 377 | p = self.allocate_parameters(args) 378 | 379 | 380 | def stepfn(x, dummy_h, dummy_c, 381 | pop_means_a, pop_means_b, pop_means_c, 382 | pop_vars_a, pop_vars_b, pop_vars_c, 383 | h, c): 384 | 385 | atilde = T.dot(h, p.Wa) 386 | btilde = x 387 | if args.baseline: 388 | a_normal, a_mean, a_var = bn(atilde, 1.0, p.ab_betas, pop_means_a, pop_vars_a, args) 389 | b_normal, b_mean, b_var = bn(btilde, 1.0, 0, pop_means_b, pop_vars_b, args) 390 | else: 391 | a_normal, a_mean, a_var = bn(atilde, p.a_gammas, p.ab_betas, pop_means_a, pop_vars_a, args) 392 | b_normal, b_mean, b_var = bn(btilde, p.b_gammas, 0, pop_means_b, pop_vars_b, args) 393 | ab = a_normal + b_normal 394 | g, f, i, o = [fn(ab[:, j * args.num_hidden:(j + 1) * args.num_hidden]) 395 | for j, fn in enumerate([self.activation] + 3 * [T.nnet.sigmoid])] 396 | c = dummy_c + f * c + i * g 397 | if args.baseline: 398 | c_normal, c_mean, c_var = bn(c, 1.0, p.c_betas, pop_means_c, pop_vars_c, args) 399 | else: 400 | c_normal, c_mean, c_var = bn(c, p.c_gammas, p.c_betas, pop_means_c, pop_vars_c, args) 401 | h = dummy_h + o * self.activation(c_normal) 402 | return (h, c, atilde, btilde, c_normal, 403 | a_mean, b_mean, c_mean, 404 | a_var, b_var, c_var) 405 | 406 | 407 | xtilde = T.dot(x, p.Wx) 408 | if args.noise: 409 | # prime h with white noise 410 | Trng = MRG_RandomStreams() 411 | h_prime = Trng.normal((xtilde.shape[1], args.num_hidden), std=args.noise) 412 | elif args.summarize: 413 | # prime h with mean of example 414 | h_prime = x.mean(axis=[0, 2])[:, None] 415 | else: 416 | h_prime = 0 417 | 418 | dummy_states = dict(h=T.zeros((xtilde.shape[0], xtilde.shape[1], args.num_hidden)), 419 | c=T.zeros((xtilde.shape[0], xtilde.shape[1], args.num_hidden))) 420 | 421 | if popstats is None: 422 | popstats = OrderedDict() 423 | for key, size in zip("abc", [4*args.num_hidden, 4*args.num_hidden, args.num_hidden]): 424 | for stat, init in zip("mean var".split(), [0, 1]): 425 | name = "%s_%s" % (key, stat) 426 | popstats[name] = theano.shared( 427 | init + np.zeros((length, size,), dtype=theano.config.floatX), 428 | name=name) 429 | popstats_seq = [popstats['a_mean'], popstats['b_mean'], popstats['c_mean'], 430 | popstats['a_var'], popstats['b_var'], popstats['c_var']] 431 | 432 | [h, c, atilde, btilde, htilde, 433 | batch_mean_a, batch_mean_b, batch_mean_c, 434 | batch_var_a, batch_var_b, batch_var_c ], _ = theano.scan( 435 | stepfn, 436 | sequences=[xtilde, dummy_states["h"], dummy_states["c"]] + popstats_seq, 437 | outputs_info=[T.repeat(p.h0[None, :], xtilde.shape[1], axis=0) + h_prime, 438 | T.repeat(p.c0[None, :], xtilde.shape[1], axis=0), 439 | None, None, None, 440 | None, None, None, 441 | None, None, None]) 442 | 443 | batchstats = OrderedDict() 444 | batchstats['a_mean'] = batch_mean_a 445 | batchstats['b_mean'] = batch_mean_b 446 | batchstats['c_mean'] = batch_mean_c 447 | batchstats['a_var'] = batch_var_a 448 | batchstats['b_var'] = batch_var_b 449 | batchstats['c_var'] = batch_var_c 450 | 451 | updates = OrderedDict() 452 | if not args.use_population_statistics: 453 | alpha = 1e-2 454 | for key in "abc": 455 | for stat, init in zip("mean var".split(), [0, 1]): 456 | name = "%s_%s" % (key, stat) 457 | popstats[name].tag.estimand = batchstats[name] 458 | updates[popstats[name]] = (alpha * batchstats[name] + 459 | (1 - alpha) * popstats[name]) 460 | return dict(h=h, c=c, 461 | atilde=atilde, btilde=btilde, htilde=htilde), updates, dummy_states, popstats 462 | 463 | 464 | def construct_common_graph(situation, args, outputs, mask, dummy_states, Wy, by, y): 465 | 466 | ytilde = T.dot(outputs["h"], Wy) + by 467 | ytilde_reshape = ytilde.reshape((ytilde.shape[0] * ytilde.shape[1], ytilde.shape[2])) 468 | yhat = T.nnet.softmax(ytilde_reshape).reshape((ytilde.shape[0], ytilde.shape[1], ytilde.shape[2])) 469 | 470 | errors = T.neq(T.argmax(y, axis=2), T.argmax(yhat, axis=2)).reshape(mask.shape) 471 | flat_y = y.reshape((-1, nclasses)) 472 | flat_yhat = yhat.reshape((-1, nclasses)) 473 | cross_entropies = T.nnet.categorical_crossentropy(flat_yhat, flat_y).reshape(mask.shape) 474 | 475 | ## masking 476 | errors = mask * errors 477 | cross_entropies = mask * cross_entropies 478 | 479 | # following Amar & Martin we take the mean across time (rather than require all steps to be correct) 480 | error_rate = errors.mean().copy(name="error_rate") 481 | cross_entropy = cross_entropies.mean().copy(name="cross_entropy") 482 | cost = cross_entropy.copy(name="cost") 483 | graph = ComputationGraph([cost, cross_entropy, error_rate]) 484 | 485 | state_grads = dict((k, T.grad(cost, v)) for k, v in dummy_states.items()) 486 | 487 | extensions = [] 488 | # extensions = [ 489 | # DumpVariables("%s_hiddens" % situation, graph.inputs, 490 | # [v.copy(name="%s%s" % (k, suffix)) 491 | # for suffix, things in [("", outputs), ("_grad", state_grads)] 492 | # for k, v in things.items()], 493 | # batch=next(get_stream(which_set="train", 494 | # batch_size=args.batch_size, 495 | # num_examples=args.batch_size) 496 | # .get_epoch_iterator(as_dict=True)), 497 | # before_training=True, every_n_epochs=10)] 498 | return graph, extensions 499 | 500 | def construct_graphs(args, nclasses, length): 501 | constructor = LSTM if args.lstm else RNN 502 | 503 | Wy = theano.shared(orthogonal((args.num_hidden, nclasses)), name="Wy") 504 | by = theano.shared(np.zeros((nclasses,), dtype=theano.config.floatX), name="by") 505 | 506 | ### graph construction 507 | x, y = T.tensor3("x"), T.tensor3("y") 508 | x_mask, y_mask = T.matrix("x_mask"), T.matrix("y_mask") 509 | # data_dbg = next(get_stream(which_set="valid", 510 | # min_interval=args.interval, 511 | # max_interval=args.interval, 512 | # batch_size=batch_size).get_epoch_iterator()) 513 | #x.tag.test_value = data_dbg[2] 514 | #y.tag.test_value = data_dbg[0] 515 | #x_mask.tag.test_value = data_dbg[3] 516 | #y_mask.tag.test_value = data_dbg[1] 517 | 518 | mask = x_mask + 0 * y_mask # need to use both in the graph or theano and blocks will complain 519 | 520 | # move time axis before batch axis 521 | x = x.dimshuffle(1, 0, 2) 522 | y = y.dimshuffle(1, 0, 2) 523 | mask = mask.dimshuffle(1, 0) 524 | 525 | args.use_population_statistics = False 526 | turd = constructor(args, nclasses) 527 | (outputs, training_updates, dummy_states, popstats) = turd.construct_graph_popstats(args, x, length) 528 | training_graph, training_extensions = construct_common_graph("training", args, outputs, mask, dummy_states, Wy, by, y) 529 | 530 | args.use_population_statistics = True 531 | (inf_outputs, inference_updates, dummy_states, _) = turd.construct_graph_popstats(args, x, length, popstats=popstats) 532 | inference_graph, inference_extensions = construct_common_graph("inference", args, inf_outputs, mask, dummy_states, Wy, by, y) 533 | 534 | add_role(Wy, PARAMETER) 535 | add_role(by, PARAMETER) 536 | args.use_population_statistics = False 537 | return (dict(training=training_graph, inference=inference_graph), 538 | dict(training=training_extensions, inference=inference_extensions), 539 | dict(training=training_updates, inference=inference_updates)) 540 | 541 | 542 | if __name__ == "__main__": 543 | nclasses = 10 544 | batch_size = 100 545 | 546 | activations = dict( 547 | tanh=T.tanh, 548 | identity=lambda x: x, 549 | relu=lambda x: T.max(0, x)) 550 | 551 | import argparse 552 | parser = argparse.ArgumentParser() 553 | parser.add_argument("--seed", type=int, default=1) 554 | parser.add_argument("--num-epochs", type=int, default=100) 555 | parser.add_argument("--learning-rate", type=float, default=1e-4) 556 | parser.add_argument("--epsilon", type=float, default=1e-5) 557 | parser.add_argument("--batch-size", type=int, default=100) 558 | parser.add_argument("--noise", type=float, default=None) 559 | parser.add_argument("--summarize", action="store_true") 560 | parser.add_argument("--num-hidden", type=int, default=40) 561 | parser.add_argument("--baseline", action="store_true") 562 | parser.add_argument("--lstm", action="store_true") 563 | parser.add_argument("--initial-gamma", type=float, default=0.1) 564 | parser.add_argument("--initial-beta", type=float, default=0) 565 | parser.add_argument("--cluster", action="store_true") 566 | parser.add_argument("--activation", choices=list(activations.keys()), default="tanh") 567 | parser.add_argument("--init", type=str, default="ortho") 568 | parser.add_argument("--clipvar", action="store_true") 569 | parser.add_argument("--continue-from") 570 | parser.add_argument("--interval", type=int, default=100) 571 | args = parser.parse_args() 572 | 573 | #assert not (args.noise and args.summarize) 574 | np.random.seed(args.seed) 575 | blocks.config.config.default_seed = args.seed 576 | 577 | sequence_length = args.interval+20 578 | 579 | 580 | if args.continue_from: 581 | from blocks.serialization import load 582 | main_loop = load(args.continue_from) 583 | main_loop.run() 584 | sys.exit(0) 585 | 586 | graphs, extensions, updates = construct_graphs(args, nclasses, sequence_length) 587 | 588 | ### optimization algorithm definition 589 | step_rule = CompositeRule([ 590 | StepClipping(1.), 591 | #Momentum(learning_rate=args.learning_rate, momentum=0.9), 592 | RMSProp(learning_rate=args.learning_rate, decay_rate=0.9), 593 | ]) 594 | 595 | algorithm = GradientDescent(cost=graphs["training"].outputs[0], 596 | parameters=graphs["training"].parameters, 597 | step_rule=step_rule) 598 | algorithm.add_updates(updates["training"]) 599 | model = Model(graphs["training"].outputs[0]) 600 | extensions = extensions["training"] + extensions["inference"] 601 | 602 | 603 | # step monitor (after epoch to limit the log size) 604 | step_channels = [] 605 | step_channels.extend([ 606 | algorithm.steps[param].norm(2).copy(name="step_norm:%s" % name) 607 | for name, param in model.get_parameter_dict().items()]) 608 | step_channels.append(algorithm.total_step_norm.copy(name="total_step_norm")) 609 | step_channels.append(algorithm.total_gradient_norm.copy(name="total_gradient_norm")) 610 | step_channels.extend(graphs["training"].outputs) 611 | logger.warning("constructing training data monitor") 612 | extensions.append(TrainingDataMonitoring( 613 | step_channels, prefix="iteration", after_batch=False)) 614 | 615 | # parameter monitor 616 | extensions.append(DataStreamMonitoring( 617 | [param.norm(2).copy(name="parameter.norm:%s" % name) 618 | for name, param in model.get_parameter_dict().items()], 619 | data_stream=None, after_epoch=True)) 620 | 621 | # performance monitor 622 | for situation in "training".split(): # add inference 623 | for which_set in "train valid test".split(): 624 | logger.warning("constructing %s %s monitor" % (which_set, situation)) 625 | channels = list(graphs[situation].outputs) 626 | extensions.append(DataStreamMonitoring( 627 | channels, 628 | prefix="%s_%s" % (which_set, situation), after_epoch=True, 629 | data_stream=get_stream(which_set=which_set, 630 | min_interval=args.interval, max_interval=args.interval, 631 | batch_size=args.batch_size)))#, num_examples=1000))) 632 | for situation in "inference".split(): # add inference 633 | for which_set in "valid test".split(): 634 | logger.warning("constructing %s %s monitor" % (which_set, situation)) 635 | channels = list(graphs[situation].outputs) 636 | extensions.append(DataStreamMonitoring( 637 | channels, 638 | prefix="%s_%s" % (which_set, situation), after_epoch=True, 639 | data_stream=get_stream(which_set=which_set, 640 | min_interval=args.interval, max_interval=args.interval, 641 | batch_size=args.batch_size)))#, num_examples=1000))) 642 | 643 | extensions.extend([ 644 | TrackTheBest("valid_training_error_rate", "best_valid_training_error_rate"), 645 | DumpBest("best_valid_training_error_rate", "best.zip"), 646 | FinishAfter(after_n_epochs=args.num_epochs), 647 | #FinishIfNoImprovementAfter("best_valid_error_rate", epochs=50), 648 | Checkpoint("checkpoint.zip", on_interrupt=False, every_n_epochs=1, use_cpickle=True), 649 | DumpLog("log.pkl", after_epoch=True)]) 650 | 651 | if not args.cluster: 652 | extensions.append(ProgressBar()) 653 | 654 | extensions.extend([ 655 | Timing(), 656 | Printing(), 657 | PrintingTo("log"), 658 | ]) 659 | main_loop = MainLoop( 660 | data_stream=get_stream(which_set="train", 661 | min_interval=args.interval, max_interval=args.interval, 662 | batch_size=args.batch_size), 663 | algorithm=algorithm, extensions=extensions, model=model) 664 | main_loop.run() 665 | 666 | --------------------------------------------------------------------------------