├── .gitignore ├── LICENSE ├── R ├── process.R └── processruns.R ├── README.md ├── exp ├── exp1a ├── exp1b ├── exp1c ├── exp1d ├── exp2a ├── exp2b ├── exp2c ├── exp2d ├── exp3a ├── exp3b ├── exp3c └── exp3d ├── model.py ├── results └── results-folder.txt ├── runrnn ├── sample.py ├── save └── save-folder.txt ├── stats.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Edwin D. de Jong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /R/process.R: -------------------------------------------------------------------------------- 1 | 2 | process <- function(exp, mode, folder){ 3 | 4 | if(F){ 5 | exp='xc810at5r1' 6 | mode='train' 7 | folder = '~/code/digits/rnn' 8 | } 9 | 10 | command = paste('cd ', folder ,';./processresults out/',exp,'.txt',sep='') 11 | print(command) 12 | system(command) 13 | 14 | fn = paste('./out/',exp,'-',mode,sep=''); 15 | print(paste('reading ', fn)) 16 | epoch=read.csv(fn,sep='',row.names=NULL) 17 | 18 | #remove columns with too many missing values 19 | nrmis = apply(is.na( epoch ),2,sum) 20 | nrrows=dim(epoch)[1] 21 | sel=nrmis < .5 * nrrows 22 | epoch = epoch[,sel] 23 | 24 | if (colnames(epoch)[1]=='row.names') 25 | { 26 | colnames(epoch)=c(colnames(epoch)[-1],'x') 27 | } 28 | 29 | rmse_own_output_0=NULL 30 | rmse_own_output_1=NULL 31 | 32 | fn = paste('./results/',exp,'/rmse-own-output0.txt',sep='') 33 | if (file.exists(fn)){ 34 | rmse_own_output_0 = tryCatch({ 35 | print(paste('reading ', fn)) 36 | read.table(fn,sep='') 37 | }, error = function(e) { 38 | print(paste('error reading file',e)) 39 | }) 40 | } 41 | 42 | fn = paste('./results/',exp,'/rmse-own-output1.txt',sep='') 43 | if (file.exists(fn)){ 44 | rmse_own_output_1 = tryCatch({ 45 | print(paste('reading ', fn)) 46 | read.table(fn,sep='') 47 | }, error = function(e) { 48 | print(paste('error reading file',e)) 49 | }) 50 | } 51 | 52 | result={} 53 | result$epoch=epoch 54 | result$rmse_own_output_0=rmse_own_output_0 55 | result$rmse_own_output_1=rmse_own_output_1 56 | 57 | return(result) 58 | } 59 | -------------------------------------------------------------------------------- /R/processruns.R: -------------------------------------------------------------------------------- 1 | processruns <- function(exp, mode, runnrlist, binsize_nrpoints, windowsize, folder, required_frac_available_timepoints = .8){ 2 | deltarunnr=max(0,1-min(runnrlist)) 3 | print(paste('deltarunnr',deltarunnr)) 4 | 5 | if(F){ 6 | exp='xc810t5' 7 | mode='train' 8 | runnrlist=1 9 | binsize_nrpoints=10000 10 | runnr=1 11 | required_frac_available_timepoints=.9 12 | windowsize=100 13 | folder = '~/code/digits/rnn' 14 | 15 | } 16 | 17 | library(zoo) 18 | 19 | result=NULL 20 | nmat=NULL 21 | meanmat=NULL 22 | M2mat=NULL 23 | nrpoints = NULL 24 | tablist=list() 25 | for (runnr in runnrlist){ 26 | print(paste('r',runnr)) 27 | flush.console() 28 | 29 | print('before process') 30 | tab=process(paste(exp,'r',runnr,sep=''), mode, folder ) 31 | print('after process') 32 | 33 | 34 | tablist[[runnr+deltarunnr]]=tab$epoch 35 | if (is.null(result)){ 36 | result=tab 37 | 38 | e1=as.matrix(result$epoch) 39 | nrmis = apply(is.na(e1),1,sum) 40 | sel = which(nrmis==0) 41 | e1 = e1[sel,] 42 | 43 | xmat = result$epoch[ sel, "totnrpoints_trained" ] 44 | maxx = max( xmat, na.rm=T ) 45 | xrange = seq( 0, maxx, binsize_nrpoints ) 46 | nrbins = length( xrange ) 47 | print(paste('maxx',maxx,'binsize_nrpoints',binsize_nrpoints,'nrbins',nrbins)) 48 | 49 | e1r = nrbins 50 | e1c = dim(e1)[2] 51 | print(paste('dims',e1r,e1c)) 52 | e1new = matrix( 0, e1r, e1c ) 53 | 54 | for( c in 1:e1c ){ 55 | y = e1[ , c ] 56 | if ( is.factor( y ) ) 57 | { 58 | y = levels( y )[ y ] 59 | } 60 | y = as.numeric( y ) 61 | 62 | z = rollmean(zoo(y), k = windowsize, fill = "extend", partial = T ) 63 | zx = rollmean(zoo(xmat), k = windowsize, fill = "extend", partial = T ) 64 | z = approx( zx, z, xout = xrange )$y 65 | 66 | e1new[ , c ] = z 67 | } 68 | 69 | nrmis = apply(is.na(e1new),1,sum) 70 | sel = which(nrmis==0) 71 | 72 | e1new = e1new[sel,] 73 | result$epoch=e1new 74 | 75 | nmat = (0 * e1new) + 1 #it 1 of https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm 76 | delta= e1new 77 | meanmat=delta 78 | M2mat = e1new*e1new 79 | 80 | print(paste('dims nrpoints:',dim(result$epoch)[1])) 81 | nrpointsepoch = c( matrix( 1, dim( result$epoch)[1],1), matrix( 0, dim( result$epoch )[1],1) ) 82 | 83 | 84 | } 85 | else{ 86 | 87 | e1=as.matrix(result$epoch) 88 | e2=as.matrix(tab$epoch) 89 | 90 | nrmis = apply( is.na( e2 ), 1, sum ) 91 | sel = which( nrmis == 0 ) 92 | e2 = e2[ sel, ] 93 | 94 | e2r = nrbins 95 | e2c = dim(e2)[2] 96 | 97 | e1r = dim(e1)[1] 98 | e1c = dim(e1)[2] 99 | 100 | e2new=matrix(0,e2r,e2c) 101 | xmat = tab$epoch[ sel , "totnrpoints_trained" ] 102 | 103 | for( c in 1:e2c ){ 104 | y = e2[ , c ] 105 | if ( is.factor( y ) ) 106 | { 107 | y = levels( y )[ y ] 108 | } 109 | y = as.numeric( y ) 110 | 111 | z = rollmean(zoo(y), k = windowsize, fill = "extend", partial = T ) 112 | zx = rollmean(zoo(xmat), k = windowsize, fill = "extend", partial = T ) 113 | z = approx( zx, z, xout = xrange )$y 114 | 115 | e2new[ , c ] = z 116 | } 117 | 118 | nrmis = apply( is.na( e2new ), 1,sum ) 119 | sel = which( nrmis == 0 ) 120 | e2new = e2new[ sel, ] 121 | 122 | e2 = e2new 123 | e2r = dim( e2 )[1] 124 | e2c = dim( e2 )[2] 125 | 126 | res = matrix( 0, max( e1r, e2r ), max( e1c, e2c )) 127 | res[1:e1r,1:e1c]=e1 128 | res[1:e2r,1:e2c]=res[1:e2r,1:e2c]+e2 129 | 130 | 131 | nmatnew=matrix(0,max(e1r,e2r),max(e1c,e2c)) 132 | nmatnew[1:e1r,1:e1c]=nmat 133 | nmatnew[1:e2r,1:e2c]=nmatnew[1:e2r,1:e2c]+1 134 | nmat=nmatnew 135 | 136 | meanmatnew=matrix(0,max(e1r,e2r),max(e1c,e2c)) 137 | meanmatnew[1:e1r,1:e1c]=meanmat 138 | 139 | meanmat=meanmatnew 140 | 141 | delta = e2 - meanmat[1:e2r,1:e2c] 142 | 143 | meanmat[1:e2r,1:e2c] = meanmat[1:e2r,1:e2c] + delta / nmat[1:e2r,1:e2c] 144 | 145 | M2matnew=matrix(0,max(e1r,e2r),max(e1c,e2c)) 146 | M2matnew[1:e1r,1:e1c]=M2mat 147 | M2matnew[1:e2r,1:e2c]=M2matnew[1:e2r,1:e2c]+delta*(e2-meanmat[1:e2r,1:e2c]) 148 | M2mat=M2matnew 149 | 150 | nrpointsepoch[1:e2r] = nrpointsepoch[1:e2r] + 1 151 | result$epoch = res 152 | 153 | } 154 | } 155 | 156 | nrr=dim(result$epoch)[1] 157 | nrc=dim(result$epoch)[2] 158 | nrp=matrix(0,nrr,nrc) 159 | for(i in 1:nrc){ 160 | nrp[,i]=nrpointsepoch[1:nrr] 161 | } 162 | 163 | sel=which(nrpointsepoch >= required_frac_available_timepoints * length(runnrlist)) 164 | nrpointsepoch=nrpointsepoch[sel] 165 | 166 | avg = result$epoch[sel,] / nrp[sel,] 167 | colnames(avg)=colnames(tab$epoch) 168 | 169 | nmat[is.na(nmat)]=0 170 | sel = nmat >= 2 171 | std = NA * meanmat 172 | std[sel]=sqrt(M2mat[sel]/(nmat[sel]-1)) 173 | colnames(std) = colnames(avg) 174 | 175 | result={} 176 | result$avg=avg 177 | result$std=std 178 | result$tabs=tablist 179 | result$nrpoints = nrpointsepoch 180 | return(result) 181 | } 182 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Incremental-sequence-learning 2 | Implementation of the Incremental Sequence Learning algorithms described in the Incremental Sequence Learning article. 3 | 4 | #Requirements 5 | Python 3.5 6 | 7 | Tensorflow 0.9 8 | 9 | 10 | #Getting started 11 | Parameter files for the first 3 experiments described in the article are available as exp/exp1a..d, exp/exp2a..d, and exp/exp3a..d. The a, b, c, and d variant represent the four different configurations compared in the article. 12 | 13 | To start a run for experiment 1a, use: 14 | 15 | ./runrnn exp1a --runnr 1 16 | 17 | 18 | #Data 19 | This project makes use of the MNIST stroke sequence data set, available here: 20 | 21 | https://github.com/edwin-de-jong/mnist-digits-stroke-sequence-data/wiki/MNIST-digits-stroke-sequence-data 22 | 23 | 24 | #Results 25 | 26 | I have included the R scripts used to extract results from the output files. To process the results, you can use: 27 | 28 | source('R/process.R') 29 | 30 | source('R/processruns.R') 31 | 32 | binsize = 1000 33 | 34 | requiredfraction = .9 #fraction of the files required to be available for reporting output 35 | 36 | windowsize = 1 37 | 38 | folder = '~/code/digits/rnn' 39 | 40 | exp1atrain = processruns( 'exp1a', 'train', 1, binsize, windowsize, folder, requiredfraction ) 41 | 42 | exp1atest = processruns( 'exp1a', 'test', 1, binsize, windowsize, folder, requiredfraction ) 43 | 44 | #Acknowledgements 45 | 46 | The network architecture used in this work is based on the article [Generating Sequences With Recurrent Neural Networks](https://arxiv.org/pdf/1308.0850v5.pdf) by Alex Graves. 47 | 48 | The implementation is based on the [write-rnn-tensorflow](https://github.com/hardmaru/write-rnn-tensorflow) by [hardmaru](https://github.com/hardmaru), which in turn is based on 49 | the [char-rnn-tensorflow](https://github.com/sherjilozair/char-rnn-tensorflow) implementation by [sherjilozair](https://github.com/sherjilozair). See the blog post [Handwriting Generation Demo in TensorFlow](http://blog.otoro.net/2015/12/12/handwriting-generation-demo-in-tensorflow/). 50 | 51 | -------------------------------------------------------------------------------- /exp/exp1a: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | params=" 3 | --num_mixture 17 4 | --rnn_size 200 5 | --num_layers 2 6 | --nrseq_per_batch 50 7 | 8 | --nrinputfiles_train 10000 9 | --nrinputfiles_test 5000 10 | --max_seq_length 100 11 | --num_epochs 100000 12 | --eval_every 1 13 | --report_every 100 14 | --save_every_nrbatches 500 15 | --keep_prob 1 16 | --learning_rate 0.0025 17 | --decay_rate .99995 18 | --train_on_own_output_method 0 19 | --sample_from_output 0 20 | --regularization_factor 0.25 21 | --l2_weight_regularization 0 22 | --max_weight_regularization 1 23 | --nrClassOutputVars 10 24 | --discard_classvar_inputs 1 25 | --randomizeSequenceOrder 1 26 | --classweightfactor 0 27 | --curnrtrainexamples 5000 28 | --correctfrac_threshold_inc_nrtrainex 0.8 29 | --useStrokeOutputVars 1 30 | --useClassInputVars 0 31 | --useStrokeLoss 1 32 | --usernn 1 33 | --incremental_nr_trainexamples 0 34 | --incremental_seq_length 0 35 | --current_seq_length 2 36 | --threshold_rmse_stroke 4 37 | --incremental_nr_digits 0 38 | --curnrdigits 10 39 | --test_every_nrbatches 10 40 | --maxnrpoints 10000000 41 | --stat_windowsize_nrsequences 100 42 | " 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /exp/exp1b: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | params=" 3 | --num_mixture 17 4 | --rnn_size 200 5 | --num_layers 2 6 | --nrseq_per_batch 50 7 | 8 | --nrinputfiles_train 10000 9 | --nrinputfiles_test 5000 10 | --max_seq_length 100 11 | --num_epochs 100000 12 | --eval_every 1 13 | --report_every 100 14 | --save_every_nrbatches 500 15 | --keep_prob 1 16 | --learning_rate 0.0025 17 | --decay_rate .99995 18 | --train_on_own_output_method 0 19 | --sample_from_output 0 20 | --regularization_factor 0.25 21 | --l2_weight_regularization 0 22 | --max_weight_regularization 1 23 | --nrClassOutputVars 10 24 | --discard_classvar_inputs 1 25 | --randomizeSequenceOrder 1 26 | --classweightfactor 0 27 | --curnrtrainexamples 5000 28 | --correctfrac_threshold_inc_nrtrainex 0.8 29 | --useStrokeOutputVars 1 30 | --useClassInputVars 0 31 | --useStrokeLoss 1 32 | --usernn 1 33 | --incremental_nr_trainexamples 0 34 | --incremental_seq_length 1 35 | --current_seq_length 2 36 | --threshold_rmse_stroke 4 37 | --incremental_nr_digits 0 38 | --curnrdigits 10 39 | --test_every_nrbatches 10 40 | --maxnrpoints 10000000 41 | --stat_windowsize_nrsequences 100 42 | " 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /exp/exp1c: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | params=" 3 | --num_mixture 17 4 | --rnn_size 200 5 | --num_layers 2 6 | --nrseq_per_batch 50 7 | 8 | --nrinputfiles_train 10000 9 | --nrinputfiles_test 5000 10 | --max_seq_length 100 11 | --num_epochs 100000 12 | --eval_every 1 13 | --report_every 100 14 | --save_every_nrbatches 500 15 | --keep_prob 1 16 | --learning_rate 0.0025 17 | --decay_rate .99995 18 | --train_on_own_output_method 0 19 | --sample_from_output 0 20 | --regularization_factor 0.25 21 | --l2_weight_regularization 0 22 | --max_weight_regularization 1 23 | --nrClassOutputVars 10 24 | --discard_classvar_inputs 1 25 | --randomizeSequenceOrder 1 26 | --classweightfactor 0 27 | --curnrtrainexamples 5000 28 | --correctfrac_threshold_inc_nrtrainex 0.8 29 | --useStrokeOutputVars 1 30 | --useClassInputVars 0 31 | --useStrokeLoss 1 32 | --usernn 1 33 | --incremental_nr_trainexamples 0 34 | --incremental_seq_length 0 35 | --current_seq_length 2 36 | --threshold_rmse_stroke 4 37 | --incremental_nr_digits 1 38 | --curnrdigits 1 39 | --test_every_nrbatches 10 40 | --maxnrpoints 10000000 41 | --stat_windowsize_nrsequences 100 42 | " 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /exp/exp1d: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | params=" 3 | --num_mixture 17 4 | --rnn_size 200 5 | --num_layers 2 6 | --nrseq_per_batch 50 7 | 8 | --nrinputfiles_train 10000 9 | --nrinputfiles_test 5000 10 | --max_seq_length 100 11 | --num_epochs 100000 12 | --eval_every 1 13 | --report_every 100 14 | --save_every_nrbatches 500 15 | --keep_prob 1 16 | --learning_rate 0.0025 17 | --decay_rate .99995 18 | --train_on_own_output_method 0 19 | --sample_from_output 0 20 | --regularization_factor 0.25 21 | --l2_weight_regularization 0 22 | --max_weight_regularization 1 23 | --nrClassOutputVars 10 24 | --discard_classvar_inputs 1 25 | --randomizeSequenceOrder 1 26 | --classweightfactor 0 27 | --curnrtrainexamples 10 28 | --correctfrac_threshold_inc_nrtrainex 0.8 29 | --useStrokeOutputVars 1 30 | --useClassInputVars 0 31 | --useStrokeLoss 1 32 | --usernn 1 33 | --incremental_nr_trainexamples 1 34 | --incremental_seq_length 0 35 | --current_seq_length 2 36 | --threshold_rmse_stroke 4 37 | --incremental_nr_digits 0 38 | --curnrdigits 10 39 | --test_every_nrbatches 10 40 | --maxnrpoints 10000000 41 | --stat_windowsize_nrsequences 100 42 | " 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /exp/exp2a: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | params=" 3 | --num_mixture 17 4 | --rnn_size 200 5 | --num_layers 2 6 | --nrseq_per_batch 50 7 | --nrpoints_per_batch 2000 8 | --nrinputfiles_train 10000 9 | --nrinputfiles_test 5000 10 | --max_seq_length 100 11 | --num_epochs 100000 12 | --eval_every 1 13 | --report_every 100 14 | --save_every_nrbatches 500 15 | --keep_prob 1 16 | --learning_rate 0.0025 17 | --decay_rate .99995 18 | --train_on_own_output_method 0 19 | --sample_from_output 0 20 | --regularization_factor 0.25 21 | --l2_weight_regularization 0 22 | --max_weight_regularization 1 23 | --nrClassOutputVars 10 24 | --discard_classvar_inputs 1 25 | --randomizeSequenceOrder 1 26 | --classweightfactor 0 27 | --curnrtrainexamples 5000 28 | --correctfrac_threshold_inc_nrtrainex 0.8 29 | --useStrokeOutputVars 1 30 | --useClassInputVars 0 31 | --useStrokeLoss 1 32 | --usernn 1 33 | --incremental_nr_trainexamples 0 34 | --incremental_seq_length 0 35 | --current_seq_length 2 36 | --threshold_rmse_stroke 4 37 | --incremental_nr_digits 0 38 | --curnrdigits 10 39 | --test_every_nrbatches 10 40 | --maxnrpoints 10000000 41 | --stat_windowsize_nrsequences 100 42 | " 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /exp/exp2b: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | params=" 3 | --num_mixture 17 4 | --rnn_size 200 5 | --num_layers 2 6 | --nrseq_per_batch 50 7 | --nrpoints_per_batch 2000 8 | --nrinputfiles_train 10000 9 | --nrinputfiles_test 5000 10 | --max_seq_length 100 11 | --num_epochs 100000 12 | --eval_every 1 13 | --report_every 100 14 | --save_every_nrbatches 500 15 | --keep_prob 1 16 | --learning_rate 0.0025 17 | --decay_rate .99995 18 | --train_on_own_output_method 0 19 | --sample_from_output 0 20 | --regularization_factor 0.25 21 | --l2_weight_regularization 0 22 | --max_weight_regularization 1 23 | --nrClassOutputVars 10 24 | --discard_classvar_inputs 1 25 | --randomizeSequenceOrder 1 26 | --classweightfactor 0 27 | --curnrtrainexamples 5000 28 | --correctfrac_threshold_inc_nrtrainex 0.8 29 | --useStrokeOutputVars 1 30 | --useClassInputVars 0 31 | --useStrokeLoss 1 32 | --usernn 1 33 | --incremental_nr_trainexamples 0 34 | --incremental_seq_length 1 35 | --current_seq_length 2 36 | --threshold_rmse_stroke 4 37 | --incremental_nr_digits 0 38 | --curnrdigits 10 39 | --test_every_nrbatches 10 40 | --maxnrpoints 10000000 41 | --stat_windowsize_nrsequences 100 42 | " 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /exp/exp2c: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | params=" 3 | --num_mixture 17 4 | --rnn_size 200 5 | --num_layers 2 6 | --nrseq_per_batch 50 7 | --nrpoints_per_batch 2000 8 | --nrinputfiles_train 10000 9 | --nrinputfiles_test 5000 10 | --max_seq_length 100 11 | --num_epochs 100000 12 | --eval_every 1 13 | --report_every 100 14 | --save_every_nrbatches 500 15 | --keep_prob 1 16 | --learning_rate 0.0025 17 | --decay_rate .99995 18 | --train_on_own_output_method 0 19 | --sample_from_output 0 20 | --regularization_factor 0.25 21 | --l2_weight_regularization 0 22 | --max_weight_regularization 1 23 | --nrClassOutputVars 10 24 | --discard_classvar_inputs 1 25 | --randomizeSequenceOrder 1 26 | --classweightfactor 0 27 | --curnrtrainexamples 5000 28 | --correctfrac_threshold_inc_nrtrainex 0.8 29 | --useStrokeOutputVars 1 30 | --useClassInputVars 0 31 | --useStrokeLoss 1 32 | --usernn 1 33 | --incremental_nr_trainexamples 0 34 | --incremental_seq_length 0 35 | --current_seq_length 2 36 | --threshold_rmse_stroke 4 37 | --incremental_nr_digits 1 38 | --curnrdigits 1 39 | --test_every_nrbatches 10 40 | --maxnrpoints 10000000 41 | --stat_windowsize_nrsequences 100 42 | " 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /exp/exp2d: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | params=" 3 | --num_mixture 17 4 | --rnn_size 200 5 | --num_layers 2 6 | --nrseq_per_batch 50 7 | --nrpoints_per_batch 2000 8 | --nrinputfiles_train 10000 9 | --nrinputfiles_test 5000 10 | --max_seq_length 100 11 | --num_epochs 100000 12 | --eval_every 1 13 | --report_every 100 14 | --save_every_nrbatches 500 15 | --keep_prob 1 16 | --learning_rate 0.0025 17 | --decay_rate .99995 18 | --train_on_own_output_method 0 19 | --sample_from_output 0 20 | --regularization_factor 0.25 21 | --l2_weight_regularization 0 22 | --max_weight_regularization 1 23 | --nrClassOutputVars 10 24 | --discard_classvar_inputs 1 25 | --randomizeSequenceOrder 1 26 | --classweightfactor 0 27 | --curnrtrainexamples 10 28 | --correctfrac_threshold_inc_nrtrainex 0.8 29 | --useStrokeOutputVars 1 30 | --useClassInputVars 0 31 | --useStrokeLoss 1 32 | --usernn 1 33 | --incremental_nr_trainexamples 1 34 | --incremental_seq_length 0 35 | --current_seq_length 2 36 | --threshold_rmse_stroke 4 37 | --incremental_nr_digits 0 38 | --curnrdigits 10 39 | --test_every_nrbatches 10 40 | --maxnrpoints 10000000 41 | --stat_windowsize_nrsequences 100 42 | " 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /exp/exp3a: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | params=" 3 | --num_mixture 17 4 | --rnn_size 200 5 | --num_layers 2 6 | --nrseq_per_batch 50 7 | --nrpoints_per_batch 2000 8 | --nrinputfiles_train 10000 9 | --nrinputfiles_test 5000 10 | --max_seq_length 100 11 | --num_epochs 100000 12 | --eval_every 1 13 | --report_every 100 14 | --save_every_nrbatches 500 15 | --keep_prob 1 16 | --learning_rate 0.0025 17 | --decay_rate .99995 18 | --train_on_own_output_method 0 19 | --sample_from_output 0 20 | --regularization_factor 0.25 21 | --l2_weight_regularization 0 22 | --max_weight_regularization 1 23 | --nrClassOutputVars 10 24 | --discard_classvar_inputs 1 25 | --randomizeSequenceOrder 1 26 | --classweightfactor 0 27 | --curnrtrainexamples 5000 28 | --correctfrac_threshold_inc_nrtrainex 0.8 29 | --useStrokeOutputVars 1 30 | --useClassInputVars 0 31 | --useStrokeLoss 1 32 | --usernn 1 33 | --incremental_nr_trainexamples 0 34 | --incremental_seq_length 0 35 | --current_seq_length 2 36 | --threshold_rmse_stroke 4 37 | --incremental_nr_digits 0 38 | --curnrdigits 10 39 | --test_every_nrbatches 10 40 | --maxnrpoints 10000000 41 | --model ffnn 42 | --stat_windowsize_nrsequences 1000 43 | " 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /exp/exp3b: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | params=" 3 | --num_mixture 17 4 | --rnn_size 200 5 | --num_layers 2 6 | --nrseq_per_batch 50 7 | --nrpoints_per_batch 2000 8 | --nrinputfiles_train 10000 9 | --nrinputfiles_test 5000 10 | --max_seq_length 100 11 | --num_epochs 100000 12 | --eval_every 1 13 | --report_every 100 14 | --save_every_nrbatches 500 15 | --keep_prob 1 16 | --learning_rate 0.0025 17 | --decay_rate .99995 18 | --train_on_own_output_method 0 19 | --sample_from_output 0 20 | --regularization_factor 0.25 21 | --l2_weight_regularization 0 22 | --max_weight_regularization 1 23 | --nrClassOutputVars 10 24 | --discard_classvar_inputs 1 25 | --randomizeSequenceOrder 1 26 | --classweightfactor 0 27 | --curnrtrainexamples 5000 28 | --correctfrac_threshold_inc_nrtrainex 0.8 29 | --useStrokeOutputVars 1 30 | --useClassInputVars 0 31 | --useStrokeLoss 1 32 | --usernn 1 33 | --incremental_nr_trainexamples 0 34 | --incremental_seq_length 1 35 | --current_seq_length 2 36 | --threshold_rmse_stroke 4 37 | --incremental_nr_digits 0 38 | --curnrdigits 10 39 | --test_every_nrbatches 10 40 | --maxnrpoints 10000000 41 | --model ffnn 42 | --stat_windowsize_nrsequences 1000 43 | " 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /exp/exp3c: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | params=" 3 | --num_mixture 17 4 | --rnn_size 200 5 | --num_layers 2 6 | --nrseq_per_batch 50 7 | --nrpoints_per_batch 2000 8 | --nrinputfiles_train 10000 9 | --nrinputfiles_test 5000 10 | --max_seq_length 100 11 | --num_epochs 100000 12 | --eval_every 1 13 | --report_every 100 14 | --save_every_nrbatches 500 15 | --keep_prob 1 16 | --learning_rate 0.0025 17 | --decay_rate .99995 18 | --train_on_own_output_method 0 19 | --sample_from_output 0 20 | --regularization_factor 0.25 21 | --l2_weight_regularization 0 22 | --max_weight_regularization 1 23 | --nrClassOutputVars 10 24 | --discard_classvar_inputs 1 25 | --randomizeSequenceOrder 1 26 | --classweightfactor 0 27 | --curnrtrainexamples 5000 28 | --correctfrac_threshold_inc_nrtrainex 0.8 29 | --useStrokeOutputVars 1 30 | --useClassInputVars 0 31 | --useStrokeLoss 1 32 | --usernn 1 33 | --incremental_nr_trainexamples 0 34 | --incremental_seq_length 0 35 | --current_seq_length 2 36 | --threshold_rmse_stroke 4 37 | --incremental_nr_digits 1 38 | --curnrdigits 1 39 | --test_every_nrbatches 10 40 | --maxnrpoints 10000000 41 | --model ffnn 42 | --stat_windowsize_nrsequences 1000 43 | " 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /exp/exp3d: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | params=" 3 | --num_mixture 17 4 | --rnn_size 200 5 | --num_layers 2 6 | --nrseq_per_batch 50 7 | --nrpoints_per_batch 2000 8 | --nrinputfiles_train 10000 9 | --nrinputfiles_test 5000 10 | --max_seq_length 100 11 | --num_epochs 100000 12 | --eval_every 1 13 | --report_every 100 14 | --save_every_nrbatches 500 15 | --keep_prob 1 16 | --learning_rate 0.0025 17 | --decay_rate .99995 18 | --train_on_own_output_method 0 19 | --sample_from_output 0 20 | --regularization_factor 0.25 21 | --l2_weight_regularization 0 22 | --max_weight_regularization 1 23 | --nrClassOutputVars 10 24 | --discard_classvar_inputs 1 25 | --randomizeSequenceOrder 1 26 | --classweightfactor 0 27 | --curnrtrainexamples 10 28 | --correctfrac_threshold_inc_nrtrainex 0.8 29 | --useStrokeOutputVars 1 30 | --useClassInputVars 0 31 | --useStrokeLoss 1 32 | --usernn 1 33 | --incremental_nr_trainexamples 1 34 | --incremental_seq_length 0 35 | --current_seq_length 2 36 | --threshold_rmse_stroke 4 37 | --incremental_nr_digits 0 38 | --curnrdigits 10 39 | --test_every_nrbatches 10 40 | --maxnrpoints 10000000 41 | --model ffnn 42 | --stat_windowsize_nrsequences 1000 43 | " 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import random 4 | 5 | def get_pi_idx( x, pdf ): 6 | N = pdf.size 7 | accumulate = 0 8 | for i in range( 0, N ): 9 | accumulate += pdf[ i ] 10 | if ( accumulate >= x ): 11 | return i 12 | print( 'error with sampling ensemble' ) 13 | return -1 14 | 15 | class Model( ): 16 | 17 | def get_classvars( self, args, output ): 18 | z = output 19 | 20 | last = args.nroutputvars_raw - args.nrClassOutputVars 21 | 22 | classvars = tf.zeros( 1, dtype = tf.float32, name = None ) 23 | classpred = tf.zeros( 1, dtype = tf.float32, name = None ) 24 | 25 | if args.nrClassOutputVars > 0: 26 | classvars = z[ :, last: ] 27 | classpred = tf.nn.softmax( classvars ) 28 | 29 | return [ classvars, classpred ] 30 | 31 | # below is where we need to do MDN splitting of distribution params 32 | def get_mixture_coef( self, args, output ): 33 | # returns the tf slices containing mdn dist params 34 | # ie, eq 18 -> 23 of http://arxiv.org/abs/1308.0850 35 | z = output 36 | 37 | #get the remaining parameters 38 | last = args.nroutputvars_raw - args.nrClassOutputVars 39 | 40 | z_eos = z[ :, 0 ] 41 | z_eos = tf.sigmoid( z_eos ) #eos: sigmoid, eq 18 42 | 43 | z_eod = z[ :, 1 ] 44 | z_eod = tf.sigmoid( z_eod ) #eod: sigmoid 45 | 46 | z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split( z[ :, 2:last ], 6, 1 ) #eq 20: mu1, mu2: no transformation required 47 | 48 | # process output z's into MDN parameters 49 | 50 | # softmax all the pi's: 51 | max_pi = tf.reduce_max( z_pi, 1, keep_dims = True ) 52 | z_pi = tf.subtract( z_pi, max_pi ) #EdJ: subtract max pi for numerical stabilization 53 | 54 | z_pi = tf.exp( z_pi ) #eq 19 55 | normalize_pi = tf.reciprocal( tf.reduce_sum( z_pi, 1, keep_dims = True ) ) 56 | z_pi = tf.multiply( normalize_pi, z_pi ) #19 57 | 58 | # exponentiate the sigmas and also make corr between -1 and 1. 59 | z_sigma1 = tf.exp( z_sigma1 ) #eq 21 60 | z_sigma2 = tf.exp( z_sigma2 ) 61 | z_corr_tanh = tf.tanh( z_corr ) #eq 22 62 | z_corr_tanh = .95 * z_corr_tanh #avoid -1 and 1 63 | 64 | z_corr_tanh_adj = z_corr_tanh 65 | 66 | return [ z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr_tanh_adj, z_eos, z_eod ] 67 | 68 | def sample_gaussian_2d( self, mu1, mu2, s1, s2, rho ): 69 | mean = [ mu1, mu2 ] 70 | cov = [ [ s1 * s1, rho * s1 * s2 ], [ rho * s1 * s2, s2 * s2 ] ] 71 | x = np.random.multivariate_normal( mean, cov, 1 ) 72 | return x[ 0 ][ 0 ], x[ 0 ][ 1 ] 73 | 74 | def tf_2d_normal( self, x1, x2, mu1, mu2, s1, s2, rho ): 75 | # eq # 24 and 25 of http://arxiv.org/abs/1308.0850 76 | #dims: mu1, mu2: batch_nrpoints x nrmixtures 77 | norm1 = tf.subtract( x1, mu1 ) #batch_nrpoints x nrmixtures 78 | norm2 = tf.subtract( x2, mu2 ) 79 | s1s2 = tf.multiply( s1, s2 ) 80 | normprod = tf.multiply( norm1, norm2 ) #batch_nrpoints x nrmixtures; here x1 and x2 are combined 81 | 82 | epsilon = 1e-10 83 | self.z = tf.square( tf.div( norm1, s1 + epsilon ) ) + tf.square( tf.div( norm2, s2 + epsilon ) ) - 2 * tf.div( tf.multiply( rho, normprod ), s1s2 + epsilon ) #batch_nrpoints x nrmixtures 84 | negRho = 1 - tf.square( rho ) #EdJ: Problem: can become 0 if corr is 1 --> denom becomes zero --> nan result, resolved by multiplying z_corr_tanh with 0.95 85 | result5 = tf.exp( tf.div( - self.z, 2 * negRho ) ) 86 | 87 | self.denom = 2 * np.pi * tf.multiply( s1s2, tf.sqrt( negRho ) ) 88 | self.result6 = tf.div( result5, self.denom ) 89 | 90 | return self.result6 #still batch_nrpoints x nrmixtures 91 | 92 | def getRegularizationTerm( self, args ): 93 | 94 | trainablevars = tf.trainable_variables( ) 95 | self.weights = [ ] 96 | weightsum = tf.zeros( 1, dtype = tf.float32, name = None ) 97 | nrweights = tf.zeros( 1, dtype = tf.int32, name = None ) 98 | self.maxabsweight = tf.zeros( 1, dtype = tf.float32, name = None ) 99 | 100 | for var in trainablevars: 101 | isBias = var.name.find( "Bias" ) >= 0 102 | if isBias: 103 | print ( "Found trainable variable: ", var.name ) 104 | else: 105 | print ( "Found trainable variable: ", var.name , "; adding to regularization term" ) 106 | self.weights.append( var ) 107 | weightsum = weightsum + tf.reduce_sum( tf.abs( var ) ) 108 | nrweights = tf.add( nrweights , tf.reduce_prod( tf.shape( var ) ) ) 109 | maxval = tf.reduce_max( tf.abs( var ) ) 110 | self.maxabsweight = tf.maximum( maxval, self.maxabsweight ) 111 | self.avgweight = weightsum / tf.to_float( nrweights ) 112 | 113 | regularization_term = tf.zeros( 1, dtype = tf.float32, name = None ) 114 | nrvalues = tf.zeros( 1, dtype = tf.int32, name = None ) 115 | for weight in self.weights: 116 | if args.l2_weight_regularization: 117 | regularization_term = regularization_term + tf.nn.l2_loss( weight ) 118 | nrvalues = tf.add( nrvalues, tf.reduce_prod( tf.shape( weight ) ) ) 119 | if args.max_weight_regularization: 120 | regularization_term = tf.maximum( regularization_term, tf.reduce_max( weight ) ) 121 | if args.l2_weight_regularization: 122 | regularization_term = tf.div( regularization_term, nrvalues ) 123 | return args.regularization_factor * regularization_term 124 | 125 | def get_stroke_loss( self, args, z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_eos, z_eod, x1_data, x2_data, eos_data, eod_data, targetdata_classvars ): 126 | 127 | self.mask = tf.sign( tf.abs( tf.reduce_max( targetdata_classvars, reduction_indices = 1 ) ) ) 128 | self.result0 = tf.squeeze( self.tf_2d_normal( x1_data, x2_data, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr ) ) #batch_nrpoints x nrmixtures 129 | 130 | # implementing eq # 26 of http://arxiv.org/abs/1308.0850 131 | epsilon = 1e-10 132 | self.result1 = tf.multiply( self.result0, z_pi ) 133 | self.lossvector = self.result1 134 | self.result1 = tf.reduce_sum( self.result1, 1, keep_dims = True ) #batch_nrpoints x 1 135 | self.result1 = tf.squeeze( -tf.log( self.result1 + epsilon ) ) # at the beginning, some errors are exactly zero. 136 | self.result1_nomask = self.result1 137 | 138 | eos_data = tf.squeeze( eos_data ) 139 | self.z_eos = z_eos 140 | self.eos_data = eos_data 141 | self.result2 = tf.multiply( z_eos, eos_data ) + tf.multiply( 1 - z_eos, 1 - eos_data ) #eq 26 rightmost part 142 | self.result2 = -tf.log( self.result2 + epsilon ) 143 | 144 | eod_data = tf.squeeze( eod_data ) 145 | self.result3 = tf.multiply( z_eod, eod_data ) + tf.multiply( 1 - z_eod, 1 - eod_data ) #analogous for eod 146 | self.result3 = -tf.log( self.result3 + epsilon ) 147 | 148 | self.result = self.result1 + self.result2 + self.result3 149 | 150 | self.result_before_mask = self.result 151 | self.result *= self.mask #checked EdJ Oct 15: correctly applies mask to include loss for used points only, depending on current sequence length 152 | 153 | self.lossnrpoints = tf.reduce_sum( self.mask ) 154 | 155 | stroke_loss = tf.reduce_sum( self.result ) / self.lossnrpoints 156 | return stroke_loss 157 | 158 | def get_class_loss( self, args, z_classvars, z_classpred, targetdata_classvars ): 159 | self.mask = tf.sign( tf.abs( tf.reduce_max( targetdata_classvars, reduction_indices = 1 ) ) ) 160 | 161 | self.result4 = tf.zeros( 1, dtype = tf.float32, name = None ) 162 | if args.nrClassOutputVars > 0 and args.classweightfactor > 0: 163 | self.crossentropy = tf.nn.softmax_cross_entropy_with_logits( z_classvars, targetdata_classvars ) 164 | self.result4 = args.classweightfactor * self.crossentropy 165 | self.result4 = tf.multiply( self.mask, self.result4 ) 166 | self.targetdata_classvars = targetdata_classvars 167 | 168 | self.result = self.result4 169 | 170 | self.result_before_mask = self.result 171 | self.result *= self.mask #checked EdJ Sept 2: correctly only measures loss up to last point of actual sequence. 172 | self.lossvector = self.result 173 | 174 | self.lossnrpoints = tf.reduce_sum( self.mask ) 175 | 176 | classloss = tf.reduce_sum( self.result ) / self.lossnrpoints 177 | return classloss 178 | 179 | def __init__( self, args, trainpredictmode, infer = False, nrinputvars_network = 1, nroutputvars_raw = 1, nrtargetvars = 1, nrauxoutputvars = 0, rangemin = 0, rangelen = 1, maxdigitlength_nrpoints = 1 ): 180 | self.args = args 181 | 182 | self.result0 = tf.zeros( 1, dtype = tf.float32, name = None ) 183 | self.result1 = tf.zeros( 1, dtype = tf.float32, name = None ) 184 | self.result1_nomask = tf.zeros( 1, dtype = tf.float32, name = None ) 185 | self.result2 = tf.zeros( 1, dtype = tf.float32, name = None ) 186 | self.result3 = tf.zeros( 1, dtype = tf.float32, name = None ) 187 | self.result4 = tf.zeros( 1, dtype = tf.float32, name = None ) 188 | self.crossentropy = tf.zeros( 1, dtype = tf.float32, name = None ) 189 | self.lossvector = tf.zeros( 1, dtype = tf.float32, name = None ) 190 | self.targetdata_classvars = tf.zeros( 1, dtype = tf.float32, name = None ) 191 | 192 | self.nrinputvars_network = nrinputvars_network 193 | self.nroutputvars_raw = nroutputvars_raw 194 | self.nrauxoutputvars = nrauxoutputvars 195 | self.maxdigitlength_nrpoints = maxdigitlength_nrpoints 196 | self.max_seq_length = args.max_seq_length 197 | self.seq_length = min( self.max_seq_length, self.maxdigitlength_nrpoints ) 198 | self.regularization_term = tf.zeros( 1, dtype = tf.float32, name = None ) 199 | o_classvars = tf.zeros( 2, dtype = tf.float32, name = None ) 200 | o_classpred = tf.zeros( 2, dtype = tf.float32, name = None ) 201 | 202 | if infer: 203 | self.seq_length = 2 #will be reduced by 1 204 | 205 | self.batch_size_ph = tf.placeholder( dtype = tf.int32 ) 206 | self.seq_length_ph = tf.placeholder( dtype = tf.int32 ) 207 | 208 | if args.model == 'rnn': 209 | cell_fn = tf.nn.rnn_cell.BasicRNNCell 210 | elif args.model == 'gru': 211 | cell_fn = tf.nn.rnn_cell.GRUCell 212 | elif args.model == 'basiclstm': 213 | cell_fn = tf.nn.rnn_cell.BasicLSTMCell 214 | elif args.model == 'lstm': 215 | cell_fn = tf.nn.rnn_cell.LSTMCell 216 | elif args.model == 'ffnn': 217 | cell_fn = 0 218 | else: 219 | raise Exception( "model type not supported: {}".format( args.model ) ) 220 | 221 | useInitializers = False 222 | if hasattr( args, 'useInitializers' ): 223 | useInitializers = args.useInitializers 224 | 225 | if args.model == 'ffnn': #regular variables, no rnn 226 | nrinputs = nrinputvars_network 227 | nrhidden = args.rnn_size 228 | nroutputs = self.nroutputvars_raw 229 | 230 | if useInitializers: 231 | self.init_op_weights_ffnn = tf.random_normal( [ nrinputs, nrhidden ], dtype = tf.float32, name = None, seed = random.random( ) ) 232 | init_op_bias_ffnn = tf.zeros( [ nrhidden ], dtype = tf.float32, name = None ) 233 | if args.num_layers > 0: 234 | if useInitializers: 235 | weightsh1 = tf.get_variable( "weightsh1", initializer = self.init_op_weights_ffnn ) 236 | biasesh1 = tf.get_variable( "biasesh1", initializer = init_op_bias_ffnn ) 237 | else: 238 | weightsh1 = tf.get_variable( "weightsh1", [ nrinputs, nrhidden ] ) 239 | biasesh1 = tf.get_variable( "biasesh1", [ nrhidden ] ) 240 | 241 | if args.num_layers > 1: 242 | if useInitializers: 243 | weightsh2 = tf.get_variable( "weightsh2", initializer = self.init_op_weights_ffnn ) 244 | biasesh2 = tf.get_variable( "biasesh2", initializer = init_op_bias_ffnn ) 245 | else: 246 | weightsh2 = tf.get_variable( "weightsh2", [ nrhidden, nrhidden ] ) 247 | biasesh2 = tf.get_variable( "biasesh2", [ nrhidden ] ) 248 | layers = tf.zeros( [ 1 ] ) 249 | 250 | else: 251 | if args.model == 'lstm': 252 | layers = cell_fn( args.rnn_size, use_peepholes = True ) 253 | else: 254 | layers = cell_fn( args.rnn_size ) 255 | 256 | if args.num_layers > 0: 257 | 258 | rnn_layers= [] 259 | for li in range( args.num_layers ): 260 | if args.model == 'lstm': 261 | layer = cell_fn(args.rnn_size, use_peepholes=True) 262 | else: 263 | layer = cell_fn(args.rnn_size) 264 | 265 | rnn_layers.append(layer) 266 | layers = tf.contrib.rnn.MultiRNNCell(cells=rnn_layers, state_is_tuple=True) 267 | 268 | else: 269 | if args.model == 'lstm': 270 | layers = cell_fn(args.rnn_size, use_peepholes=True) 271 | else: 272 | layers = cell_fn(args.rnn_size) 273 | 274 | if ( infer == False and args.keep_prob < 1 ): # training mode 275 | layers = tf.nn.rnn_cell.DropoutWrapper( layers, output_keep_prob = args.keep_prob ) 276 | 277 | self.layers = layers 278 | 279 | if infer: 280 | self.input_data = tf.placeholder( dtype = tf.float32, shape = [ None, 1, nrinputvars_network ] ) 281 | self.target_data = tf.placeholder( dtype = tf.float32, shape = [ None, 1, nrtargetvars ] ) 282 | else: 283 | self.input_data = tf.placeholder( dtype = tf.float32, shape = [ None, self.seq_length - 1, nrinputvars_network ] ) 284 | self.target_data = tf.placeholder( dtype = tf.float32, shape = [ None, self.seq_length - 1, nrtargetvars ] ) 285 | self.batch_size_ph = tf.placeholder( tf.int32, [] ) 286 | 287 | if args.model == "ffnn": 288 | self.initial_state = tf.zeros( [ 1 ] ) 289 | else: 290 | self.initial_state = state = layers.zero_state( batch_size = self.batch_size_ph, dtype = tf.float32 ) 291 | 292 | seqlen = self.seq_length - 1 293 | self.inputdatasize = tf.shape( self.input_data ) 294 | inputs = tf.split( self.input_data, seqlen, 1) 295 | self.inputssize1 = tf.shape( inputs ) 296 | inputs = [ tf.squeeze( input_, [ 1 ] ) for input_ in inputs ] 297 | self.inputssize2 = tf.shape( inputs ) 298 | 299 | if useInitializers: 300 | self.init_op_weights = tf.random_normal( [ args.rnn_size, self.nroutputvars_raw ], dtype = tf.float32, name = None, seed = random.random( ) ) 301 | init_op_bias = tf.zeros( [ self.nroutputvars_raw ], dtype = tf.float32, name = None ) 302 | with tf.variable_scope( trainpredictmode ): 303 | if useInitializers: 304 | outputWeight = tf.get_variable( "outputWeight", initializer = self.init_op_weights ) 305 | outputBias = tf.get_variable( "outputBias", initializer = init_op_bias ) 306 | else: 307 | outputWeight = tf.get_variable( "outputWeight", [ args.rnn_size, self.nroutputvars_raw ] ) 308 | outputBias = tf.get_variable( "outputBias", [ self.nroutputvars_raw ] ) 309 | self.outputWeight = outputWeight 310 | self.outputBias = outputBias 311 | 312 | 313 | if args.model == 'ffnn': #regular variables, no rnn 314 | print( 'nrinputvars_network', nrinputvars_network ) 315 | inputs_2d = tf.reshape( inputs, [ -1, nrinputvars_network ] ) # make 2d: ( nrseq * seq_length ) x nrinputvars_network 316 | 317 | if args.num_layers > 0: 318 | hidden1 = tf.nn.relu( tf.matmul( inputs_2d, weightsh1 ) + biasesh1 ) 319 | output = hidden1 320 | if args.num_layers > 1: 321 | hidden2 = tf.nn.relu( tf.matmul( output, weightsh2 ) + biasesh2 ) 322 | output = hidden2 323 | last_state = tf.zeros( [ 1 ] ) 324 | elif args.usernn: #See https://www.tensorflow.org/versions/r0.10/tutorials/recurrent/index.html 325 | output, last_state = tf.contrib.rnn.static_rnn( layers, inputs, initial_state = self.initial_state, scope = trainpredictmode ) 326 | else: 327 | output, last_state = tf.nn.seq2seq.rnn_decoder( inputs, self.initial_state, layers, loop_function = None, scope = trainpredictmode ) 328 | 329 | output = tf.reshape( tf.concat( output, 1 ), [ -1, args.rnn_size ] ) 330 | output = tf.nn.xw_plus_b( output, outputWeight, outputBias ) 331 | 332 | self.num_mixture = args.num_mixture 333 | self.output = output 334 | self.final_state = last_state 335 | 336 | # reshape target data so that it is compatible with prediction shape 337 | flat_target_data = tf.reshape( self.target_data, [ -1, nrtargetvars ] ) # make 2d: ( nrseq * seq_length ) x nrinputvars_network 338 | targetdata_classvars = flat_target_data[ :, :self.nrauxoutputvars ] 339 | [ x1_data, x2_data, eos_data, eod_data ] = tf.split( flat_target_data[ :, self.nrauxoutputvars: ], 4, 1 ) #classvars dx dy eos eod 340 | 341 | loss = tf.zeros( 1, dtype = tf.float32, name = None ) 342 | if args.nrClassOutputVars > 0 and args.classweightfactor > 0: 343 | [ o_classvars, o_classpred ] = self.get_classvars( args, output ) #does same as when strokevars are used, but skips extracting those 344 | classloss = self.get_class_loss( args, o_classvars, o_classpred, targetdata_classvars ) 345 | loss += classloss 346 | if args.useStrokeOutputVars and args.useStrokeLoss: 347 | [ o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_eos, o_eod ] = self.get_mixture_coef( args, output ) 348 | 349 | self.pi = o_pi 350 | self.mu1 = o_mu1 351 | self.mu2 = o_mu2 352 | self.sigma1 = o_sigma1 353 | self.sigma2 = o_sigma2 354 | self.corr = o_corr 355 | self.eos = o_eos 356 | self.eod = o_eod 357 | strokeloss = self.get_stroke_loss( args, o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_eos, o_eod, x1_data, x2_data, eos_data, eod_data, targetdata_classvars ) 358 | loss += strokeloss 359 | 360 | self.loss_plain = loss 361 | self.regularization_term = self.getRegularizationTerm( args ) 362 | loss += self.regularization_term 363 | self.loss_total = loss 364 | 365 | self.classvars = o_classvars 366 | self.classpred = o_classpred 367 | 368 | self.learningratevar = tf.Variable( 0.0, trainable = False ) 369 | self.learningrate_ph = tf.placeholder( dtype = tf.float32 ) #placeholder to feed new values for the learning rate to avoid adding an assignment op for each change 370 | self.learningrateop = tf.assign( self.learningratevar, self.learningrate_ph ) 371 | 372 | tvars = tf.trainable_variables( ) 373 | 374 | with tf.variable_scope( "gradient" ): 375 | self.gradient_org = tf.gradients( loss, tvars ) 376 | self.gradient_clipped, _ = tf.clip_by_global_norm( self.gradient_org, args.grad_clip ) 377 | optimizer = tf.train.AdamOptimizer( self.learningratevar, epsilon = 1e-05 ) 378 | 379 | self.train_op = optimizer.apply_gradients( zip( self.gradient_clipped, tvars ) ) 380 | 381 | def sample( self, sess, dataloader, args, nrbatches, use_own_output_as_input, outputdir ): #to see how network behaves given perfect prediction by itself on each previous step, feed input so that we can see output on each step 382 | print( 'sample' ) 383 | 384 | fn = outputdir +"output.txt" 385 | outputfile = open( fn, "w" ) 386 | 387 | prev_state = sess.run( self.layers.zero_state( 1, tf.float32 ) ) 388 | 389 | nrpointsperseq = args.maxdigitlength_nrpoints 390 | 391 | nrpoints = int ( nrbatches * args.nrseq_per_batch * nrpointsperseq ) 392 | nrsequenceinputs = 4 #dx dy eos eod 393 | strokes = np.zeros( ( nrpoints, nrsequenceinputs ), dtype = np.float32 ) 394 | mixture_params = [ ] 395 | 396 | dataloader.reset_batch_pointer( args ) 397 | state = sess.run( self.initial_state, feed_dict = { self.batch_size_ph: args.nrseq_per_batch, self.seq_length_ph: self.seq_length } ) 398 | 399 | strokeindex = 0 400 | sequencenr = 0 401 | nrseq = dataloader.curnrexamples 402 | rmse_strokes = np.zeros( ( nrseq ), dtype = np.float32 ) 403 | rmse_classes = np.zeros( ( nrseq ), dtype = np.float32 ) 404 | correctfracs = np.zeros( ( nrseq ), dtype = np.float32 ) 405 | 406 | mode = "test" 407 | nrbatches = 100 408 | sample_nrseq = args.nrseq_per_batch 409 | 410 | if use_own_output_as_input: 411 | sample_nrseq = 500 412 | 413 | for batchnr in range( nrbatches ): 414 | print( 'batch', batchnr ) 415 | x, y, sequence_index = dataloader.next_batch( args, args.seq_length ) 416 | 417 | for batch_seqnr in range( sample_nrseq ): 418 | print( 'batch', batchnr, 'seq', batch_seqnr, 'of', sample_nrseq, 'filenr', sequence_index[ batch_seqnr ] ) 419 | xseq = x[ batch_seqnr ] 420 | yseq = y[ batch_seqnr ] 421 | 422 | nrpoints = min( len( xseq ), nrpointsperseq ) 423 | outputmat = np.zeros( ( nrpoints, args.nroutputvars_final ), dtype = np.float32 ) 424 | 425 | if use_own_output_as_input: 426 | maxnrrows = 100 427 | else: 428 | maxnrrows = nrpoints 429 | rownr = 0 430 | cont = True 431 | while cont: 432 | if ( not use_own_output_as_input ) or ( rownr == 0 ): 433 | inputrow = xseq[ rownr, : ] 434 | print( 'getting row', rownr, ' of inputdata:', inputrow ) 435 | else: 436 | inputrow = [ next_x1, next_x2, eos, eod ] 437 | 438 | inputrow_scaledback = np.copy( inputrow ) 439 | inputrow_scaledback[ 0:2 ] *= args.rangelen 440 | inputrow_scaledback[ 0:2 ] += args.rangemin 441 | 442 | print( 'feeding inputrow, scaled back:', inputrow_scaledback ) 443 | feed = {self.input_data: [ [ inputrow ] ], self.initial_state:prev_state} 444 | 445 | [ o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_eos, o_eod, o_classvars, o_classpred, next_state, output ] = sess.run( [ self.pi, self.mu1, self.mu2, self.sigma1, self.sigma2, self.corr, self.eos, self.eod, self.classvars, self.classpred, self.final_state, self.output ], feed ) 446 | prev_state = next_state 447 | 448 | if args.nrClassOutputVars > 0: 449 | classvars = o_classvars[ 0, ] 450 | classpred = o_classpred[ 0, ] 451 | 452 | batch_pointnr = 0 453 | 454 | targetmat = np.copy( yseq ) 455 | 456 | absrowsum = np.absolute( targetmat ).sum( 1 ) 457 | mask = np.sign( absrowsum ) 458 | nzrows = np.nonzero( mask ) 459 | nzrows = nzrows[ 0 ] 460 | if len( nzrows )>0: 461 | last = len( nzrows ) - 1 462 | nrtargetrows = nzrows[ last ] + 1 463 | else: 464 | nrtargetrows = 0 465 | print( 'found nrtargetrows:', nrtargetrows ) 466 | 467 | outputmat = np.zeros( ( 1, args.nroutputvars_final ), dtype = np.float32 ) 468 | outputmat_sampled = np.zeros( ( 1, args.nroutputvars_final ), dtype = np.float32 ) 469 | 470 | if args.useStrokeOutputVars: 471 | if args.nrClassOutputVars > 0 and args.classweightfactor > 0: 472 | outputmat[ 0, :args.nrClassOutputVars ] = o_classpred[ batch_pointnr, ] 473 | outputmat_sampled[ 0, :args.nrClassOutputVars ] = o_classpred[ batch_pointnr, ] 474 | if args.useStrokeLoss: 475 | idx = get_pi_idx( dataloader.getRandValue( ), o_pi[ batch_pointnr ] ) 476 | next_x1, next_x2 = self.sample_gaussian_2d( o_mu1[ batch_pointnr, idx ], o_mu2[ batch_pointnr, idx ], o_sigma1[ batch_pointnr, idx ], o_sigma2[ batch_pointnr, idx ], o_corr[ batch_pointnr, idx ] ) 477 | eos = 1 if dataloader.getRandValue( ) < o_eos[ batch_pointnr ] else 0 478 | eod = 1 if dataloader.getRandValue( ) < o_eod[ batch_pointnr ] else 0 479 | outputmat[ 0, args.nrClassOutputVars:args.nrClassOutputVars + 4 ] = [ o_mu1[ batch_pointnr, idx ], o_mu2[ batch_pointnr, idx ], o_sigma1[ batch_pointnr, idx ], o_sigma2[ batch_pointnr, idx ] ] 480 | outputmat_sampled[ 0, args.nrClassOutputVars:args.nrClassOutputVars+4 ] = [ next_x1, next_x2, eos, eod ] 481 | else: 482 | outputmat_sampled[ 0, ] = o_classpred[ batch_pointnr, ] 483 | 484 | print( 'output unscaled:', [ o_mu1[ batch_pointnr, idx ], o_mu2[ batch_pointnr, idx ], o_eos[ batch_pointnr ], o_eod[ batch_pointnr ] ] ) 485 | outputrow = np.asarray( [ next_x1, next_x2, eos, eod ] ) 486 | print( 'sampled output unscaled:', outputrow ) 487 | outputrow[ 0:2 ] *= args.rangelen 488 | outputrow[ 0:2 ] += args.rangemin 489 | print( 'sampled output scaled', outputrow ) 490 | outputfile.write( str( outputrow[ 0 ] ) + " " + str( outputrow[ 1 ] ) + " " + str( outputrow[ 2 ] ) + " " + str( outputrow[ 3 ] ) + "\n" ) 491 | 492 | if not use_own_output_as_input: 493 | stroketarget = np.copy( targetmat[ rownr, args.nrClassOutputVars:args.nrClassOutputVars + 2 ] ) 494 | classtarget = np.copy( targetmat[ rownr, :args.nrClassOutputVars ] ) 495 | print( 'classtarget:', classtarget ) 496 | 497 | if args.useStrokeOutputVars: 498 | 499 | if args.useStrokeLoss: 500 | outputmat_sampled[ :, args.nrClassOutputVars:args.nrClassOutputVars + 2 ] *= args.rangelen 501 | outputmat_sampled[ :, args.nrClassOutputVars:args.nrClassOutputVars + 2 ] += args.rangemin 502 | outputmat[ :, args.nrClassOutputVars:args.nrClassOutputVars + 2 ] *= args.rangelen 503 | outputmat[ :, args.nrClassOutputVars:args.nrClassOutputVars + 2 ] += args.rangemin 504 | 505 | print( 'sampled outputmat_sample scaled back:' ) 506 | print( outputmat_sampled ) 507 | 508 | stroketarget *= args.rangelen 509 | stroketarget += args.rangemin 510 | 511 | err_stroke = outputmat_sampled[ :, args.nrClassOutputVars:args.nrClassOutputVars + 2 ]-stroketarget 512 | 513 | print( 'prediction', mode ) 514 | print( outputmat_sampled[ :, args.nrClassOutputVars:args.nrClassOutputVars + 2 ] ) 515 | print( 'stroketarget', mode ) 516 | print( stroketarget ) 517 | print( 'error', mode ) 518 | print( err_stroke ) 519 | 520 | sse_stroke = ( err_stroke ** 2 ).sum( ) 521 | 522 | if args.nrClassOutputVars > 0 and not use_own_output_as_input: 523 | classindex_true = np.argmax( classtarget ) 524 | classindex_pred = np.argmax( outputmat_sampled[ 0, :args.nrClassOutputVars ] ) 525 | print( 'batch', batchnr, 'seq', batch_seqnr, 'row', rownr, "classindex_true", classindex_true, 'pred', classindex_pred ) 526 | class_logits = outputmat_sampled[ 0, :args.nrClassOutputVars ] 527 | 528 | correct = np.equal( classindex_pred, classindex_true ) 529 | 530 | print( 'output', outputmat_sampled[ :args.nrClassOutputVars ] ) 531 | last = args.nroutputvars_raw - args.nrClassOutputVars 532 | logits_str = [ str( a ) for a in class_logits ] 533 | print( "batch", batchnr, "class", classindex_true, 'pred', classindex_pred, "class_logits", " " . join( logits_str ) ) 534 | print( 'correct:', correct ) 535 | 536 | rownr += 1 537 | if eod or ( rownr >= maxnrrows ): 538 | cont = False 539 | prev_state = sess.run( self.layers.zero_state( 1, tf.float32 ) ) 540 | 541 | print( 'end of batch', batchnr ) 542 | print( 'done' ) #after batch for loop 543 | outputfile.close( ) 544 | 545 | 546 | 547 | -------------------------------------------------------------------------------- /results/results-folder.txt: -------------------------------------------------------------------------------- 1 | file to ensure git creates the results folder 2 | -------------------------------------------------------------------------------- /runrnn: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | runnr=$3 5 | 6 | python --version 7 | which python 8 | 9 | if [ $# -eq 0 ] 10 | then 11 | echo "No arguments supplied" 12 | else 13 | echo found param file: $1 14 | rm -rf ./results/$1r$3 15 | rm -rf ./save/$1r$3 16 | mkdir ./results/$1r$3 17 | source ./exp/$1 18 | mkdir ./save/$1r$3 19 | 20 | echo found params: $params 21 | 22 | command="python train.py $params --explabel $*" 23 | echo command: $command 24 | $command 25 | 26 | fi 27 | 28 | #python train.py --num_mixture 1 --rnn_size 1 --num_layers 1 --batch_size 1 --num_epochs 10000 --seq_length 10 --data_scale 5 --keep_prob 1 --learning_rate 0.005 $* 29 | 30 | #parameters can be adapted above, and more parameters from below can be added. 31 | #to predict: use ./runrnn --predict 1 32 | 33 | #params: 34 | #--rnn_size 35 | #--num_layers 36 | #--model 37 | #--batch_size 38 | #--seq_length 39 | #--num_epochs 40 | #--save_every 41 | #--grad_clip 42 | #--learning_rate 43 | #--decay_rate 44 | #--num_mixture 45 | #--data_scale 46 | #--keep_prob 47 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | import time 5 | import os 6 | import cPickle 7 | import argparse 8 | 9 | from utils import * 10 | from model import Model 11 | import random 12 | 13 | import svgwrite 14 | from IPython.display import SVG, display 15 | 16 | # main code (not in a main function since I want to run this script in IPython as well). 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--filename', type=str, default='sample', 20 | help='filename of .svg file to output, without .svg') 21 | parser.add_argument('--sample_length', type=int, default=800, 22 | help='number of strokes to sample') 23 | parser.add_argument('--scale_factor', type=int, default=10, 24 | help='factor to scale down by for svg output. smaller means bigger output') 25 | sample_args = parser.parse_args() 26 | 27 | with open(os.path.join('save', 'config.pkl')) as f: 28 | saved_args = cPickle.load(f) 29 | 30 | model = Model(saved_args, True) 31 | sess = tf.InteractiveSession() 32 | saver = tf.train.Saver(tf.all_variables()) 33 | 34 | ckpt = tf.train.get_checkpoint_state('save') 35 | print "loading model: ",ckpt.model_checkpoint_path 36 | 37 | saver.restore(sess, ckpt.model_checkpoint_path) 38 | 39 | def sample_stroke(): 40 | [strokes, params] = model.sample(sess, sample_args.sample_length) 41 | draw_strokes(strokes, factor=sample_args.scale_factor, svg_filename = sample_args.filename+'.normal.svg') 42 | draw_strokes_random_color(strokes, factor=sample_args.scale_factor, svg_filename = sample_args.filename+'.color.svg') 43 | draw_strokes_random_color(strokes, factor=sample_args.scale_factor, per_stroke_mode = False, svg_filename = sample_args.filename+'.multi_color.svg') 44 | draw_strokes_eos_weighted(strokes, params, factor=sample_args.scale_factor, svg_filename = sample_args.filename+'.eos_pdf.svg') 45 | draw_strokes_pdf(strokes, params, factor=sample_args.scale_factor, svg_filename = sample_args.filename+'.pdf.svg') 46 | return [strokes, params] 47 | 48 | [strokes, params] = sample_stroke() 49 | 50 | 51 | -------------------------------------------------------------------------------- /save/save-folder.txt: -------------------------------------------------------------------------------- 1 | file to ensure git creates the save folder 2 | -------------------------------------------------------------------------------- /stats.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.python.framework import ops 4 | 5 | class Stats(): 6 | def __init__( self, args, nrseq, label ): 7 | self.nrseq = nrseq 8 | self.label = label 9 | self.stats_stroke = RMSEStat( args, nrseq, label + 'stroke' ) 10 | self.stats_correct = AvgStat( args, nrseq, label + 'correct' ) 11 | self.stats_correctfrac = AvgStat( args, nrseq, label + 'correctfrac' ) 12 | 13 | def reset( self ): 14 | self.stats_stroke.reset() 15 | self.stats_correct.reset() 16 | self.stats_correctfrac.reset() 17 | 18 | class RMSEStat(): 19 | def __init__( self, args, nrseq, label ): 20 | self.args = args 21 | self.nrseq = nrseq 22 | self.label = label 23 | self.reset() 24 | 25 | def log_sse_sequential( self, sse, nrpoints ): #no index given --> cycle through all sequences 26 | self.log_sse( self.pointer, sse, nrpoints) 27 | self.pointer += 1 28 | if ( self.pointer >= self.nrseq): 29 | self.pointer = 0 30 | 31 | def log_sse( self, sequence_index, sse, nrpoints ): #replace previous value of this specific example, so that stats always reflect all examples 32 | curval = self.sse[ sequence_index ] 33 | diff = sse - curval 34 | self.sse[ sequence_index ] = sse 35 | self.sse_sum += diff 36 | 37 | curval = self.nrpoints[ sequence_index ] 38 | diff = nrpoints - curval 39 | self.nrpoints[ sequence_index ] = nrpoints 40 | self.totnrpoints += diff 41 | 42 | def reset( self ): 43 | self.sse = np.zeros( (self.nrseq), dtype=np.float32 ) 44 | self.sse_sum = 0 45 | 46 | self.nrpoints = np.zeros( (self.nrseq), dtype=np.float32 ) 47 | self.totnrpoints = 0 48 | 49 | self.pointer = 0 50 | 51 | def rmse( self ): 52 | return np.sqrt( self.sse_sum / max( 1, self.totnrpoints ) ) 53 | 54 | class AvgStat(): 55 | def __init__( self, args, nrseq, label ): 56 | self.nrseq = nrseq 57 | self.label = label 58 | self.reset() 59 | 60 | def log_value_sequential( self, value, nrpoints ): 61 | self.log_value( self.pointer, value, nrpoints) 62 | self.pointer += 1 63 | if ( self.pointer >= self.nrseq): 64 | self.pointer = 0 65 | 66 | def log_value( self, sequence_index, value, nrpoints ): 67 | curval = self.values[ sequence_index ] 68 | diff = value - curval 69 | self.values[ sequence_index ] = value 70 | self.values_sum += diff 71 | 72 | curval = self.nrpoints[ sequence_index ] 73 | diff = nrpoints - curval 74 | self.nrpoints[ sequence_index ] = nrpoints 75 | self.totnrpoints += diff 76 | 77 | def reset( self ): 78 | self.nrpoints = 0 79 | 80 | self.values = np.zeros( (self.nrseq), dtype=np.float32 ) 81 | self.values_sum = 0 82 | 83 | self.nrpoints = np.zeros( (self.nrseq), dtype=np.float32 ) 84 | self.totnrpoints = 0 85 | 86 | self.pointer = 0 87 | 88 | def average( self ): 89 | return self.values_sum / max( 1, self.totnrpoints ) 90 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.python.framework import ops 4 | 5 | import getpass 6 | import argparse 7 | import time 8 | import os 9 | import numpy.ma as MA 10 | import sys 11 | import pickle 12 | import random 13 | import math 14 | from time import gmtime, strftime 15 | 16 | from utils import DataLoader 17 | import utils 18 | 19 | from model import Model 20 | from model import get_pi_idx 21 | 22 | from stats import Stats 23 | 24 | import resource 25 | 26 | def memusage( point = "") : 27 | usage = resource.getrusage( resource.RUSAGE_SELF) 28 | return '''%s: usertime = %s systime = %s mem = %s mb 29 | '''%( point, usage[ 0 ], usage[ 1 ], 30 | ( usage[ 2 ]*resource.getpagesize( ) ) /1000000.0 ) 31 | 32 | def main( ) : 33 | 34 | print_input = 0 35 | user = getpass.getuser( ) 36 | print ( "user: ", user) 37 | #logdir = "/home/"+user+"/code/digits/rnn/log/" 38 | logdir = '~/projects/incremental-sequence-learning/log/' 39 | 40 | parser = argparse.ArgumentParser( ) 41 | parser.add_argument( '--rnn_size', type = int, default = 256, 42 | help = 'size of RNN hidden state') 43 | parser.add_argument( '--num_layers', type = int, default = 2, 44 | help = 'number of layers in the RNN') 45 | parser.add_argument( '--model', type = str, default = 'basiclstm', 46 | help = 'rnn, gru, or lstm, or ffnn') 47 | parser.add_argument( '--nrseq_per_batch', type = int, default = 50, 48 | help = 'minibatch size') 49 | parser.add_argument( '--nrseq_per_batch_test', type = int, default = 50, 50 | help = 'minibatch size') 51 | parser.add_argument( '--nrpoints_per_batch', type = int, default = 0, 52 | help = 'Number of points ( sequence steps) per batch') 53 | parser.add_argument( '--num_epochs', type = int, default = 30, 54 | help = 'number of epochs') 55 | parser.add_argument( '--report_every', type = int, default = 50, 56 | help = 'report frequency') 57 | parser.add_argument( '--save_every_nrbatches', type = int, default = 200, 58 | help = 'save frequency') 59 | parser.add_argument( '--save_maxnrmodels_keep', type = int, default = 5, 60 | help = 'Max nr of models to keep') 61 | parser.add_argument( '--eval_every', type = int, default = 50, 62 | help = 'evaluation frequency') 63 | parser.add_argument( '--test_every_nrbatches', type = int, default = 0, 64 | help = 'testing frequency') 65 | parser.add_argument( '--grad_clip', type = float, default = 10., 66 | help = 'clip gradients at this value') 67 | parser.add_argument( '--learning_rate', type = float, default = 0.005, 68 | help = 'learning rate') 69 | parser.add_argument( '--decay_rate', type = float, default = 0.95, 70 | help = 'decay rate for rmsprop') 71 | parser.add_argument( '--num_mixture', type = int, default = 20, 72 | help = 'number of gaussian mixtures') 73 | parser.add_argument( '--keep_prob', type = float, default = 0.8, 74 | help = 'dropout keep probability') 75 | parser.add_argument( '--predict', type = int, default = 0, 76 | help = 'predict instead of training') 77 | parser.add_argument( '--predictideal', type = int, default = 0, 78 | help = 'predict given ideal input') 79 | parser.add_argument( '--evaluate', type = int, default = 0, 80 | help = 'Run evaluation process that monitors the checkpoint files of a concurrently running training process.') 81 | parser.add_argument( '--nrinputfiles_train', type = int, default = 0, 82 | help = 'number of training input data files to use') 83 | parser.add_argument( '--nrinputfiles_test', type = int, default = 0, 84 | help = 'number of test input data files to use') 85 | 86 | parser.add_argument( '--explabel', type = str, default = 0, 87 | help = 'experiment label') 88 | parser.add_argument( '--max_seq_length', type = int, default = 0, 89 | help = 'max amount of points per sequence that will be used') 90 | parser.add_argument( '--file_label', type = str, default = "", 91 | help = 'input file label') 92 | parser.add_argument( '--train_on_own_output_method', type = int, default = 0, 93 | help = 'Various methods for training on own output, governed by current network performance') 94 | parser.add_argument( '--model_checkpointfile', type = str, default = "", 95 | help = 'checkpoint file to load') 96 | parser.add_argument( '--sample_from_output', type = int, default = 0, 97 | help = 'If set, when using train_on_own_output_method, the output will be sampled from first before passing it on as the next input; if not set, the output is used directly.') 98 | parser.add_argument( '--regularization_factor', type = float, default = .01) 99 | parser.add_argument( '--l2_weight_regularization', type = int, default = 1, help = 'Use the average of all weights as a regularization component') 100 | parser.add_argument( '--max_weight_regularization', type = int, default = 0, help = 'Use the maximum of all weights as a basis for regularization') 101 | parser.add_argument( '--discard_inputs', type = int, default = 0, help = 'Discard the input data; network must produce the output balistically.') 102 | parser.add_argument( '--discard_classvar_inputs', type = int, default = 0, help = 'Discard the 10 boolean class indicators indicating the current class as input.') 103 | parser.add_argument( '--nrClassOutputVars', type = int, default = 0, help = 'Use up to 10 binary class indicator outputs that the model has to predict at each step') 104 | parser.add_argument( '--useStrokeOutputVars', type = int, default = 1, help = 'Generate strokes.') 105 | parser.add_argument( '--useStrokeLoss', type = int, default = 1, help = 'Use loss component based on stroke output.') 106 | parser.add_argument( '--useClassInputVars', type = int, default = 1, help = 'Use 10 binary input variable representing the digit class ( one-hot representation) .') 107 | parser.add_argument( '--incremental_min_nrpoints', type = int, default = 5000, help = 'Min number of points to evaluate before considering next increment') 108 | parser.add_argument( '--useInitializers', type = int, default = 0, help = 'Use initializers for network parameters to ensure reproducibility') 109 | parser.add_argument( '--usePreviousEndState', type = int, default = 0, help = 'Use end state after previous batch as initial state for next batch') 110 | parser.add_argument( '--print_length_correct', type = int, default = 0, help = 'Use initializers for network parameters to ensure reproducibility') 111 | 112 | parser.add_argument( '--startingpoint', type = str, default = '', help = 'Start from saved state.') 113 | parser.add_argument( '--randomizeSequenceOrder', type = int, default = 1, help = 'Randomize order of sequences to prevent learning order.') 114 | parser.add_argument( '--classweightfactor', type = float, default = 10, help = 'weight of classvar loss') 115 | parser.add_argument( '--curnrtrainexamples', type = float, default = 10) 116 | parser.add_argument( '--current_seq_length', type = int, default = 0, 117 | help = 'Used in combination with incremental_seq_length') 118 | parser.add_argument( '--curnrdigits', type = int, default = 10) 119 | parser.add_argument( '--correctfrac_threshold_inc_nrtrainex', type = float, default = .8) 120 | parser.add_argument( '--threshold_rmse_stroke', type = float, default = 2) 121 | parser.add_argument( '--usernn', type = int, default = 0) 122 | parser.add_argument( '--fileselection', type = str, default = '', nargs = '+') #representative 10 digits: 1, 3, 25, 7, 89, 0, 62, 96, 85, 43 123 | parser.add_argument( '--incremental_nr_trainexamples', type = int, default = 0) 124 | parser.add_argument( '--incremental_seq_length', type = int, default = 0) 125 | parser.add_argument( '--incremental_nr_digits', type = int, default = 0) 126 | parser.add_argument( '--runnr', type = int, default = 1) 127 | parser.add_argument( '--maxnrpoints', type = int, default = 0) 128 | parser.add_argument( '--stat_windowsize_nrsequences', type = int, default = 1000) 129 | parser.add_argument( '--firsttrainstep', type = int, default = 0, help = 'Loss is calculated from this sequence step onwards; preceding points are ignored ( fed, but not contributing to loss) ') 130 | parser.add_argument( '--stopcrit_threshold_stroke_rmse_train', type = float, default = 0) 131 | parser.add_argument( '--testovertrain', type = int, default = 0, help = 'Control experiment to check that overtraining can occur. Uses digits 0-4 for training, 5-9 for testing.') 132 | parser.add_argument( '--reportstate', type = int, default = 0, help = 'report complete internal state ( weights, state) before/after each train/test batch') 133 | parser.add_argument( '--reportmixture', type = int, default = 0, help = 'report mixture') 134 | 135 | #arguments set internally: 136 | parser.add_argument( '--maxdigitlength_nrpoints', type = int, default = 0, help = 'max sequence length ( nr points) that was encountered in the training data; calculated parameter.' ) 137 | parser.add_argument( '--rangemin', type = float, default = -22.6) #determined based on full MNIST stroke sequence data set 138 | parser.add_argument( '--rangelen', type = float, default = 55.2) #determined based on full MNIST stroke sequence data set 139 | parser.add_argument( '--seq_length', type = int, default = 0) 140 | parser.add_argument( '--nroutputvars', type = int, default = 0) 141 | parser.add_argument( '--nrtargetvars', type = int, default = 0) 142 | parser.add_argument( '--nrauxinputvars', type = int, default = 0) 143 | parser.add_argument( '--debuginfo', type = int, default = 0) 144 | 145 | #variable sizes: 146 | #o_pi: nrrowsperbatch x nrmixtures, i.e. pointnr x mixturenr 147 | #targetdata: nrseq x seqlen x nrinputvars_network 148 | #input: dx dy eos eod 149 | #output: eos eod nr_mixtures*distribution-params classvars 150 | 151 | args = parser.parse_args( ) 152 | 153 | file_label = args.file_label 154 | 155 | explabel = args.explabel 156 | 157 | outputdir = "./results/"+explabel+"r"+str( args.runnr) +"/" 158 | 159 | if args.incremental_nr_trainexamples: 160 | args.curnrtrainexamples = min( args.curnrtrainexamples, args.nrinputfiles_train ) 161 | args.incremental_min_nrpoints = 50 * args.curnrtrainexamples 162 | else: 163 | args.curnrtrainexamples = args.nrinputfiles_train 164 | 165 | # datadir = "/home/"+user+"/code/digits/sequences/" 166 | datadir = './data/sequences/' 167 | print( "using data dir: ", datadir ) 168 | 169 | seqlenarg = 0 170 | trainarg = 1 171 | dataloader_train = DataLoader( datadir, args, args.nrinputfiles_train, args.curnrtrainexamples, seqlenarg, trainarg, file_label, print_input, args.rangemin, args.rangelen) 172 | args.nrauxinputvars = 10 * args.useClassInputVars 173 | 174 | trainarg = 0 175 | dataloader_test = DataLoader( datadir, args, args.nrinputfiles_test, args.nrinputfiles_test, dataloader_train.seq_length, trainarg, file_label, print_input, args.rangemin, args.rangelen) 176 | dataloader_test.createRandValues( ) 177 | 178 | args.nrtargetvars = 4*args.useStrokeOutputVars + args.nrClassOutputVars 179 | if ( not args.incremental_seq_length) : 180 | args.current_seq_length = dataloader_train.seq_length 181 | 182 | if ( args.evaluate or args.predict or args.predictideal) : 183 | configfile = os.path.join( 'save/'+args.explabel+'r'+str( args.runnr) , 'config.pkl') 184 | if len( args.startingpoint ) > 0: 185 | pos = args.startingpoint.find( 'model' ) 186 | savedfolder = args.startingpoint[ 0 : pos - 1 ] 187 | print( 'savedfolder', savedfolder) 188 | slashpos = savedfolder.find( '/') #find first slash 189 | savedfolderlist = list( savedfolder) 190 | savedfolderlist[ slashpos ] = 'x' 191 | savedfolder = "".join( savedfolderlist) 192 | slashpos = savedfolder.find( '/') #find second slash 193 | savedfolder = args.startingpoint[ 0 : pos - 1 ] 194 | savedfolderlist = list( savedfolder) 195 | savedfolderlist = savedfolderlist[ :slashpos+1 ] 196 | savedfolder = "".join( savedfolderlist) #get the path 197 | configfile = os.path.join( savedfolder, 'config.pkl') 198 | print( 'configfile', configfile) 199 | fileexists = os.path.exists( configfile ) 200 | while not fileexists: 201 | print ( "Waiting for config file", configfile) 202 | time.sleep( 5) 203 | fileexists = os.path.exists( configfile ) 204 | 205 | f = open( configfile, "rb") 206 | saved_args = pickle.load( f) 207 | saved_args.nrseq_per_batch = args.nrseq_per_batch 208 | f.close( ) 209 | 210 | trainpredictmode = "Predict" 211 | nrsequenceinputs = 4 #dx dy eos eod 212 | nrinputvars_network = nrsequenceinputs + args.nrauxinputvars; 213 | 214 | model_predict = Model( saved_args, trainpredictmode, True, nrinputvars_network, saved_args.nroutputvars_raw, args.nrtargetvars, args.nrClassOutputVars, maxdigitlength_nrpoints = saved_args.maxdigitlength_nrpoints) 215 | 216 | if args.predict: 217 | nrbatches = args.nrinputfiles_test / args.nrseq_per_batch 218 | use_own_output_as_input = 1 219 | with tf.Session( ) as sess: 220 | performPrediction( sess, saved_args, args, model_predict, dataloader_test, nrbatches, use_own_output_as_input, outputdir, parser ) 221 | elif ( args.predictideal) : 222 | nrbatches = args.nrinputfiles_test / args.nrseq_per_batch 223 | use_own_output_as_input = 0 224 | with tf.Session( ) as sess: 225 | performPrediction( sess, saved_args, args, model_predict, dataloader_test, nrbatches, use_own_output_as_input, outputdir, parser ) 226 | else: 227 | train( dataloader_train, dataloader_test, args, logdir, outputdir ) 228 | 229 | def savemodel( saver, sess, dataloader, args, batchnr ) : 230 | checkpoint_path = os.path.join( 'save/'+args.explabel+'r'+str( args.runnr) , 'model.ckpt') 231 | print( ( "saving model to {}".format( checkpoint_path) ) ) 232 | saver.save( sess, checkpoint_path, global_step = batchnr) 233 | checkpoint_fullpath = checkpoint_path + "-" + str( batchnr ) 234 | print( ( 'saved checkpoint: '+checkpoint_fullpath) ) 235 | 236 | def restoreModel( sess, args, model_predict, dataloader, checkpoint_fullpath = "") : 237 | 238 | saver = tf.train.Saver( tf.trainable_variables( ) , max_to_keep = args.save_maxnrmodels_keep) 239 | 240 | print( 'restoreModel: checkpoint_fullpath', checkpoint_fullpath) 241 | if len( checkpoint_fullpath) >0: #model path provided 242 | saver.restore( sess, checkpoint_fullpath) 243 | else: #load most recent state of own run 244 | modelfile = 'save/'+saved_args.explabel 245 | if len( args.model_checkpointfile) > 1 : 246 | modelfile = args.model_checkpointfile 247 | ckpt = tf.train.get_checkpoint_state( modelfile) 248 | print( 'restored checkpoint' ) 249 | print ( "loading model: ", ckpt.model_checkpoint_path) 250 | saver.restore( sess, ckpt.model_checkpoint_path) 251 | 252 | def performPrediction( sess, saved_args, args, model_predict, dataloader, nrbatches, use_own_output_as_input, outputdir, parser) : 253 | 254 | print( 'performprediction') 255 | 256 | saver = tf.train.Saver( tf.trainable_variables( ) , max_to_keep = args.save_maxnrmodels_keep) 257 | 258 | restoreModel( sess, args, model_predict, dataloader, args.startingpoint) 259 | 260 | print( 'restored model') 261 | 262 | [ strokes, params ] = model_predict.sample( sess, dataloader, saved_args, nrbatches, use_own_output_as_input, outputdir) 263 | print( 'Completed performPrediction' ) 264 | 265 | def writeOutputTarget( args, outputdir, batchnr, sequence_index, batch_seqnr, outputmat, outputmat_sampled, targetmat, stroketarget, lossvec, model, loss, mode, inputdata) : 266 | fn = outputdir + "output-"+mode+"-batch-" + str( batchnr) + "-seqnr-"+str( batch_seqnr) +"-filenr-" + str( sequence_index[ batch_seqnr ] ) + ".txt" 267 | np.savetxt( fn, outputmat, fmt = '%.3f') 268 | 269 | fn = outputdir + "output-sampled-"+mode+"-batch-" + str( batchnr) + "-seqnr-"+str( batch_seqnr) +"-filenr-" + str( sequence_index[ batch_seqnr ] ) + ".txt" 270 | np.savetxt( fn, outputmat_sampled, fmt = '%.3f') 271 | 272 | fn = outputdir + "classtarget-" +mode+ "-batch-" + str( batchnr) + "-seqnr-" +str( batch_seqnr) +"-filenr-" + str( sequence_index[ batch_seqnr ] ) + ".txt" 273 | np.savetxt( fn, targetmat[ :, 0:10 ], fmt = '%.3f') 274 | 275 | fn = outputdir + "stroketarget-"+ mode + "-batch-" + str( batchnr) + "-seqnr-" +str( batch_seqnr) +"-filenr-" + str( sequence_index[ batch_seqnr ] ) + ".txt" 276 | np.savetxt( fn, stroketarget, fmt = '%.3f') 277 | 278 | fn = outputdir + "input-"+ mode + "-batch-" + str( batchnr) + "-seqnr-" +str( batch_seqnr) +"-filenr-" + str( sequence_index[ batch_seqnr ] ) + ".txt" 279 | np.savetxt( fn, inputdata, fmt = '%.3f') 280 | 281 | if batch_seqnr == 0: 282 | if ( args.useStrokeOutputVars and args.useStrokeLoss) : 283 | fn = outputdir + "lossvec-" +mode+ str( batchnr) + ".txt" 284 | np.savetxt( fn, lossvec, fmt = '%.3f') 285 | fn = outputdir + mode+"loss-" +mode+ str( batchnr) + ".txt" 286 | file = open( fn, "w") 287 | file.write( str( loss) + "\n" ) 288 | file.close( ) 289 | 290 | def writeMixture( args, outputdir, batchnr, sequence_index, batch_seqnr, mode, mixture, seq_pointnr) : 291 | fn = outputdir + "mixture-"+mode+"-batch-" + str( batchnr) + "-seqnr-"+str( batch_seqnr) +"-filenr-" + str( sequence_index[ batch_seqnr ] ) + ".txt" 292 | if seq_pointnr == 0: 293 | f = open( fn, 'wb') 294 | else: 295 | f = open( fn, 'ab') 296 | np.savetxt( f, mixture[ None ], fmt = '%.3f', delimiter = ", ") 297 | f.close( ) 298 | 299 | def softmax( x) : 300 | """Compute softmax values for each sets of scores in x.""" 301 | e_x = np.exp( x - np.max( x) ) 302 | return e_x / e_x.sum( ) 303 | 304 | def evaluate( sess, args, stats, stats_alldata, stats_inc, sequence_index, trainpredictmode, model, dataloader, outputdir, outputs, state, lossvec, train_loss, regularization_term, loss_plain, loss_total, weights, nrinputvars_network, targetdata, maxabsweight, avgweight, learningrate_value, train_on_output, epochnr, totbatchnr, totnrpoints_trained, writefiles, runtime, mode, printstate, batchsize_nrseq, x) : 305 | nanfound = False 306 | avgmu1 = 0 307 | avgmu2 = 0 308 | maxabscorr = 0 309 | sse_stroke = 0 310 | nrrowsused = 0 311 | report_nrsequences = 10 312 | 313 | with tf.variable_scope( trainpredictmode) : 314 | 315 | if args.useStrokeOutputVars and args.useStrokeLoss: 316 | [ o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_eos, o_eod, o_classvars, o_classpred ] = outputs 317 | else: 318 | [ o_classvars, o_classpred ] = outputs 319 | 320 | batch_pointnr = 0 321 | 322 | for batch_seqnr in range( 0, batchsize_nrseq ) : #for each sequence in the batch 323 | targetmat = np.copy( targetdata[ batch_seqnr ]) 324 | absrowsum = np.absolute( targetmat) .sum( 1 ) 325 | mask = np.sign( absrowsum) #checked EdJ Sept 2 326 | nzrows = np.nonzero( mask) #not binary; seems: this returns indices of nonzero rows 327 | nzrows = nzrows[ 0 ] 328 | if len( nzrows) >0: 329 | last = len( nzrows) - 1 330 | nrtargetrows = nzrows[ last ] + 1 331 | else: 332 | nrtargetrows = 0 333 | 334 | if mode == "train": 335 | evalseqlength = min( args.current_seq_length + 1, nrtargetrows + 1) 336 | else: 337 | evalseqlength = min( model.seq_length, nrtargetrows + 1) 338 | 339 | outputmat = np.zeros( ( evalseqlength - 1, args.nroutputvars_final) , dtype = np.float32) 340 | outputmat_sampled = np.zeros( ( evalseqlength - 1, args.nroutputvars_final) , dtype = np.float32) 341 | 342 | mixture = 0 343 | for p in range( evalseqlength - 1) :#process used points first 344 | 345 | if args.useStrokeOutputVars: 346 | if args.nrClassOutputVars > 0 and args.classweightfactor > 0: 347 | outputmat[ p, :args.nrClassOutputVars ] = o_classpred[ batch_pointnr, ] 348 | outputmat_sampled[ p, :args.nrClassOutputVars ] = o_classpred[ batch_pointnr, ] 349 | if args.useStrokeLoss: 350 | idx = get_pi_idx( dataloader.getRandValue( ) , o_pi[ batch_pointnr ]) 351 | next_x1, next_x2 = model.sample_gaussian_2d( o_mu1[ batch_pointnr, idx ], o_mu2[ batch_pointnr, idx ], o_sigma1[ batch_pointnr, idx ], o_sigma2[ batch_pointnr, idx ], o_corr[ batch_pointnr, idx ]) 352 | eos = 1 if dataloader.getRandValue( ) < o_eos[ batch_pointnr ] else 0 353 | eod = 1 if dataloader.getRandValue( ) < o_eod[ batch_pointnr ] else 0 354 | outputmat[ p, args.nrClassOutputVars:args.nrClassOutputVars+4 ] = [ o_mu1[ batch_pointnr, idx ], o_mu2[ batch_pointnr, idx ], o_sigma1[ batch_pointnr, idx ], o_sigma2[ batch_pointnr, idx ] ] 355 | outputmat_sampled[ p, args.nrClassOutputVars:args.nrClassOutputVars+4 ] = [ next_x1, next_x2, eos, eod ] 356 | if writefiles and args.reportmixture and ( sequence_index[ batch_seqnr ] < report_nrsequences) : 357 | nrparams = args.num_mixture * 6 358 | mixture = np.zeros( ( nrparams ) , dtype = np.float32 ) 359 | for m in range( args.num_mixture ) : 360 | mixture[ m*6:( m+1) *6 ] = [ o_pi[ batch_pointnr, m ], o_mu1[ batch_pointnr, m ], o_mu2[ batch_pointnr, m ], o_sigma1[ batch_pointnr, m ], o_sigma2[ batch_pointnr, m ], o_corr[ batch_pointnr, m ] ] 361 | writeMixture( args, outputdir, totbatchnr, sequence_index, batch_seqnr, mode, mixture, p) 362 | 363 | else: 364 | outputmat_sampled[ p, ] = o_classpred[ batch_pointnr, ] 365 | 366 | batch_pointnr += 1 367 | batch_pointnr += model.seq_length - evalseqlength #after cur seq len, skip to end of seq ( = seq_length - 1) 368 | 369 | stroketarget = np.copy( targetmat[ :evalseqlength - 1, args.nrClassOutputVars:args.nrClassOutputVars + 4 ]) 370 | nrrowsused = nrtargetrows 371 | stats_inc_rmse = 0 372 | 373 | if args.useStrokeOutputVars and ( nrrowsused > 0) : 374 | 375 | if args.useStrokeLoss: 376 | outputmat_sampled[ :, args.nrClassOutputVars:args.nrClassOutputVars + 2 ] *= args.rangelen 377 | outputmat_sampled[ :, args.nrClassOutputVars:args.nrClassOutputVars + 2 ] += args.rangemin 378 | outputmat[ :, args.nrClassOutputVars:args.nrClassOutputVars + 2 ] *= args.rangelen 379 | outputmat[ :, args.nrClassOutputVars:args.nrClassOutputVars + 2 ] += args.rangemin 380 | 381 | stroketarget[ :, 0:2 ] *= args.rangelen 382 | stroketarget[ :, 0:2 ] += args.rangemin 383 | 384 | err_stroke = outputmat_sampled[ :, args.nrClassOutputVars:args.nrClassOutputVars + 2 ]-stroketarget[ :, 0:2 ] #was: 1:2 385 | 386 | sse_stroke = ( err_stroke ** 2) .sum( ) 387 | 388 | stats.stats_stroke.log_sse_sequential( sse_stroke, 2 * nrrowsused ) 389 | stats_alldata.stats_stroke.log_sse( sequence_index[ batch_seqnr ], sse_stroke, 2 * nrrowsused ) 390 | 391 | if mode == "train": 392 | stats_inc.stats_stroke.log_sse_sequential( sse_stroke, 2 * nrrowsused ) #sequential counter; window of last n values 393 | stats_inc_rmse = stats_inc.stats_stroke.rmse( ) 394 | 395 | if args.nrClassOutputVars > 0 and ( nrrowsused > 0) : 396 | classindex_true = np.argmax( targetmat[ :evalseqlength - 1, :args.nrClassOutputVars ], 1) 397 | classindex_pred = np.argmax( outputmat_sampled[ :, :args.nrClassOutputVars ], 1) 398 | 399 | correct = np.equal( classindex_pred, classindex_true) 400 | last_correct = correct[ nrrowsused - 1 ] #model.seq_length 401 | 402 | if args.print_length_correct: 403 | seqindex = sequence_index[ batch_seqnr ] 404 | print( 'len-correct', mode, 'seq', seqindex, 'len', dataloader.seqlengthlist[ seqindex ], 'correct', 1*last_correct) 405 | 406 | stats.stats_correct.log_value_sequential( last_correct, 1 ) 407 | stats.stats_correctfrac.log_value_sequential( correct.sum( ) , nrrowsused ) 408 | stats_alldata.stats_correct.log_value( sequence_index[ batch_seqnr ], last_correct, 1 ) 409 | stats_alldata.stats_correctfrac.log_value( sequence_index[ batch_seqnr ], correct.sum( ) , nrrowsused ) 410 | 411 | else: 412 | avgcorrectfrac = 0 413 | correctpreds = 0 414 | 415 | if writefiles and ( sequence_index[ batch_seqnr ] < report_nrsequences) : 416 | inputdata = np.copy( x[ batch_seqnr ] ) 417 | inputdata[ :, 10 * args.useClassInputVars : 10 * args.useClassInputVars + 2 ] *= args.rangelen 418 | inputdata[ :, 10 * args.useClassInputVars : 10 * args.useClassInputVars + 2 ] += args.rangemin 419 | writeOutputTarget( args, outputdir, totbatchnr, sequence_index, batch_seqnr, outputmat, outputmat_sampled, targetmat, stroketarget, lossvec, model, train_loss, mode, inputdata) 420 | 421 | weights_o = sess.run( model.outputWeight) ; 422 | bias = sess.run( model.outputBias) ; 423 | avgbias = bias.mean( ) 424 | maxabsbias = np.absolute( bias) .max( ) 425 | avgstate = np.asarray( state) .mean( ) 426 | maxabsstate = np.absolute( state) .max( ) 427 | if args.useStrokeOutputVars and args.useStrokeLoss: 428 | avgmu1 = outputmat[ :, 0 ].mean( ) 429 | avgmu2 = outputmat[ :, 1 ].mean( ) 430 | maxabscorr = np.absolute( o_corr) .max( ) 431 | 432 | avgw = avgweight 433 | if ( len( avgweight) >1) : 434 | avgw = avgweight.mean( ) 435 | maxabsw = maxabsweight 436 | if ( len( maxabsweight) >1) : 437 | maxabsw = maxabsweight.mean( ) 438 | print ( 'eval', mode, ': epoch', epochnr, 'totbatches', totbatchnr, 'totnrpoints_trained', totnrpoints_trained, 'nrtrainex', args.curnrtrainexamples, 'curseqlen', args.current_seq_length, 'curnrdigits', args.curnrdigits, 'rmse_stroke', stats.stats_stroke.rmse( ) , 'rmse_stroke_alldata', stats_alldata.stats_stroke.rmse( ) , 'rmse_stroke_inc', stats_inc_rmse, "correct", stats.stats_correct.average( ) , "correct_alldata", stats_alldata.stats_correct.average( ) , 'regularization', regularization_term[ 0 ], 'loss_total', loss_total, 'avgbias', avgbias, 'maxabsbias', maxabsbias, 'avgstate', avgstate, 'maxabsstate', maxabsstate, 'learningrate', learningrate_value, 'maxabscorr', maxabscorr, 'maxabsweight', maxabsw[ 0 ], 'avgweight', avgw[ 0 ], 'runtime', runtime ) 439 | 440 | #stats 441 | if epochnr % 100 == 0: 442 | graph = tf.get_default_graph( ) 443 | ops = graph.get_operations( ) 444 | print( ( 'mem nr ops: ', len( ops) ) ) 445 | print( 'mem usage:') 446 | print( ( memusage( "eval") ) ) 447 | print( 'rand e', epochnr, dataloader.getRandValue( ) ) 448 | 449 | return nanfound 450 | 451 | 452 | def print_model( model ) : 453 | print ( "model structure: " ) 454 | print ( "gradient vars: " ) 455 | 456 | for var in tf.get_collection( tf.GraphKeys.VARIABLES, scope = 'gradient') : # tf.variable_scope( "gradient") : 457 | print ( "var: ", var.name ) 458 | print ( "all vars: " ) 459 | params = tf.all_variables( ) 460 | for var in params: 461 | print ( "var: ", var.name ) 462 | 463 | def recordState( model, sess ) : 464 | params = tf.all_variables( ) 465 | state = [ ] 466 | varnames = [ ] 467 | for var in params: 468 | varnames.append( var.name ) 469 | value = sess.run( var) ; 470 | state.append( value ) 471 | return state, varnames 472 | 473 | def printState ( state, varnames, fn = '' ) : 474 | i = 0 475 | statefile = open( fn, "w" ) 476 | for var in state: 477 | print( 'var: ', varnames[ i ], file = statefile ) 478 | print( var.sum( ) , file = statefile ) 479 | i += 1 480 | statefile.close( ) 481 | 482 | def constructInputFromOutput( args, model, x, o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_eos, o_eod) : 483 | xnrseq = np.shape( x) [ 0 ] 484 | xnrpointsperseq = np.shape( x) [ 1 ] 485 | getbatch = False 486 | 487 | point = 0 488 | for s in range( xnrseq) : #for each sequence in previous batch 489 | 490 | outputmat = np.zeros( ( model.seq_length, 4) , dtype = np.float32) 491 | 492 | batch_pointnr = 0 493 | for p in range( xnrpointsperseq) : 494 | idx = get_pi_idx( dataloader.getRandValue( ) , o_pi[ batch_pointnr ]) 495 | 496 | mu1out = o_mu1[ batch_pointnr, idx ] #these are regular numpy floats, not tensors 497 | mu2out = o_mu2[ batch_pointnr, idx ] 498 | sigma1out = o_sigma1[ batch_pointnr, idx ] 499 | sigma2out = o_sigma2[ batch_pointnr, idx ] 500 | corrout = o_corr[ batch_pointnr, idx ] 501 | eosout = o_eos[ batch_pointnr ] 502 | eodout = o_eod[ batch_pointnr ] 503 | next_x1, next_x2 = model.sample_gaussian_2d( mu1out, mu2out, sigma1out, sigma2out, corrout) 504 | eos = 1 if dataloader.getRandValue( ) < eosout else 0 505 | eod = 1 if dataloader.getRandValue( ) < eodout else 0 506 | 507 | if args.sample_from_output: 508 | outputmat[ p, ] = [ next_x1, next_x2, eos, eod ] 509 | else: 510 | outputmat[ p, ] = [ mu1out, mu2out, eosout, eodout ] 511 | 512 | batch_pointnr += 1 513 | 514 | fromval = s*xnrpointsperseq 515 | toval = ( s+1) *xnrpointsperseq-1 #last output: not used 516 | xvalues = np.array( x[ s ]) 517 | xvalues[ 1:xnrpointsperseq, args.nrClassOutputVars:args.nrClassOutputVars+4 ] = outputmat[ 0:xnrpointsperseq-1, ] 518 | x[ s ] = xvalues 519 | return x 520 | 521 | def printInputsTargets( args, x, y ) : 522 | print ( "x") 523 | xvalues = np.array( x) 524 | xvalues[ :, args.nrClassOutputVars:args.nrClassOutputVars+2 ] *= args.rangelen 525 | xvalues[ :, args.nrClassOutputVars:args.nrClassOutputVars+2 ] += args.rangemin 526 | print( xvalues) 527 | print( "y") 528 | yvalues = np.array( y) 529 | yvalues[ :, args.nrClassOutputVars:args.nrClassOutputVars+2 ] *= args.rangelen 530 | yvalues[ :, args.nrClassOutputVars:args.nrClassOutputVars+2 ] += args.rangemin 531 | print( yvalues) 532 | 533 | def printWeightsGradients( sess ) : 534 | 535 | allvars = tf.all_variables( ) 536 | for var in allvars: 537 | isBias = var.name.find( "Bias") >= 0 538 | if not isBias: 539 | print( "var: ", var.name) 540 | value = sess.run( var) 541 | print( value ) 542 | 543 | def train( dataloader_train, dataloader_test, args, logdir, outputdir ) : 544 | 545 | stats_train = Stats( args, args.stat_windowsize_nrsequences, 'stats_train' ) #stats over recent training data 546 | stats_test = Stats( args, args.stat_windowsize_nrsequences, 'stats_test' ) #stats over recent test data 547 | stats_train_alldata = Stats( args, args.nrinputfiles_train, 'stats_train' ) #stats over the most recent set of |trainingset| examples 548 | stats_test_alldata = Stats( args, args.nrinputfiles_test, 'stats_test' ) #stats over the most recent set of |testset| examples 549 | 550 | nrseq_inc = np.ceil( args.incremental_min_nrpoints / min( args.current_seq_length, dataloader_train.avgseqlength) ) 551 | stats_train_inc = Stats( args, nrseq_inc, 'stats_train_inc' ) #stats over most recent incremental_min_nrpoints, for incremental methods 552 | 553 | random.seed( 100 * args.runnr ) 554 | np.random.seed( 100 * args.runnr ) 555 | tf.set_random_seed( 100 * args.runnr ) 556 | print( 'runnr', args.runnr, 'after seed, rand:', random.random( ) , 'np rand', np.random.rand( ) ) 557 | 558 | print( 'starting time: ', strftime( "%Y-%m-%d %H:%M:%S") ) 559 | 560 | nrsequenceinputs = 4 #dx dy eos eod 561 | nrinputvars_network = nrsequenceinputs + args.nrauxinputvars; 562 | args.nroutputvars_raw = ( 2 + args.num_mixture * 6) * args.useStrokeOutputVars + args.nrClassOutputVars 563 | args.nroutputvars_final = ( 2 + 2) * args.useStrokeOutputVars + args.nrClassOutputVars 564 | 565 | print( "nrinputvars_network", nrinputvars_network) 566 | print( "nrauxinputvars", args.nrauxinputvars) 567 | print( "args.nroutputvars_final", args.nroutputvars_final) 568 | 569 | trainpredictmode = "Predict" 570 | 571 | model = Model( args, trainpredictmode, False, nrinputvars_network, args.nroutputvars_raw, args.nrtargetvars, args.nrClassOutputVars, dataloader_train.rangemin, dataloader_train.rangelen, args.maxdigitlength_nrpoints ) 572 | 573 | #store info from model in args so it's saved: 574 | args.seq_length = model.seq_length 575 | 576 | print( 'about to save config in', os.path.join( 'save/'+args.explabel+'r'+str( args.runnr) , 'config.pkl') ) 577 | with open( os.path.join( 'save/'+args.explabel+'r'+str( args.runnr) , 'config.pkl') , 'wb') as f: 578 | pickle.dump( args, f) 579 | 580 | print_model( model ) 581 | 582 | checkpoint_fullpath = "" 583 | nanfound = False 584 | nrnanbatches = 0 585 | train_on_output = 0 586 | 587 | printstate = args.reportstate 588 | 589 | with tf.Session( ) as sess: 590 | 591 | random.seed( 100 * args.runnr ) 592 | np.random.seed( 100 * args.runnr ) 593 | randop = tf.random_normal( [ 1 ], seed = random.random( ) ) #, seed = 1234 594 | print( 'runnr', args.runnr, 'after seed, rand:', random.random( ) , 'np rand', np.random.rand( ) , 'tf rand', sess.run( randop) ) 595 | 596 | dataloader_train.createRandValues( ) 597 | dataloader_test.createRandValues( ) 598 | 599 | tf.initialize_all_variables( ) .run( ) 600 | saver = tf.train.Saver( tf.trainable_variables( ) , max_to_keep = args.save_maxnrmodels_keep) 601 | 602 | if args.useInitializers: 603 | init_op_weights = sess.run( model.init_op_weights) 604 | 605 | if ( len( args.startingpoint) >0) : 606 | print( "Starting from saved model: ", args.startingpoint) 607 | restoreModel( sess, args, model, dataloader_train, args.startingpoint) 608 | 609 | if args.useInitializers: 610 | init_op_weights = sess.run( model.init_op_weights) 611 | 612 | totnrbatches = 0 613 | totnrpoints = 0 614 | totnrpoints_trained = 0 615 | tstart = time.time( ) 616 | 617 | nrbatches_per_epoch = dataloader_train.nrbatches_per_epoch 618 | epochnr = 0 619 | cont = True 620 | 621 | while cont: #batch loop 622 | 623 | if ( totnrbatches % nrbatches_per_epoch == 0 ) : 624 | learningrate_value = args.learning_rate * ( args.decay_rate ** epochnr) 625 | learningrate = sess.run( model.learningrateop, feed_dict = {model.learningrate_ph: learningrate_value}) 626 | epochnr += 1 627 | 628 | modes = [ "train" ] 629 | if totnrbatches > 0 and ( totnrbatches % args.test_every_nrbatches == 0 ) : 630 | modes = [ "train", "test" ] 631 | 632 | for mode in modes: 633 | 634 | if mode == "train": 635 | dataloader = dataloader_train 636 | runseqlength = args.current_seq_length 637 | else: 638 | dataloader = dataloader_test 639 | runseqlength = model.seq_length 640 | 641 | if mode == "train": 642 | totnrbatches += 1 643 | stats = stats_train 644 | stats_alldata = stats_train_alldata 645 | stats_inc = stats_train_inc 646 | else: 647 | stats = stats_test 648 | stats_alldata = stats_test_alldata 649 | stats_inc = 0 650 | 651 | start_batch = time.time( ) 652 | 653 | tstartdata = time.time( ) 654 | getbatch = True 655 | if epochnr > 0: 656 | if args.train_on_own_output_method == 1 and mode == "train": 657 | train_on_output = dataloader.getRandValue( ) < 1.0 / ( 2. + mean( model.batch_rmse_stroke, model.batch_rmse_class) ) #training --> rand ok 658 | if train_on_output: 659 | x = constructInputFromOutput( args, model, x, o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_eos, o_eod) 660 | getbatch = False 661 | 662 | if getbatch: #not training on own output 663 | x, y, sequence_index = dataloader.next_batch( args, runseqlength ) #can contain multiple sequence 664 | tgetdata = ( time.time( ) - tstartdata) / 60 665 | start_train = time.time( ) 666 | 667 | batchsize_nrseq = len( x ) 668 | 669 | zero_initial_state = sess.run( model.initial_state, feed_dict = {model.batch_size_ph: batchsize_nrseq, model.seq_length_ph: runseqlength}) #Get zero state given current batchsz 670 | 671 | if ( mode == "test") or ( totnrbatches == 1) or ( not args.usePreviousEndState) : 672 | state = zero_initial_state 673 | else: 674 | state = last_train_state 675 | 676 | feed = {model.input_data: x, model.target_data: y, model.initial_state: state, model.batch_size_ph: batchsize_nrseq, model.seq_length_ph: args.current_seq_length} #model.seq_length 677 | 678 | if args.useStrokeOutputVars and args.useStrokeLoss: 679 | if mode == "train": 680 | train_loss, last_train_state, lossvec, o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_eos, o_eod, o_classvars, o_classpred, regularization_term, loss_plain, lossnrpoints, maxabsweight, avgweight, _ = sess.run( [ model.loss_total, model.final_state, model.lossvector, model.pi, model.mu1, model.mu2, model.sigma1, model.sigma2, model.corr, model.eos, model.eod, model.classvars, model.classpred, model.regularization_term, model.loss_plain, model.lossnrpoints, model.maxabsweight, model.avgweight, model.train_op ], feed) 681 | state_report = last_train_state 682 | else: #test --> omit train op, and don't replace state 683 | train_loss, state_report, lossvec, o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_eos, o_eod, o_classvars, o_classpred, regularization_term, loss_plain, lossnrpoints, maxabsweight, avgweight = sess.run( [ model.loss_total, model.final_state, model.lossvector, model.pi, model.mu1, model.mu2, model.sigma1, model.sigma2, model.corr, model.eos, model.eod, model.classvars, model.classpred, model.regularization_term, model.loss_plain, model.lossnrpoints, model.maxabsweight, model.avgweight ], feed) 684 | 685 | outputs = [ o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_eos, o_eod, o_classvars, o_classpred ] 686 | 687 | else: #no stroke loss, only learn classes 688 | z = np.zeros( ( 1) , dtype = np.float32 ) 689 | [ o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_eos, o_eod ] = [ z, z, z, z, z, z, z, z ] 690 | if mode == "train": 691 | train_loss, last_train_state, output, lossvec, o_classvars, o_classpred, regularization_term, loss_plain, result4, result, result_before_mask, lossnrpoints, mask, classpred, targetdata_classvars, crossentropy, maxabsweight, avgweight, _ = sess.run( [ model.loss_total, model.final_state, model.output, model.lossvector, model.classvars, model.classpred, model.regularization_term, model.loss_plain, model.result4, model.result, model.result_before_mask, model.lossnrpoints, model.mask, model.classpred, model.targetdata_classvars, model.crossentropy, model.maxabsweight, model.avgweight, model.train_op ], feed) 692 | state_report = last_train_state 693 | else: 694 | train_loss, state_report, output, lossvec, o_classvars, o_classpred, regularization_term, loss_plain, result4, result, mask, maxabsweight, avgweight, classpred, targetdata_classvars = sess.run( [ model.loss_total, model.final_state, model.output, model.lossvector, model.classvars, model.classpred, model.regularization_term, model.loss_plain, model.result4, model.result, model.mask, model.maxabsweight, model.avgweight, model.classpred, model.targetdata_classvars ], feed) 695 | outputs = [ o_classvars, o_classpred ] 696 | 697 | if mode == "train": 698 | totnrpoints_trained += lossnrpoints 699 | 700 | weights_o = sess.run( model.outputWeight) ; 701 | 702 | nanfound = math.isnan( train_loss) 703 | 704 | if nanfound: 705 | print( ( "NAN encountered --> stopping.") ) 706 | sys.exit( ) ; 707 | 708 | end_train = time.time( ) 709 | train_loss = train_loss.mean( ) 710 | 711 | start_eval = time.time( ) 712 | 713 | #evaluation 714 | if ( epochnr % args.eval_every == 0 ) : 715 | writefiles = totnrbatches % args.report_every == 0 and ( totnrbatches > 0) 716 | weights = 0 717 | runtime = ( time.time( ) - tstart) / 60 718 | nanfound = nanfound or evaluate( sess, args, stats, stats_alldata, stats_inc, sequence_index, trainpredictmode, model, dataloader, outputdir, outputs, state_report, lossvec, train_loss, regularization_term, loss_plain, train_loss, weights, nrinputvars_network, y, maxabsweight, avgweight, learningrate_value, train_on_output, epochnr, totnrbatches, totnrpoints_trained, writefiles, runtime, mode, printstate, batchsize_nrseq, x) 719 | stats.reset( ) 720 | 721 | end_eval = time.time( ) 722 | tot_time = end_train-start_train + end_eval-start_eval 723 | 724 | #saving: 725 | if ( not nanfound ) and ( mode == "train") : 726 | if totnrbatches % args.save_every_nrbatches == 0 and ( totnrbatches > 0) : 727 | savemodel( saver, sess, dataloader, args, totnrbatches) 728 | 729 | if nanfound and mode == "train": 730 | print( ( "NAN encountered --> stopping.") ) 731 | sys.exit( ) ; 732 | 733 | end_batch = time.time( ) 734 | print ( "End of batch: time_train", end_train-start_train, "time ev", end_eval-start_eval, "tdata", tgetdata, "tot", tot_time, 'batch time', end_batch-start_batch, "sequences/sec", dataloader.nrseq_per_batch/tot_time) 735 | 736 | if mode == "train": 737 | if stats_train_inc.stats_stroke.totnrpoints >= args.incremental_min_nrpoints: 738 | 739 | reached_threshold = False 740 | if args.incremental_seq_length: 741 | print( 'inc seq len: rmse', stats_train_inc.stats_stroke.rmse( ) , 'thr', args.threshold_rmse_stroke) 742 | if ( stats_train_inc.stats_stroke.rmse( ) < args.threshold_rmse_stroke) and ( args.current_seq_length < model.seq_length) : 743 | args.current_seq_length = min( model.seq_length, args.current_seq_length * 2 ) 744 | reached_threshold = True 745 | print( "REACHED THRESHOLD! --> increasing cur_seq_length to ", args.current_seq_length, ' max ', model.seq_length) 746 | 747 | if args.incremental_nr_trainexamples: 748 | print( 'inc nrtrainex: rmse stroke ', stats_train_inc.stats_stroke.rmse( ) , 'thr', args.threshold_rmse_stroke) 749 | if ( stats_train_inc.stats_stroke.rmse( ) < args.threshold_rmse_stroke) and ( args.curnrtrainexamples < args.nrinputfiles_train ) : 750 | args.curnrtrainexamples = min( 2 * args.curnrtrainexamples, args.nrinputfiles_train ) 751 | dataloader_train.curnrexamples = args.curnrtrainexamples 752 | reached_threshold = True 753 | print( "REACHED THRESHOLD! --> increasing curnrtrainexamples to ", args.curnrtrainexamples) 754 | dataloader.nrbatches_per_epoch = max( 1, int( args.curnrtrainexamples / dataloader.nrseq_per_batch) ) 755 | dataloader.reset_batch_pointer( args ) 756 | args.incremental_min_nrpoints = 50 * args.curnrtrainexamples 757 | print ( "setting new nrbatches_per_epoch to: ", dataloader.nrbatches_per_epoch) 758 | 759 | if args.incremental_nr_digits: 760 | print( 'inc nr digits: rmse ', stats_train_inc.stats_stroke.rmse( ) , ' thr', args.threshold_rmse_stroke) 761 | if ( stats_train_inc.stats_stroke.rmse( ) < args.threshold_rmse_stroke) and ( args.curnrdigits < 10 ) : 762 | args.curnrdigits = min( 2 * args.curnrdigits, 10 ) 763 | reached_threshold = True 764 | print( "REACHED THRESHOLD! --> increasing curnrdigits to ", args.curnrdigits) 765 | dataloader.findAvailableExamples( args ) 766 | dataloader.nrbatches_per_epoch = max( 1, int( args.curnrtrainexamples / dataloader.nrseq_per_batch) ) 767 | dataloader.reset_batch_pointer( args ) 768 | print ( "setting new nrbatches_per_epoch to: ", dataloader.nrbatches_per_epoch) 769 | 770 | if reached_threshold: #reset rmse counters used for incremental learning 771 | nrseq_inc = np.ceil( args.incremental_min_nrpoints / min( args.current_seq_length, dataloader_train.avgseqlength) ) 772 | stats_train_inc = Stats( args, nrseq_inc, 'stats_train_inc' ) #stats over most recent incremental_min_nrpoints, for incremental methods 773 | 774 | #end of while loop ( batch) : 775 | cont = totnrpoints_trained <= args.maxnrpoints 776 | if ( stats_train.stats_stroke.rmse( ) < args.stopcrit_threshold_stroke_rmse_train ) : 777 | cont = False 778 | 779 | #end of run: 780 | print( 'End of run --> saving model' ) 781 | savemodel( saver, sess, dataloader_train, args, totnrbatches) 782 | print( 'done' ) 783 | 784 | if __name__ == '__main__': 785 | main( ) 786 | 787 | 788 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | import numpy as np 5 | import xml.etree.ElementTree as ET 6 | import random 7 | import svgwrite 8 | from IPython.display import SVG, display 9 | import tensorflow as tf 10 | 11 | 12 | def get_bounds( data, factor ): 13 | min_x = 0 14 | max_x = 0 15 | min_y = 0 16 | max_y = 0 17 | 18 | abs_x = 0 19 | abs_y = 0 20 | for i in range( len( data ) ): 21 | x = float( data[ i, 0 ] )/factor 22 | y = float( data[ i, 1 ] )/factor 23 | abs_x += x 24 | abs_y += y 25 | min_x = min( min_x, abs_x ) 26 | min_y = min( min_y, abs_y ) 27 | max_x = max( max_x, abs_x ) 28 | max_y = max( max_y, abs_y ) 29 | 30 | return ( min_x, max_x, min_y, max_y ) 31 | 32 | # version where each path is entire stroke ( smaller svg size, but have to keep same color ) 33 | def draw_strokes( data, factor = 10, svg_filename = 'sample.svg' ): 34 | min_x, max_x, min_y, max_y = get_bounds( data, factor ) 35 | dims = ( 50 + max_x - min_x, 50 + max_y - min_y ) 36 | 37 | dwg = svgwrite.Drawing( svg_filename, size = dims ) 38 | dwg.add( dwg.rect( insert = ( 0, 0 ), size = dims, fill = 'white' ) ) 39 | 40 | lift_pen = 1 41 | 42 | abs_x = 25 - min_x 43 | abs_y = 25 - min_y 44 | p = "M%s, %s " % ( abs_x, abs_y ) 45 | 46 | command = "m" 47 | 48 | for i in range( len( data ) ): 49 | if ( lift_pen == 1 ): 50 | command = "m" 51 | elif ( command != "l" ): 52 | command = "l" 53 | else: 54 | command = "" 55 | x = float( data[ i, 0 ] )/factor 56 | y = float( data[ i, 1 ] )/factor 57 | lift_pen = data[ i, 2 ] 58 | p += command+str( x )+", "+str( y )+" " 59 | 60 | the_color = "black" 61 | stroke_width = 1 62 | 63 | dwg.add( dwg.path( p ).stroke( the_color, stroke_width ).fill( "none" ) ) 64 | 65 | dwg.save( ) 66 | display( SVG( dwg.tostring( ) ) ) 67 | 68 | def draw_strokes_eos_weighted( stroke, param, factor = 10, svg_filename = 'sample[ A_eos.svg' ): 69 | c_data_eos = np.zeros( ( len( stroke ), 3 ) ) 70 | for i in range( len( param ) ): 71 | c_data_eos[ i, : ] = ( 1-param[ i ][ 6 ][ 0 ] )*225 # make color gray scale, darker = more likely to eos 72 | draw_strokes_custom_color( stroke, factor = factor, svg_filename = svg_filename, color_data = c_data_eos, stroke_width = 3 ) 73 | 74 | def draw_strokes_random_color( stroke, factor = 10, svg_filename = 'sample_random_color.svg', per_stroke_mode = True ): 75 | c_data = np.array( np.random.rand( len( stroke ), 3 )*240, dtype = np.uint8 ) 76 | if per_stroke_mode: 77 | switch_color = False 78 | for i in range( len( stroke ) ): 79 | if switch_color == False and i > 0: 80 | c_data[ i ] = c_data[ i-1 ] 81 | if stroke[ i, 2 ] < 1: # same strike 82 | switch_color = False 83 | else: 84 | switch_color = True 85 | draw_strokes_custom_color( stroke, factor = factor, svg_filename = svg_filename, color_data = c_data, stroke_width = 2 ) 86 | 87 | def draw_strokes_custom_color( data, factor = 10, svg_filename = 'test.svg', color_data = None, stroke_width = 1 ): 88 | min_x, max_x, min_y, max_y = get_bounds( data, factor ) 89 | dims = ( 50 + max_x - min_x, 50 + max_y - min_y ) 90 | 91 | dwg = svgwrite.Drawing( svg_filename, size = dims ) 92 | dwg.add( dwg.rect( insert = ( 0, 0 ), size = dims, fill = 'white' ) ) 93 | 94 | lift_pen = 1 95 | abs_x = 25 - min_x 96 | abs_y = 25 - min_y 97 | 98 | for i in range( len( data ) ): 99 | 100 | x = float( data[ i, 0 ] )/factor 101 | y = float( data[ i, 1 ] )/factor 102 | 103 | prev_x = abs_x 104 | prev_y = abs_y 105 | 106 | abs_x += x 107 | abs_y += y 108 | 109 | if ( lift_pen == 1 ): 110 | p = "M "+str( abs_x )+", "+str( abs_y )+" " 111 | else: 112 | p = "M +"+str( prev_x )+", "+str( prev_y )+" L "+str( abs_x )+", "+str( abs_y )+" " 113 | 114 | lift_pen = data[ i, 2 ] 115 | 116 | the_color = "black" 117 | 118 | if ( color_data is not None ): 119 | the_color = "rgb( "+str( int( color_data[ i, 0 ] ) )+", "+str( int( color_data[ i, 1 ] ) )+", "+str( int( color_data[ i, 2 ] ) )+" )" 120 | 121 | dwg.add( dwg.path( p ).stroke( the_color, stroke_width ).fill( the_color ) ) 122 | dwg.save( ) 123 | display( SVG( dwg.tostring( ) ) ) 124 | 125 | def draw_strokes_pdf( data, param, factor = 10, svg_filename = 'sample_pdf.svg' ): 126 | min_x, max_x, min_y, max_y = get_bounds( data, factor ) 127 | dims = ( 50 + max_x - min_x, 50 + max_y - min_y ) 128 | 129 | dwg = svgwrite.Drawing( svg_filename, size = dims ) 130 | dwg.add( dwg.rect( insert = ( 0, 0 ), size = dims, fill = 'white' ) ) 131 | 132 | abs_x = 25 - min_x 133 | abs_y = 25 - min_y 134 | 135 | num_mixture = len( param[ 0 ][ 0 ] ) 136 | 137 | for i in range( len( data ) ): 138 | 139 | x = float( data[ i, 0 ] )/factor 140 | y = float( data[ i, 1 ] )/factor 141 | 142 | for k in range( num_mixture ): 143 | pi = param[ i ][ 0 ][ k ] 144 | if pi > 0.01: # optimisation, ignore pi's less than 1% chance 145 | mu1 = param[ i ][ 1 ][ k ] 146 | mu2 = param[ i ][ 2 ][ k ] 147 | s1 = param[ i ][ 3 ][ k ] 148 | s2 = param[ i ][ 4 ][ k ] 149 | sigma = np.sqrt( s1*s2 ) 150 | dwg.add( dwg.circle( center = ( abs_x+mu1*factor, abs_y+mu2*factor ), r = int( sigma*factor ) ).fill( 'red', opacity = pi/( sigma*sigma*factor ) ) ) 151 | 152 | prev_x = abs_x 153 | prev_y = abs_y 154 | 155 | abs_x += x 156 | abs_y += y 157 | 158 | 159 | dwg.save( ) 160 | display( SVG( dwg.tostring( ) ) ) 161 | 162 | class DataLoader( ): 163 | 164 | def getRandValue( self ): 165 | value = self.randvalues[ self.randvaluepointer ] 166 | self.randvaluepointer += 1 167 | if ( self.randvaluepointer >= self.nrrandvalues ): 168 | self.randvaluepointer = 0 169 | return value 170 | 171 | def createRandValues( self ): 172 | self.nrrandvalues = 1000 173 | self.randvalues = np.zeros( ( self.nrrandvalues ), dtype = np.float32 ) 174 | for i in range( self.nrrandvalues ): 175 | value = random.random( ) 176 | self.randvalues[ i ] = value 177 | self.randvaluepointer = 0 178 | 179 | def getClassLabels( self ): 180 | if self.train: 181 | fn = self.data_dir + "trainlabels.txt" 182 | else: 183 | fn = self.data_dir + "testlabels.txt" 184 | classlabels = np.loadtxt( fn ) 185 | classlabels = classlabels[ :self.nrinputfiles ] 186 | return classlabels 187 | 188 | def findAvailableExamples( self, args ): 189 | self.availableExamples = [ ] 190 | findexamples = True 191 | if findexamples: 192 | for i in range( len( self.classlabels ) ): 193 | if ( self.classlabels[ i ] < args.curnrdigits ): 194 | self.availableExamples.append( i ) 195 | self.availableExamples = np.array( self.availableExamples ) 196 | 197 | def __init__( self, datadir, args, totnrfiles, curnrexamples, seqlength = 0, train = 1, file_label = "", print_input = 0, rangemin = 0, rangelen = 0 ): 198 | 199 | random.seed( 100*args.runnr ) 200 | np.random.seed( 100*args.runnr ) 201 | tf.set_random_seed( 100*args.runnr ) 202 | self.args = args 203 | 204 | self.data_dir = datadir 205 | self.train = train 206 | if self.train: 207 | self.traintest = "train" 208 | else: 209 | self.traintest = "test" 210 | self.rangemin = rangemin 211 | self.rangelen = rangelen 212 | self.nrinputfiles = totnrfiles 213 | 214 | self.curnrexamples = curnrexamples 215 | self.nrseq_per_batch = args.nrseq_per_batch 216 | self.file_label = file_label 217 | self.print_input = print_input 218 | self.nrinputvars_data = self.getInputVectorLength( args ) 219 | self.max_seq_length = args.max_seq_length 220 | 221 | self.nrsequenceinputs = 4 #dx dy eos eod 222 | self.nrauxinputvars = args.nrClassOutputVars #either [ 0..9 dx dy eos eod ] or [ dx dy eos ] 223 | 224 | strokedatafile = os.path.join( self.data_dir, "strokes_"+self.traintest+"ing_data"+ file_label+args.explabel+ ".cpkl" ) 225 | raw_data_dir = self.data_dir+"/lineStrokes" 226 | 227 | print ( "creating data cpkl file from source data" ) 228 | self.preprocess( args, raw_data_dir, strokedatafile ) 229 | 230 | if ( seqlength > 0 ): #provided 231 | self.seq_length = seqlength 232 | else: 233 | self.seq_length = min( self.max_seq_length, args.maxdigitlength_nrpoints ) 234 | 235 | self.load_preprocessed( args, strokedatafile ) 236 | 237 | self.classlabels = self.getClassLabels( ) 238 | self.findAvailableExamples( args ) 239 | 240 | self.nrbatches_per_epoch = max( 1, int( self.curnrexamples / self.nrseq_per_batch ) ) 241 | print ( "curnrexamples", self.curnrexamples, "seq_length", self.seq_length, " --> nrbatches_per_epoch: ", self.nrbatches_per_epoch ) 242 | 243 | print ( "loaded data" ) 244 | self.reset_batch_pointer( args ) 245 | 246 | def constructInputFileName( self, args, file_label, imgnr ): 247 | filename = self.data_dir + self.traintest + 'img' + file_label + '-' + str( imgnr ) + '-targetdata.txt' #currently, we expect 14 inputs 248 | return filename 249 | 250 | def getInputVectorLength( self, args ): 251 | result = [ ] 252 | 253 | filename = self.constructInputFileName( args, self.file_label, imgnr = 0 ) 254 | 255 | with open( filename ) as f: 256 | points = [ ] 257 | line = f.readline( ) 258 | print ( "read sample line from inputdata file: ", line ) 259 | nrs = [ float( x ) for x in line.split( ) ] 260 | length = len( nrs ) 261 | print ( "Determined nrinputvars based on data: ", length ) 262 | self.nrinputvars_data = length 263 | return length 264 | 265 | def preprocess( self, args, data_dir, strokedatafile ): 266 | filelist = [ ] 267 | 268 | if len( args.fileselection )>0: 269 | fileselection = ' '.join( args.fileselection ) 270 | if len( fileselection )>0: 271 | fileselection = [ int( s ) for s in fileselection.split( ', ' ) ] 272 | 273 | for imgnr in range( 0, self.nrinputfiles ): 274 | if len( args.fileselection )>0: 275 | fname = self.constructInputFileName( args, self.file_label, fileselection[ imgnr ] ) 276 | else: 277 | fname = self.constructInputFileName( args, self.file_label, imgnr ) 278 | filelist.append( fname ) 279 | 280 | def getStrokes( filename, nrauxinputvars ): #returns array of arrays with points 281 | result_points = [ ] 282 | result_auxinputs = [ ] 283 | nrsequencevars = 4 284 | dxmin = 1e100 285 | dxmax = -1e100 286 | dymin = 1e100 287 | dymax = -1e100 288 | nrauxinputs_data = 10 289 | 290 | with open( filename ) as f: 291 | points = [ ] 292 | auxinputs = [ ] 293 | for line in f: # read rest of lines 294 | nrs = [ float( x ) for x in line.split( ) ] 295 | auxinputvalues = nrs[ 0:nrauxinputvars ] 296 | point = nrs[ nrauxinputs_data:nrauxinputs_data+nrsequencevars ] #currently: x, y, end-of-stroke 297 | points.append( point ) 298 | auxinputs.append( auxinputvalues ) 299 | result_points.append( points ) 300 | result_auxinputs.append( auxinputs ) 301 | pointarray = np.array( points ) 302 | digitlength_nrpoints = len( points ) 303 | dxmin = pointarray[ :, 0 ].min( ) 304 | dxmax = pointarray[ :, 0 ].max( ) 305 | dymin = pointarray[ :, 1 ].min( ) 306 | dymax = pointarray[ :, 1 ].max( ) 307 | ranges = [ dxmin, dxmax, dymin, dymax ] 308 | return result_auxinputs, result_points, ranges, digitlength_nrpoints 309 | 310 | # converts a list of arrays into a 2d numpy int16 array 311 | def convert_stroke_to_array( stroke ): 312 | 313 | n_point = 0 314 | for i in range( len( stroke ) ): 315 | n_point += len( stroke[ i ] ) 316 | 317 | prev_x = 0 318 | prev_y = 0 319 | counter = 0 320 | nrsequencevars = 4 321 | stroke_data = np.zeros( ( n_point, nrsequencevars ), dtype = np.int16 ) 322 | 323 | for j in range( len( stroke ) ): 324 | for k in range( len( stroke[ j ] ) ): 325 | for s in range( nrsequencevars ): 326 | stroke_data[ counter, s ] = int( stroke[ j ][ k ][ s ] ) 327 | counter += 1 328 | return stroke_data 329 | 330 | # converts a list of arrays into a 2d numpy int16 array 331 | def convert_auxinputs_to_array( auxinputs, nrauxinputvars ): 332 | 333 | n_point = 0 334 | for i in range( len( auxinputs ) ): 335 | n_point += len( auxinputs[ i ] ) 336 | auxinputdata = np.zeros( ( n_point, nrauxinputvars ), dtype = np.int16 ) 337 | 338 | prev_x = 0 339 | prev_y = 0 340 | counter = 0 341 | 342 | for j in range( len( auxinputs ) ): 343 | for k in range( len( auxinputs[ j ] ) ): 344 | for a in range( nrauxinputvars ): 345 | auxinputdata[ counter, a ] = int( auxinputs[ j ][ k ][ a ] ) 346 | counter += 1 347 | return auxinputdata 348 | 349 | # preprocess body: build stroke array 350 | strokearray = [ ] 351 | auxinputarray = [ ] 352 | rangelist = [ ] 353 | self.seqlengthlist = [ ] 354 | if self.train: 355 | args.maxdigitlength_nrpoints = 0 356 | digitlengthsum = 0 357 | for i in range( len( filelist ) ): 358 | print ( 'dataloader', self.traintest, 'processing '+filelist[ i ] ) 359 | [ auxinputs, strokeinputs, ranges, digitlength_nrpoints ] = getStrokes( filelist[ i ], self.nrauxinputvars ) 360 | strokearray.append( convert_stroke_to_array( strokeinputs ) ) 361 | auxinputarray.append( convert_auxinputs_to_array( auxinputs, self.nrauxinputvars ) ) 362 | rangelist.append( ranges ) 363 | self.seqlengthlist.append( digitlength_nrpoints ) 364 | if self.train: 365 | args.maxdigitlength_nrpoints = max( args.maxdigitlength_nrpoints, digitlength_nrpoints ) 366 | digitlengthsum += digitlength_nrpoints 367 | 368 | rangearray = np.array( rangelist ) 369 | ranges = [ rangearray[ :, 0 ].min( ), rangearray[ :, 1 ].max( ), rangearray[ :, 2 ].min( ), rangearray[ :, 3 ].max( ) ] 370 | print ( "found overall ranges", ranges ) 371 | self.avgseqlength = digitlengthsum / len( filelist ) 372 | print( "dataloader: found avg seq length: ", self.avgseqlength ) 373 | print ( "found maxdigitlength_nrpoints", args.maxdigitlength_nrpoints ) 374 | 375 | f = open( strokedatafile, "wb" ) 376 | pickle.dump( strokearray, f ) 377 | pickle.dump( auxinputarray, f ) 378 | pickle.dump( ranges, f ) 379 | pickle.dump( self.seqlengthlist, f ) 380 | f.close( ) 381 | 382 | def load_preprocessed( self, args, strokedatafile ): 383 | f = open( strokedatafile, "rb" ) 384 | self.strokedataraw = pickle.load( f ) 385 | self.auxdataraw = pickle.load( f ) 386 | self.ranges = pickle.load( f ) 387 | self.seqlengthlist = pickle.load( f ) 388 | f.close( ) 389 | 390 | print ( "loaded ranges", self.ranges ) 391 | print ( "rangemin", self.rangemin, "rangelen", self.rangelen ) 392 | 393 | self.strokedata = [ ] #contains one array per file 394 | self.auxdata = [ ] 395 | counter = 0 396 | 397 | for data_el in self.strokedataraw: 398 | data = np.array( np.zeros( ( self.seq_length, self.nrsequenceinputs ), dtype = np.float32 ) ) 399 | len_data = len( data ) 400 | nrpoints = min( self.seq_length, len( data_el ) ) 401 | data[ :nrpoints, ] = data_el[ :nrpoints ] 402 | if ( len( data_el ) > self.seq_length ) and ( self.seq_length >= args.max_seq_length ): 403 | data[ self.seq_length-1, 2:4 ] = np.ones( ( 1, 2 ), dtype = np.float32 ) #add eos and eod for sequences exceeding length 404 | data[ nrpoints:, 0:4 ] = np.zeros( ( len_data - nrpoints, 4 ), dtype = np.float32 ) #pad remainder with zero rows 405 | data[ :, 0:2 ] -= self.rangemin 406 | data[ :, 0:2 ] /= self.rangelen 407 | self.strokedata.append( data ) 408 | 409 | counter += 1 410 | for data_el in self.auxdataraw: 411 | data = np.array( np.zeros( ( self.seq_length, self.nrauxinputvars ), dtype = np.float32 ) ) 412 | nrpoints = min( self.seq_length, len( data_el ) ) 413 | data[ :nrpoints, ] = data_el[ :nrpoints ] 414 | data[ nrpoints:self.seq_length, ] = data[ nrpoints-1, ] 415 | self.auxdata.append( data ) 416 | print ( "#sequences found in data: ", counter ) 417 | 418 | def next_batch( self, args, curseqlength ): 419 | 420 | # returns a batch of the training data of nrseq_per_batch * seq_length points 421 | x_batch = [ ] 422 | y_batch = [ ] 423 | seqlen = self.seq_length 424 | sequence_index = [ ] 425 | use_points_stopcrit = False 426 | 427 | nrpoints_per_batch = 0 428 | if hasattr( args, 'nrpoints_per_batch' ): 429 | nrpoints_per_batch = args.nrpoints_per_batch 430 | if nrpoints_per_batch > 0: 431 | use_points_stopcrit = True 432 | batch_nrpoints = 0 433 | batch_sequencenr = 0 434 | done = False 435 | while not done: 436 | sequence_index.append( self.pointer ) 437 | strokes = np.copy( self.strokedata[ self.pointer ] ) 438 | auxvalues = np.copy( self.auxdata[ self.pointer ] ) 439 | 440 | if args.useStrokeOutputVars: 441 | ytab = np.copy( np.hstack( [ auxvalues[ 1:seqlen ], strokes[ 1:seqlen ] ] ) ) 442 | else: 443 | ytab = np.copy( np.hstack( [ auxvalues[ 1:seqlen ] ] ) ) 444 | 445 | if args.discard_classvar_inputs: 446 | auxvalues[ : ] = 0 447 | 448 | if args.useClassInputVars: 449 | xtab = np.hstack( [ auxvalues[ :seqlen-1 ], strokes[ :seqlen-1 ] ] ) 450 | else: 451 | xtab = strokes[ :seqlen-1 ] 452 | 453 | actual_seq_length = self.seqlengthlist[ self.pointer ] 454 | 455 | firsttrainstep = 0 456 | if hasattr( args, 'firsttrainstep' ): 457 | firsttrainstep = args.firsttrainstep 458 | firsttrainstep = min ( firsttrainstep, actual_seq_length - 1 ) 459 | if firsttrainstep > 0: #remove earlier part from _target_ data, so that it will not be used in loss. 460 | ytab[ :firsttrainstep, : ] = 0 461 | 462 | #only keep points up to current seq_length - 1; e.g. if sequence has 3 points, use 2 pairs of ( x, y ): k = 1.n-1 for x and k = 2..n for y 463 | firstafter = min( actual_seq_length - 1, curseqlength ) #zero out part after sequence 464 | 465 | xtab[ firstafter:, : ] = 0 466 | ytab[ firstafter:, : ] = 0 467 | 468 | nrusedpoints = firstafter - firsttrainstep 469 | 470 | if args.discard_inputs: 471 | xtab[ : ] = 0 472 | 473 | x_batch.append( np.copy( xtab ) ) 474 | y_batch.append( np.copy( ytab ) ) 475 | 476 | self.next_batch_pointer( args ) 477 | 478 | batch_sequencenr += 1 479 | nrseq_per_batch = self.nrseq_per_batch 480 | if ( not self.train ) and hasattr( args, 'nrseq_per_batch_test' ): 481 | nrseq_per_batch = args.nrseq_per_batch_test 482 | 483 | if use_points_stopcrit: 484 | batch_nrpoints += nrusedpoints 485 | done = batch_nrpoints >= nrpoints_per_batch 486 | else: 487 | done = batch_sequencenr >= nrseq_per_batch 488 | 489 | return x_batch, y_batch, sequence_index 490 | 491 | 492 | def selectExamples( self, nrdigits ): 493 | sample = np.random.permutation( len( self.availableExamples ) ) 494 | return self.availableExamples[ sample ] 495 | 496 | def next_batch_pointer( self, args ): 497 | self.index += 1 498 | if ( self.index >= len( self.example_permutation ) ): 499 | self.reset_batch_pointer( args ) 500 | self.pointer = self.example_permutation[ self.index ] 501 | 502 | def reset_batch_pointer( self, args ): 503 | self.index = 0 504 | 505 | if ( args.incremental_nr_digits and self.train ): 506 | self.example_permutation = self.selectExamples( args.curnrdigits ) 507 | else: 508 | if self.train: 509 | self.example_permutation = np.random.permutation( int( self.curnrexamples ) ) 510 | else: 511 | self.example_permutation = np.arange( 0, int( self.curnrexamples ) ) 512 | self.pointer = self.example_permutation[ self.index ] 513 | 514 | --------------------------------------------------------------------------------