├── .gitignore ├── Config ├── CONFIG ├── Grammar │ ├── mathexp.gram │ ├── sparels.gmm │ └── symbol.types └── SymRec │ ├── OFF.blstm │ ├── ON.blstm │ ├── duration.counts │ ├── off.mav │ ├── on.mav │ └── segmentation.gmm ├── LICENSE ├── Makefile ├── README ├── README.md ├── SampleMathExps ├── exp.inkml └── exp.scgink ├── cellcyk.cc ├── cellcyk.h ├── duration.cc ├── duration.h ├── featureson.cc ├── featureson.h ├── gmm.cc ├── gmm.h ├── gparser.cc ├── gparser.h ├── grammar.cc ├── grammar.h ├── hypothesis.cc ├── hypothesis.h ├── logspace.cc ├── logspace.h ├── meparser.cc ├── meparser.h ├── online.cc ├── online.h ├── production.cc ├── production.h ├── rnnlib4seshat ├── ActivationFunctions.hpp ├── BiasLayer.hpp ├── BlockLayer.hpp ├── ClassificationLayer.cpp ├── ClassificationLayer.hpp ├── CollapseLayer.hpp ├── ConfigFile.hpp ├── Connection.hpp ├── Container.hpp ├── CopyConnection.hpp ├── DataExporter.cpp ├── DataExporter.hpp ├── DataSequence.hpp ├── FullConnection.hpp ├── GatherLayer.hpp ├── Helpers.hpp ├── IdentityLayer.hpp ├── InputLayer.hpp ├── Layer.cpp ├── Layer.hpp ├── Log.hpp ├── LstmLayer.hpp ├── Matrix.hpp ├── Mdrnn.cpp ├── Mdrnn.hpp ├── MultiArray.hpp ├── MultilayerNet.hpp ├── Named.hpp ├── NetcdfDataset.hpp ├── NetworkOutput.hpp ├── NeuronLayer.hpp ├── Optimiser.cpp ├── Optimiser.hpp ├── Random.cpp ├── Random.hpp ├── Rprop.hpp ├── SeqBuffer.hpp ├── SoftmaxLayer.hpp ├── SteepestDescent.hpp ├── String.hpp ├── StringAlignment.hpp ├── Trainer.hpp ├── TranscriptionLayer.hpp ├── WeightContainer.cpp └── WeightContainer.hpp ├── sample.cc ├── sample.h ├── segmentation.cc ├── segmentation.h ├── seshat.cc ├── sparel.cc ├── sparel.h ├── stroke.cc ├── stroke.h ├── symfeatures.cc ├── symfeatures.h ├── symrec.cc ├── symrec.h ├── tablecyk.cc └── tablecyk.h /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files 2 | *.slo 3 | *.lo 4 | *.o 5 | *.obj 6 | 7 | # Compiled Dynamic libraries 8 | *.so 9 | *.dylib 10 | *.dll 11 | 12 | # Compiled Static libraries 13 | *.lai 14 | *.la 15 | *.a 16 | *.lib 17 | 18 | # Executables 19 | *.exe 20 | *.out 21 | *.app 22 | -------------------------------------------------------------------------------- /Config/CONFIG: -------------------------------------------------------------------------------- 1 | GRAMMAR Config/Grammar/mathexp.gram 2 | SymbolTypes Config/Grammar/symbol.types 3 | SpatialRels Config/Grammar/sparels.gmm 4 | Duration Config/SymRec/duration.counts 5 | Segmentation Config/SymRec/segmentation.gmm 6 | RNNon Config/SymRec/ON.blstm 7 | RNNmavON Config/SymRec/on.mav 8 | RNNoff Config/SymRec/OFF.blstm 9 | RNNmavOFF Config/SymRec/off.mav 10 | RNNalpha 0.50 11 | MaxStrokes 4 12 | SegmentsTH 0.69474973 13 | InsPenalty 2.11917745 14 | SymbolSF 1.88593577 15 | DurationSF 0.93889491 16 | SegmentationSF 2.77746769 17 | RelationSF 0.63225349 18 | ProductionTSF 0.37846575 19 | ProductionBSF 0.14864657 20 | ClusterF 0.15680604 21 | -------------------------------------------------------------------------------- /Config/Grammar/sparels.gmm: -------------------------------------------------------------------------------- 1 | 6 9 5 2 | 0.718383234419 0.055574466826 0.078203856912 0.120754214335 0.025145411221 0.001938816286 3 | 0.0400465 0.0462119 0.158954 0.119728 0.338147 0.650734 0.0257464 0.0207452 0.0500226 4 | 6.38977e-29 0.0306662 3.11926 0.0920412 0.209213 12.3927 0.025514 0.035705 0.0255653 5 | 0.00471493 0.0423204 0.760158 0.10667 0.259005 2.78825 3.65786e-05 0.0447119 0.00472035 6 | 0.0153037 0.0569341 2.31413 1.11035 1.47678 10.5612 0.0245617 0.0298914 0.0315428 7 | 0.0650266 0.0394849 0.0599804 0.0243055 0.107061 0.13888 0.0461497 0.0425307 0.0662506 8 | 0.754189 -0.0702447 -1.61337 0.433981 1.18366 2.04308 -0.780692 0.133713 -0.0339323 9 | 1 0.0211677 -2.4223 0.3209 0.755544 4.08905 -0.759212 -0.293398 0.23767 10 | 0.911813 0.117074 -1.73303 0.395061 0.981709 2.48436 -0.996298 -0.248554 -0.0877496 11 | 0.875886 -0.0548455 -3.70841 0.906565 1.66987 5.74695 -0.838009 0.0188168 0.0312145 12 | 0.617975 -0.0592687 -0.753022 0.160156 0.661045 0.845 -0.70423 0.213682 -0.0935981 13 | 0.114239 0.599775 0.1071 0.0580644 0.120821 14 | 0.00828608 0.015773 0.0118063 0.0110558 0.041533 0.0242046 0.0172115 0.00612675 0.0252167 15 | 0.0365525 0.0351008 0.205108 0.0887008 0.274506 0.462619 0.0513276 0.0421679 0.061057 16 | 0.0351226 0.0342489 0.0144722 0.0132679 0.0315408 0.0226689 0.0459325 0.016168 0.105993 17 | 0.0137349 0.0104148 0.03013 0.0261673 0.0972631 0.0898623 0.0251939 0.0126444 0.0289874 18 | 0.0152452 0.0199651 0.0132074 0.0233409 0.030481 0.0229981 0.0145562 0.0144918 0.0224842 19 | 0.363074 -0.372596 -0.622582 0.137781 0.79348 0.451685 -0.2065 0.619659 0.150917 20 | 0.764614 -0.193784 -1.16842 0.283389 1.14608 1.19075 -0.617145 0.161704 0.144741 21 | 0.530155 -0.27583 -0.540513 0.0911176 0.632953 0.448073 -0.43841 0.366114 0.0877896 22 | 0.515078 -0.366387 -0.940781 0.232304 1.06029 0.821276 -0.279965 0.476539 0.230442 23 | 0.314098 -0.410848 -0.302712 -0.0575113 0.411066 0.194358 -0.130819 0.677469 0.177394 24 | 0.288838 0.127026 0.191728 0.216489 0.17592 25 | 0.00416017 0.0144403 0.00849072 0.0100371 0.0209229 0.0139464 7.40093e-05 0.0256097 0.0041528 26 | 0.00692202 0.00897028 0.0481703 0.0250189 0.209415 0.0869749 2.14762e-05 0.0241996 0.00692406 27 | 0.00644034 0.0132445 0.0438229 0.0221977 0.0655353 0.114342 6.61175e-05 0.0224242 0.0064378 28 | 0.0555328 0.042701 0.311197 0.0605234 0.330936 1.00043 0.0092768 0.0875441 0.057176 29 | 0.00695217 0.0206648 0.00817966 0.0136203 0.0182913 0.0136939 3.53137e-05 0.0237944 0.00693021 30 | 0.394356 0.509785 -0.554344 0.124304 0.635004 0.473684 -0.992665 -0.344988 -0.605484 31 | 0.327123 0.479743 -0.851251 0.170694 1.12427 0.57823 -0.996442 -0.268476 -0.672745 32 | 0.561281 0.447368 -0.706288 0.142201 0.697138 0.715438 -0.994889 -0.448858 -0.438695 33 | 0.530929 0.400142 -0.993652 0.144152 0.871201 1.1161 -0.947414 -0.21008 -0.423558 34 | 0.29066 0.483209 -0.331066 0.0143939 0.40378 0.258352 -0.994437 -0.255033 -0.709083 35 | 0.321126 0.149946 0.197312 0.0531342 0.278482 36 | 0.00877008 0.00682758 0.0945141 1.81377 0.799243 1.09448 0.0110191 0.00877008 0.00505034 37 | 0.00416901 0.00432588 0.00555767 0.310816 0.0172927 0.0265241 0.00382397 0.00416901 0.00680665 38 | 0.0239008 0.0174781 0.0628267 1.17046 0.24055 0.297461 0.0260537 0.0366781 0.0342021 39 | 0.00534745 0.0172469 0.015744 0.979295 0.0649992 0.0831585 0.00708846 0.00535608 0.00272675 40 | 0.00746677 0.0117582 0.00723016 0.0874107 0.0148005 0.0277919 0.00458986 0.00746677 0.0132412 41 | 0.579844 -0.662653 0.054634 -1.93479 0.759049 -0.868317 0.308506 0.420156 0.881954 42 | 0.520092 -0.310529 -0.0247561 -0.757472 -0.143297 0.192809 0.0996578 0.479908 0.616369 43 | 0.626503 -0.277981 -0.130995 -1.21765 -0.296303 0.558293 0.0276018 0.360311 0.650138 44 | 0.771504 -0.525694 -0.0108321 -1.29866 0.151349 -0.129685 0.138743 0.228479 0.90407 45 | 0.238862 -0.48239 -0.0176388 -0.594242 0.00578284 0.0294948 0.0981891 0.761138 0.332962 46 | 0.17711 0.323635 0.110824 0.287623 0.100808 47 | 0.00456953 0.014898 0.0179596 0.868124 0.0261201 0.0543883 0.00179963 0.00124693 0.00654775 48 | 0.0129368 0.00724917 0.0143949 0.198237 0.0327926 0.043634 0.00889696 0.00600127 0.021581 49 | 0.00308949 0.0103752 0.011871 0.832984 0.0197477 0.035784 0.0206501 0.00362518 0.0161138 50 | 0.030959 0.0298858 0.135843 2.08635 0.0898216 0.389771 0.0194944 0.0146531 0.0482674 51 | 0.00475755 0.0169366 0.024385 1.38767 0.0368425 0.0799561 0.000944127 0.00087138 0.0047204 52 | 0.632769 -0.108718 -0.213422 -1.52021 0.59879 -0.171946 -0.678616 0.304422 -0.0491304 53 | 0.510484 -0.187248 -0.191699 -0.883437 0.603435 -0.220038 -0.518879 0.421686 -0.0126336 54 | 0.812849 -0.0800241 -0.155431 -1.75566 0.429507 -0.118646 -0.70112 0.175381 0.108576 55 | 0.632944 -0.0152174 -0.0677986 -2.92811 0.635803 -0.500205 -0.709891 0.241568 -0.0821234 56 | 0.70463 -0.0262242 -0.218145 -2.13502 0.576373 -0.140083 -0.793266 0.20121 -0.0923016 57 | 0.24826 0.266405 0.243034 0.0933728 0.148928 58 | 0.00786919 0.0102346 0.0270097 0.10339 0.00603527 0.0910677 0.00591163 0.0146196 0.012015 59 | 0.00733458 0.00576684 0.0132222 0.053173 0.00452787 0.0446252 0.00949105 0.0099381 0.00616056 60 | 0.00332344 0.00334341 0.0204401 0.0517023 0.0188001 0.0550334 0.00337059 0.0048259 0.0030544 61 | 0.0018621 0.0066091 0.0110949 0.0415123 0.00531947 0.0404361 0.00601645 0.00610832 0.00610636 62 | 0.0151138 0.00159216 0.00885645 0.115227 0.0206761 0.0203693 0.00292855 0.00346035 0.00906875 63 | 0.305245 0.3179 1.06157 -2.31918 -0.0453983 -2.07774 -0.941775 0.00133904 -0.646647 64 | 0.234649 0.093974 0.798506 -1.76711 -0.0250409 -1.57197 -0.707962 0.283856 -0.481494 65 | 0.262252 0.281145 0.682458 -1.56305 -0.0506481 -1.31427 -0.90568 0.0789939 -0.653442 66 | 0.194612 0.100877 0.472739 -1.16295 0.0465906 -0.992069 -0.703561 0.288808 -0.51658 67 | 0.58627 -0.0425914 -0.167828 -0.875758 0.552609 -0.216952 -0.736717 0.252042 -0.156716 68 | 0.13359 0.284972 0.270771 0.226698 0.0839695 69 | -------------------------------------------------------------------------------- /Config/Grammar/symbol.types: -------------------------------------------------------------------------------- 1 | 102 2 | | n 3 | - n 4 | ! n 5 | + n 6 | 0 m 7 | 1 m 8 | 2 m 9 | 3 m 10 | 4 m 11 | 5 m 12 | 6 m 13 | 7 m 14 | 8 m 15 | 9 m 16 | a n 17 | A m 18 | \alpha n 19 | b a 20 | B m 21 | \beta d 22 | c n 23 | C m 24 | comma n 25 | \cos n 26 | d a 27 | \Delta n 28 | \div n 29 | dbar n 30 | dot n 31 | dots n 32 | e n 33 | E m 34 | equal n 35 | \exists n 36 | f a 37 | F m 38 | \forall n 39 | g d 40 | G m 41 | \gamma n 42 | \geq n 43 | \gt n 44 | h a 45 | H m 46 | i n 47 | I m 48 | \in n 49 | \infty n 50 | \int n 51 | j d 52 | k n 53 | l a 54 | L m 55 | \lambda a 56 | \leq n 57 | \lim a 58 | \log n 59 | lbrace n 60 | lbracket n 61 | lpar m 62 | \lt n 63 | m n 64 | M m 65 | \mu d 66 | n n 67 | N m 68 | \neq n 69 | o n 70 | p d 71 | P m 72 | \phi n 73 | \pi n 74 | \pm n 75 | \prime n 76 | q d 77 | r n 78 | R m 79 | \rightarrow n 80 | rbrace n 81 | rbracket n 82 | rpar m 83 | s n 84 | S m 85 | \sigma n 86 | \sin n 87 | \sqrt m 88 | \sum n 89 | t a 90 | T m 91 | \tan a 92 | \tg n 93 | \theta n 94 | \times n 95 | u n 96 | v n 97 | V m 98 | w n 99 | x n 100 | X m 101 | y d 102 | Y m 103 | z n 104 | -------------------------------------------------------------------------------- /Config/SymRec/off.mav: -------------------------------------------------------------------------------- 1 | 7.02993124128 3.52505419111 2.40995287827 15.6933348012 25.9197746971 0.23328943032 -0.258304657176 2.39074069548 5.09417326893 2 | 6.33958932556 3.52067357357 2.51277605056 13.8040159364 14.0999303426 4.40930100559 4.41952702215 1.2427974007 6.27785917011 3 | -------------------------------------------------------------------------------- /Config/SymRec/on.mav: -------------------------------------------------------------------------------- 1 | 172.117024448 50.2455044567 0.210231937573 0.22706088601 0.00253827353868 0.00101277453039 -0.00624285481303 2 | 992.792362035 34.5984455584 0.652141168602 0.691247480093 0.144491332862 0.154175388881 0.163523232241 3 | -------------------------------------------------------------------------------- /Config/SymRec/segmentation.gmm: -------------------------------------------------------------------------------- 1 | 2 4 5 2 | 0.50 0.50 3 | 0.033161 0.0786581 0.392137 0.137392 4 | 0.0365729 0.0731591 0.0871862 0.0876254 5 | 0.0222252 0.0493807 0.0393027 0.0134685 6 | 0.0554554 0.298172 1.73159 0.074029 7 | 0.243548 1.54521 13.7089 0.397415 8 | 0.488604 0.598039 0.981182 0.622963 9 | 0.837157 1.24887 0.460718 0.396426 10 | 0.422674 0.748496 0.303515 0.159743 11 | 0.597509 1.12518 2.82888 0.455837 12 | 1.02212 1.85958 4.31523 0.908786 13 | 0.281137 0.220711 0.25993 0.152835 0.0853858 14 | 0.00040065 0.00164303 0.0186928 0.0017177 15 | 0.00257013 0.0449728 0.00873153 0.0022907 16 | 0.0761435 0.233879 0.436473 0.162027 17 | 0.016679 0.0112306 0.0737718 0.0313773 18 | 1.20403e-06 4.51055 16.0136 0.568196 19 | 0.0306131 0.0551428 0.15828 0.052466 20 | 0.0653452 0.301276 0.12245 0.0623376 21 | 0.395353 0.589039 0.72823 0.509573 22 | 0.157565 0.134665 0.321965 0.303616 23 | 4.90522e-05 1.13667 3.36801 0.723376 24 | 0.234553 0.17369 0.18201 0.395828 0.0139184 25 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | CC=g++ 2 | LINK=-lxerces-c -lm 3 | FLAGS = -O3 -Wno-unused-result #-I/path/to/boost/ 4 | 5 | OBJFEAS=symfeatures.o featureson.o online.o 6 | OBJMUESTRA=sample.o stroke.o 7 | OBJPARSE=seshat.o meparser.o gparser.o grammar.o production.o symrec.o duration.o segmentation.o sparel.o gmm.o 8 | OBJTABLA=tablecyk.o cellcyk.o hypothesis.o logspace.o 9 | OBJRNNLIB=Random.o DataExporter.o WeightContainer.o ClassificationLayer.o Layer.o Mdrnn.o Optimiser.o 10 | RNNLIBHEADERS=rnnlib4seshat/DataSequence.hpp rnnlib4seshat/NetcdfDataset.hpp rnnlib4seshat/Mdrnn.hpp rnnlib4seshat/MultilayerNet.hpp rnnlib4seshat/Rprop.hpp rnnlib4seshat/SteepestDescent.hpp rnnlib4seshat/Trainer.hpp rnnlib4seshat/WeightContainer.hpp 11 | OBJS=$(OBJFEAS) $(OBJMUESTRA) $(OBJPARSE) $(OBJTABLA) $(OBJRNNLIB) 12 | 13 | seshat: $(OBJS) 14 | $(CC) -o seshat $(OBJS) $(FLAGS) $(LINK) 15 | 16 | seshat.o: seshat.cc grammar.o sample.o meparser.o 17 | $(CC) -c seshat.cc $(FLAGS) 18 | 19 | production.o: production.h production.cc symrec.o 20 | $(CC) -c production.cc $(FLAGS) 21 | 22 | grammar.o: grammar.h grammar.cc production.o gparser.o symrec.o 23 | $(CC) -c grammar.cc $(FLAGS) 24 | 25 | meparser.o: meparser.h meparser.cc grammar.o production.o symrec.o tablecyk.o cellcyk.o logspace.o duration.o segmentation.o sparel.o sample.o hypothesis.o 26 | $(CC) -c meparser.cc $(FLAGS) 27 | 28 | gparser.o: gparser.h gparser.cc 29 | $(CC) -c gparser.cc $(FLAGS) 30 | 31 | sample.o: sample.h sample.cc tablecyk.o cellcyk.o stroke.o grammar.o 32 | $(CC) -c sample.cc $(FLAGS) 33 | 34 | symrec.o: symrec.h symrec.cc symfeatures.o $(RNNLIBHEADERS) 35 | $(CC) -c symrec.cc $(FLAGS) 36 | 37 | duration.o: duration.h duration.cc symrec.o 38 | $(CC) -c duration.cc $(FLAGS) 39 | 40 | segmentation.o: segmentation.h segmentation.cc cellcyk.o sample.o gmm.o 41 | $(CC) -c segmentation.cc $(FLAGS) 42 | 43 | tablecyk.o: tablecyk.h tablecyk.cc cellcyk.o hypothesis.o 44 | $(CC) -c tablecyk.cc $(FLAGS) 45 | 46 | cellcyk.o: cellcyk.h cellcyk.cc hypothesis.o 47 | $(CC) -c cellcyk.cc $(FLAGS) 48 | 49 | hypothesis.o: hypothesis.h hypothesis.cc production.o grammar.o 50 | $(CC) -c hypothesis.cc $(FLAGS) 51 | 52 | logspace.o: logspace.h logspace.cc cellcyk.o 53 | $(CC) -c logspace.cc $(FLAGS) 54 | 55 | sparel.o: sparel.h sparel.cc hypothesis.o cellcyk.o gmm.o sample.o 56 | $(CC) -c sparel.cc $(FLAGS) 57 | 58 | gmm.o: gmm.cc gmm.h 59 | $(CC) -c gmm.cc $(FLAGS) 60 | 61 | stroke.o: stroke.cc stroke.h 62 | $(CC) -c stroke.cc $(FLAGS) 63 | 64 | symfeatures.o: symfeatures.cc online.o featureson.o 65 | $(CC) -c symfeatures.cc $(FLAGS) 66 | 67 | featureson.o: featureson.cc featureson.h online.o 68 | $(CC) -c featureson.cc $(FLAGS) 69 | 70 | online.o: online.cc online.h 71 | $(CC) -c online.cc $(FLAGS) 72 | 73 | #rnnlib4seshat 74 | Random.o: rnnlib4seshat/Random.cpp 75 | $(CC) -c rnnlib4seshat/Random.cpp $(FLAGS) 76 | 77 | DataExporter.o: rnnlib4seshat/DataExporter.cpp 78 | $(CC) -c rnnlib4seshat/DataExporter.cpp $(FLAGS) 79 | 80 | WeightContainer.o: rnnlib4seshat/WeightContainer.cpp 81 | $(CC) -c rnnlib4seshat/WeightContainer.cpp $(FLAGS) 82 | 83 | ClassificationLayer.o: rnnlib4seshat/ClassificationLayer.cpp 84 | $(CC) -c rnnlib4seshat/ClassificationLayer.cpp $(FLAGS) 85 | 86 | Layer.o: rnnlib4seshat/Layer.cpp 87 | $(CC) -c rnnlib4seshat/Layer.cpp $(FLAGS) 88 | 89 | Mdrnn.o: rnnlib4seshat/Mdrnn.cpp 90 | $(CC) -c rnnlib4seshat/Mdrnn.cpp $(FLAGS) 91 | 92 | Optimiser.o: rnnlib4seshat/Optimiser.cpp 93 | $(CC) -c rnnlib4seshat/Optimiser.cpp $(FLAGS) 94 | 95 | clean: 96 | rm -f *.o *~ \#*\# 97 | -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | ------------------------------------------------------------------------ 2 | SESHAT - Handwritten math expression parser 3 | Copyright (C) 2014, Francisco Alvaro 4 | ------------------------------------------------------------------------ 5 | 6 | More information at https://github.com/falvaro/seshat 7 | 8 | ---------------- 9 | License: Seshat is released under the license GPLv3 (see LICENSE file) 10 | 11 | ---------------- 12 | Requirements: 13 | - Compilation tools: makefile and g++ 14 | - Xerces-c library for parsing XML. 15 | - Boost libraries for RNNLIB (must be in include path or it should be added 16 | in the FLAGS variable of Makefile file). 17 | 18 | ---------------- 19 | Usage: ./seshat -c config -i input [-o output] [-r render.pgm] 20 | 21 | -c config: set the configuration file 22 | -i input: set the input math expression file 23 | -o output: save recognized expression to 'output' file (InkML format) 24 | -r render: save in 'render' the image representing the input expression (PGM format) 25 | -d graph: save in 'graph' the description of the recognized tree (DOT format) 26 | 27 | ---------------- 28 | Example: 29 | There are two example math expressions in folder "SampleMathExps". The following command 30 | will recognize the expression (x+y)^2 encoded in "exp.scgink" 31 | 32 | $ ./seshat -c Config/CONFIG -i SampleMathExps/exp.scgink -o out.inkml -r render.pgm -d out.dot 33 | 34 | This command outputs several information through the standard output, where the last line will 35 | provide the LaTeX string of the recognized math expression. 36 | 37 | An image representation of the input strokes will be rendered in "render.pgm". 38 | 39 | The InkML file of the recognized math expression will be saved in "out.inkml". 40 | 41 | The derivation tree of the expression provided as a graph in DOT format will be saved in "out.dot". 42 | The representation of the graph in, for example, postscript format can be obtained as follows 43 | 44 | $ dot -o out.ps out.dot -Tps 45 | 46 | It should be noted that only options "-c" and "-i" are mandatory. 47 | -------------------------------------------------------------------------------- /SampleMathExps/exp.scgink: -------------------------------------------------------------------------------- 1 | SCG_INK 2 | 8 3 | 27 4 | 78 54 5 | 80 55 6 | 80 57 7 | 82 58 8 | 83 60 9 | 85 63 10 | 86 65 11 | 89 68 12 | 91 70 13 | 96 75 14 | 100 78 15 | 104 82 16 | 108 86 17 | 112 90 18 | 113 94 19 | 117 98 20 | 121 101 21 | 125 105 22 | 128 109 23 | 132 111 24 | 136 113 25 | 140 115 26 | 144 119 27 | 145 121 28 | 147 122 29 | 148 122 30 | 149 122 31 | 24 32 | 136 64 33 | 134 66 34 | 133 68 35 | 131 71 36 | 130 73 37 | 128 77 38 | 124 82 39 | 120 87 40 | 117 93 41 | 113 98 42 | 109 104 43 | 105 108 44 | 103 112 45 | 102 115 46 | 99 117 47 | 97 119 48 | 96 121 49 | 94 123 50 | 93 124 51 | 92 124 52 | 92 125 53 | 92 126 54 | 91 126 55 | 91 127 56 | 20 57 | 188 87 58 | 190 87 59 | 192 88 60 | 194 88 61 | 196 88 62 | 199 88 63 | 202 88 64 | 205 88 65 | 210 89 66 | 214 89 67 | 219 89 68 | 226 89 69 | 233 89 70 | 238 89 71 | 244 89 72 | 249 89 73 | 253 89 74 | 254 89 75 | 255 89 76 | 256 89 77 | 85 78 | 297 65 79 | 297 66 80 | 298 68 81 | 298 71 82 | 299 73 83 | 301 75 84 | 302 79 85 | 304 85 86 | 306 90 87 | 309 95 88 | 311 101 89 | 313 106 90 | 317 112 91 | 321 113 92 | 325 115 93 | 329 116 94 | 331 116 95 | 334 116 96 | 336 116 97 | 338 116 98 | 340 116 99 | 342 115 100 | 344 114 101 | 345 112 102 | 346 110 103 | 347 109 104 | 347 107 105 | 347 105 106 | 347 103 107 | 347 102 108 | 346 100 109 | 345 99 110 | 345 98 111 | 344 97 112 | 343 96 113 | 342 95 114 | 341 95 115 | 341 97 116 | 341 99 117 | 341 102 118 | 341 106 119 | 341 111 120 | 341 116 121 | 341 121 122 | 341 126 123 | 341 131 124 | 341 137 125 | 341 141 126 | 341 143 127 | 340 145 128 | 340 146 129 | 340 148 130 | 340 150 131 | 340 151 132 | 340 153 133 | 340 154 134 | 339 155 135 | 339 156 136 | 338 157 137 | 337 157 138 | 336 157 139 | 334 155 140 | 332 154 141 | 332 152 142 | 331 150 143 | 330 147 144 | 329 145 145 | 329 142 146 | 329 140 147 | 329 138 148 | 330 136 149 | 330 135 150 | 331 133 151 | 333 132 152 | 335 132 153 | 336 131 154 | 338 130 155 | 340 130 156 | 342 130 157 | 344 130 158 | 345 130 159 | 347 130 160 | 348 130 161 | 349 130 162 | 349 131 163 | 44 164 | 69 45 165 | 68 46 166 | 67 48 167 | 66 49 168 | 64 51 169 | 63 53 170 | 62 55 171 | 61 57 172 | 59 60 173 | 58 62 174 | 58 64 175 | 56 67 176 | 55 68 177 | 55 70 178 | 54 72 179 | 54 75 180 | 54 77 181 | 53 79 182 | 53 81 183 | 53 84 184 | 53 86 185 | 53 89 186 | 53 92 187 | 53 94 188 | 54 97 189 | 55 98 190 | 56 100 191 | 56 102 192 | 56 103 193 | 58 105 194 | 58 107 195 | 59 108 196 | 59 110 197 | 60 111 198 | 61 113 199 | 63 114 200 | 64 116 201 | 65 117 202 | 65 120 203 | 66 121 204 | 67 122 205 | 67 123 206 | 68 123 207 | 69 124 208 | 52 209 | 361 45 210 | 361 46 211 | 362 47 212 | 363 48 213 | 364 49 214 | 365 51 215 | 366 53 216 | 368 55 217 | 369 57 218 | 370 59 219 | 372 60 220 | 373 63 221 | 374 66 222 | 375 68 223 | 375 70 224 | 377 73 225 | 378 75 226 | 379 77 227 | 379 80 228 | 379 82 229 | 379 84 230 | 379 86 231 | 379 88 232 | 379 89 233 | 379 91 234 | 379 93 235 | 379 95 236 | 379 97 237 | 379 99 238 | 379 101 239 | 379 103 240 | 378 104 241 | 377 106 242 | 375 108 243 | 375 109 244 | 374 110 245 | 373 111 246 | 372 113 247 | 372 115 248 | 371 117 249 | 370 119 250 | 369 120 251 | 368 121 252 | 368 122 253 | 367 122 254 | 366 123 255 | 365 123 256 | 364 123 257 | 364 124 258 | 363 124 259 | 362 124 260 | 362 125 261 | 69 262 | 395 19 263 | 396 19 264 | 397 19 265 | 398 18 266 | 399 17 267 | 401 17 268 | 402 17 269 | 404 17 270 | 406 17 271 | 407 17 272 | 410 17 273 | 412 17 274 | 414 17 275 | 416 18 276 | 418 18 277 | 419 18 278 | 420 19 279 | 421 20 280 | 421 21 281 | 421 22 282 | 421 23 283 | 421 24 284 | 420 26 285 | 419 27 286 | 419 29 287 | 417 31 288 | 416 32 289 | 415 32 290 | 413 33 291 | 412 33 292 | 411 34 293 | 409 34 294 | 408 34 295 | 407 34 296 | 405 34 297 | 404 34 298 | 403 34 299 | 402 34 300 | 401 34 301 | 400 34 302 | 400 33 303 | 399 33 304 | 399 32 305 | 399 31 306 | 400 31 307 | 401 31 308 | 402 31 309 | 403 32 310 | 405 33 311 | 406 34 312 | 408 34 313 | 409 35 314 | 410 35 315 | 411 36 316 | 413 36 317 | 415 37 318 | 417 37 319 | 419 38 320 | 420 38 321 | 421 38 322 | 423 39 323 | 424 39 324 | 425 39 325 | 426 39 326 | 427 39 327 | 428 39 328 | 429 39 329 | 430 39 330 | 431 39 331 | 36 332 | 221 52 333 | 222 52 334 | 222 53 335 | 222 54 336 | 222 57 337 | 223 59 338 | 223 61 339 | 223 63 340 | 223 66 341 | 223 68 342 | 223 71 343 | 223 73 344 | 223 76 345 | 223 79 346 | 223 83 347 | 223 86 348 | 223 89 349 | 223 92 350 | 223 96 351 | 223 99 352 | 223 101 353 | 223 103 354 | 223 104 355 | 223 106 356 | 223 107 357 | 223 108 358 | 223 110 359 | 223 111 360 | 223 112 361 | 223 113 362 | 223 114 363 | 223 115 364 | 223 116 365 | 223 117 366 | 224 118 367 | 225 118 368 | -------------------------------------------------------------------------------- /cellcyk.cc: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #include "cellcyk.h" 19 | #include 20 | 21 | CellCYK::CellCYK(int n, int ncc) { 22 | sig = NULL; 23 | nnt = n; 24 | nc = ncc; 25 | talla = 0; 26 | 27 | //Create (empty) hypotheses 28 | noterm = new Hypothesis*[nnt]; 29 | for(int i=0; iccc[i] || B->ccc[i] ) ? true : false; 70 | } 71 | 72 | //Check if cell H covers the same strokes that this 73 | bool CellCYK::ccEqual(CellCYK *H) { 74 | if( talla != H->talla ) 75 | return false; 76 | 77 | for(int i=0; iccc[i] ) 79 | return false; 80 | 81 | return true; 82 | } 83 | 84 | //Check if the intersection between the strokes of this cell and H is empty 85 | bool CellCYK::compatible(CellCYK *H) { 86 | for(int i=0; iccc[i] ) 88 | return false; 89 | 90 | return true; 91 | } 92 | 93 | -------------------------------------------------------------------------------- /cellcyk.h: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #ifndef _CELLCYK_ 19 | #define _CELLCYK_ 20 | 21 | struct Hypothesis; 22 | 23 | #include 24 | #include "hypothesis.h" 25 | 26 | using namespace std; 27 | 28 | struct CellCYK{ 29 | //Bounding box spatial region coordinates 30 | int x,y; //top-left 31 | int s,t; //bottom-right 32 | 33 | //Hypotheses for every non-terminals 34 | int nnt; 35 | Hypothesis **noterm; 36 | 37 | //Strokes covered in this cell 38 | int nc; 39 | bool *ccc; 40 | int talla; //total number of strokes 41 | 42 | //Next cell in linked list (CYK table of same size) 43 | CellCYK *sig; 44 | 45 | 46 | //Methods 47 | CellCYK(int n, int ncc); 48 | ~CellCYK(); 49 | 50 | bool operator<(const CellCYK &C); 51 | void ccUnion(CellCYK *A, CellCYK *B); 52 | bool ccEqual(CellCYK *H); 53 | bool compatible(CellCYK *H); 54 | }; 55 | 56 | 57 | #endif 58 | -------------------------------------------------------------------------------- /duration.cc: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #include "duration.h" 19 | 20 | 21 | DurationModel::DurationModel(char *str, int mxs, SymRec *sr) { 22 | FILE *fd = fopen(str,"r"); 23 | if( !fd ) { 24 | fprintf(stderr, "Error loading duration model '%s'\n", str); 25 | exit(-1); 26 | } 27 | 28 | max_strokes = mxs; 29 | Nsyms = sr->getNClases(); 30 | 31 | duration_prob = new float*[Nsyms]; 32 | for(int i=0; ikeyClase(str) ][nums-1] = count; 51 | } 52 | 53 | //Compute probabilities 54 | for(int i=0; i. 17 | */ 18 | #ifndef _DURATION_MODEL_ 19 | #define _DURATION_MODEL_ 20 | 21 | #include 22 | #include 23 | #include "symrec.h" 24 | 25 | using namespace std; 26 | 27 | class DurationModel{ 28 | int max_strokes; 29 | int Nsyms; 30 | float **duration_prob; 31 | 32 | void loadModel(FILE *fd, SymRec *sr); 33 | 34 | public: 35 | DurationModel(char *str, int mxs, SymRec *sr); 36 | ~DurationModel(); 37 | 38 | float prob(int symclas, int size); 39 | }; 40 | 41 | #endif 42 | -------------------------------------------------------------------------------- /featureson.h: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | This file is a modification of the online features original software 19 | covered by the following copyright and permission notice: 20 | 21 | */ 22 | /* 23 | Copyright (C) 2006,2007 Moisés Pastor 24 | 25 | This program is free software: you can redistribute it and/or modify 26 | it under the terms of the GNU General Public License as published by 27 | the Free Software Foundation, either version 3 of the License, or 28 | (at your option) any later version. 29 | 30 | This program is distributed in the hope that it will be useful, 31 | but WITHOUT ANY WARRANTY; without even the implied warranty of 32 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 33 | GNU General Public License for more details. 34 | 35 | You should have received a copy of the GNU General Public License 36 | along with this program. If not, see . 37 | */ 38 | #ifndef FEATURES_H 39 | #define FEATURES_H 40 | 41 | #include 42 | #include 43 | #include 44 | #include 45 | #include 46 | #include "online.h" 47 | 48 | #define MAXNUMHATS 200 49 | #define OFFSET_INS 20 50 | 51 | using namespace std; 52 | 53 | class frame { 54 | public: 55 | double x,y,dx,dy,ax,ay,k; 56 | 57 | void print(ostream & fd); 58 | int get_fr_dim(); 59 | 60 | double getFea(int i){ 61 | switch (i){ 62 | case 0: return x; 63 | case 1: return y; 64 | case 2: return dx; 65 | case 3: return dy; 66 | case 4: return ax; 67 | case 5: return ay; 68 | case 6: return k; 69 | default: 70 | fprintf(stderr, "Error: getFea(%d)\n", i); 71 | exit(-1); 72 | } 73 | } 74 | }; 75 | 76 | class sentenceF { 77 | public: 78 | string transcrip; 79 | int n_frames; 80 | frame * frames; 81 | 82 | sentenceF(); 83 | ~sentenceF(); 84 | 85 | bool data_plot(ostream & fd); 86 | bool print(ostream & fd); 87 | 88 | void calculate_features(sentence &s); 89 | 90 | private: 91 | vector normalizaAspect(vector & puntos); 92 | void calculate_derivatives(vector & points, bool norm=true); 93 | void calculate_kurvature(); 94 | }; 95 | 96 | 97 | #endif 98 | -------------------------------------------------------------------------------- /gmm.cc: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include "gmm.h" 23 | 24 | #define PI 3.14159265359 25 | 26 | using namespace std; 27 | 28 | GMM::GMM( char *model ) { 29 | loadModel( model ); 30 | } 31 | 32 | void GMM::loadModel( char *str ) { 33 | FILE *fd = fopen(str, "r"); 34 | if( !fd ) { 35 | fprintf(stderr, "Error loading GMM model file '%s'\n", str); 36 | exit(-1); 37 | } 38 | 39 | //Read parameters 40 | fscanf(fd, "%d %d %d", &C, &D, &G); 41 | 42 | //Read prior probabilities 43 | prior = new float[C]; 44 | for(int i=0; i. 17 | */ 18 | #ifndef __GMM__ 19 | #define __GMM__ 20 | 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | 27 | class GMM{ 28 | int C, D, G; 29 | float **invcov, **mean, **weight, *prior, *det; 30 | 31 | void loadModel( char *str ); 32 | float pdf(int c, float *v); 33 | 34 | public: 35 | 36 | GMM(char *model); 37 | ~GMM(); 38 | 39 | void posterior(float *x, float *pr); 40 | }; 41 | 42 | 43 | #endif 44 | -------------------------------------------------------------------------------- /gparser.cc: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #include "gparser.h" 19 | #include 20 | #include 21 | 22 | #define SIZE 1024 23 | 24 | gParser::gParser(Grammar *gram, FILE *fd, char *path) { 25 | g = gram; 26 | 27 | int n = strlen(path); 28 | 29 | if( n > 0 ) { 30 | pre = new char[n+1]; 31 | strcpy(pre, path); 32 | } 33 | else { 34 | pre = new char[1]; 35 | pre[0] = 0; 36 | } 37 | 38 | parse( fd ); 39 | } 40 | 41 | gParser::~gParser() { 42 | delete[] pre; 43 | } 44 | 45 | bool gParser::isFillChar(char c) { 46 | switch(c) { 47 | case ' ': 48 | case '\t': 49 | case '\n': 50 | case '\r': 51 | return true; 52 | default: 53 | return false; 54 | } 55 | } 56 | 57 | int gParser::split(char *str,char ***res){ 58 | char tokensaux[2*SIZE]; 59 | int n=0, i=0, j=0; 60 | 61 | while( isFillChar(str[i]) ) i++; 62 | 63 | while( str[i] ) { 64 | if( str[i] == '\"' ) { 65 | i++; 66 | while( str[i] && str[i] != '\"' ) { 67 | tokensaux[j] = str[i]; 68 | i++; j++; 69 | } 70 | i++; 71 | } 72 | else { 73 | while( str[i] && !isFillChar(str[i]) ) { 74 | tokensaux[j] = str[i]; 75 | i++; j++; 76 | } 77 | } 78 | tokensaux[j++] = 0; 79 | n++; 80 | while( str[i] && isFillChar(str[i]) ) i++; 81 | } 82 | 83 | char **toks=new char*[n]; 84 | for(i=0, j=0; iaddNoTerminal(tok1); 119 | } 120 | 121 | //Read start symbol(s) of the grammar 122 | while( nextLine(fd, linea) && strcmp(linea, "PTERM\n") ) { 123 | sscanf(linea, "%s", tok1); 124 | g->addInitSym(tok1); 125 | } 126 | 127 | //Read terminal productions 128 | while( nextLine(fd, linea) && strcmp(linea, "PBIN\n") ) { 129 | float pr; 130 | 131 | sscanf(linea, "%f %s %s %s", &pr, tok1, tok2, aux); 132 | 133 | g->addTerminal(pr, tok1, tok2, aux); 134 | } 135 | 136 | //Read binary productions 137 | while( nextLine(fd, linea) ) { 138 | char **tokens; 139 | int ntoks = split(linea, &tokens); 140 | 141 | if( ntoks != 7 ) { 142 | fprintf(stderr, "Error: Grammar not valid (PBIN)\n"); 143 | exit(-1); 144 | } 145 | 146 | if( !strcmp(tokens[1], "H") ) 147 | g->addRuleH(atof(tokens[0]), tokens[2], tokens[3], tokens[4], tokens[5], tokens[6]); 148 | else if( !strcmp(tokens[1], "V") ) 149 | g->addRuleV(atof(tokens[0]), tokens[2], tokens[3], tokens[4], tokens[5], tokens[6]); 150 | else if( !strcmp(tokens[1], "Ve") ) 151 | g->addRuleVe(atof(tokens[0]), tokens[2], tokens[3], tokens[4], tokens[5], tokens[6]); 152 | else if( !strcmp(tokens[1], "Sup") ) 153 | g->addRuleSup(atof(tokens[0]), tokens[2], tokens[3], tokens[4], tokens[5], tokens[6]); 154 | else if( !strcmp(tokens[1], "Sub") ) 155 | g->addRuleSub(atof(tokens[0]), tokens[2], tokens[3], tokens[4], tokens[5], tokens[6]); 156 | else if( !strcmp(tokens[1], "SSE") ) 157 | g->addRuleSSE(atof(tokens[0]), tokens[2], tokens[3], tokens[4], tokens[5], tokens[6]); 158 | else if( !strcmp(tokens[1], "Ins") ) 159 | g->addRuleIns(atof(tokens[0]), tokens[2], tokens[3], tokens[4], tokens[5], tokens[6]); 160 | else if( !strcmp(tokens[1], "Mrt") ) 161 | g->addRuleMrt(atof(tokens[0]), tokens[2], tokens[3], tokens[4], tokens[5], tokens[6]); 162 | else { 163 | fprintf(stderr, "Error: Binary rule type '%s' nor valid\n", tokens[1]); 164 | exit(-1); 165 | } 166 | 167 | //Free memory 168 | for(int j=0; j. 17 | */ 18 | #ifndef _G_PARSER_ 19 | #define _G_PARSER_ 20 | 21 | struct Grammar; 22 | 23 | #include 24 | #include 25 | #include "grammar.h" 26 | 27 | class gParser{ 28 | Grammar *g; 29 | char *pre; 30 | 31 | bool isFillChar(char c); 32 | int split(char *str,char ***res); 33 | bool nextLine(FILE *fd, char *lin); 34 | void solvePath(char *in, char *out); 35 | public: 36 | gParser(Grammar *gram, FILE *fd, char *path); 37 | ~gParser(); 38 | 39 | void parse(FILE *fd); 40 | }; 41 | 42 | #endif 43 | -------------------------------------------------------------------------------- /grammar.h: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #ifndef _GRAMMAR_ 19 | #define _GRAMMAR_ 20 | 21 | class gParser; 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include "production.h" 28 | #include "gparser.h" 29 | #include "symrec.h" 30 | 31 | using namespace std; 32 | 33 | struct Grammar{ 34 | map noTerminales; 35 | list initsyms; 36 | bool *esInit; 37 | SymRec *sym_rec; 38 | 39 | list prodsH, prodsSup, prodsSub; 40 | list prodsV, prodsVe, prodsIns, prodsMrt, prodsSSE; 41 | list prodTerms; 42 | 43 | Grammar(char *conf, SymRec *SR); 44 | ~Grammar(); 45 | 46 | const char *key2str(int k); 47 | void addInitSym(char *str); 48 | void addNoTerminal(char *str); 49 | void addTerminal(float pr, char *S, char *T, char *tex); 50 | 51 | void addRuleH(float pr, char *S, char *A, char *B, char *out, char *merge); 52 | void addRuleV(float pr, char *S, char *A, char *B, char *out, char *merge); 53 | void addRuleVe(float pr, char *S, char *A, char *B, char *out, char *merge); 54 | void addRuleSup(float pr, char *S, char *A, char *B, char *out, char *merge); 55 | void addRuleSub(float pr, char *S, char *A, char *B, char *out, char *merge); 56 | void addRuleSSE(float pr, char *S, char *A, char *B, char *out, char *merge); 57 | void addRuleIns(float pr, char *S, char *A, char *B, char *out, char *merge); 58 | void addRuleMrt(float pr, char *S, char *A, char *B, char *out, char *merge); 59 | }; 60 | 61 | #endif 62 | -------------------------------------------------------------------------------- /hypothesis.cc: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #include "hypothesis.h" 19 | 20 | Hypothesis::Hypothesis(int c, double p, CellCYK *cd, int nt) { 21 | clase = c; 22 | pr = p; 23 | hi = hd = NULL; 24 | prod = NULL; 25 | prod_sse = NULL; 26 | pt = NULL; 27 | lcen = rcen = 0; 28 | parent = cd; 29 | ntid = nt; 30 | inkml_id = "none"; 31 | } 32 | 33 | void Hypothesis::copy(Hypothesis *H) { 34 | clase = H->clase; 35 | pr = H->pr; 36 | hi = H->hi; 37 | hd = H->hd; 38 | prod = H->prod; 39 | prod_sse = H->prod_sse; 40 | pt = H->pt; 41 | lcen = H->lcen; 42 | rcen = H->rcen; 43 | parent = H->parent; 44 | ntid = H->ntid; 45 | inkml_id = H->inkml_id; 46 | } 47 | 48 | Hypothesis::~Hypothesis() { 49 | } 50 | -------------------------------------------------------------------------------- /hypothesis.h: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #ifndef _HYPOTHESIS_ 19 | #define _HYPOTHESIS_ 20 | 21 | class ProductionB; 22 | class ProductionT; 23 | 24 | struct CellCYK; 25 | struct Grammar; 26 | 27 | #include 28 | #include 29 | #include 30 | #include "production.h" 31 | #include "cellcyk.h" 32 | #include "grammar.h" 33 | 34 | using namespace std; 35 | 36 | struct Hypothesis{ 37 | int clase; //If the hypothesis encodes a terminal symbols this is the class id (-1 otherwise) 38 | double pr; //log-probability 39 | 40 | //References to left-child (hi) and right-child (hd) to create the derivation tree 41 | Hypothesis *hi, *hd; 42 | 43 | //The production used to create this hypothesis (either Binary or terminal) 44 | ProductionB *prod; 45 | ProductionT *pt; 46 | 47 | //INKML_id for terminal symbols in order to create the InkML output 48 | string inkml_id; 49 | //Auxiliar var to retrieve the used production in the special SSE treatment 50 | ProductionB *prod_sse; 51 | 52 | //Vertical center left and right 53 | int lcen, rcen; 54 | 55 | CellCYK *parent; //Parent cell 56 | int ntid; //Nonterminal ID in parent 57 | 58 | //Methods 59 | Hypothesis(int c, double p, CellCYK *cd, int nt); 60 | ~Hypothesis(); 61 | 62 | void copy(Hypothesis *SYM); 63 | }; 64 | 65 | #endif 66 | -------------------------------------------------------------------------------- /logspace.h: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #ifndef _LOGSPACE_ 19 | #define _LOGSPACE_ 20 | 21 | #include 22 | #include 23 | #include "cellcyk.h" 24 | 25 | class LogSpace{ 26 | int N; 27 | int RX, RY; 28 | CellCYK **data; 29 | 30 | void quicksort(CellCYK **vec, int ini, int fin); 31 | int partition(CellCYK **vec, int ini, int fin); 32 | void bsearch(int sx, int sy, int ss, int st, list *set); 33 | void bsearchStv(int sx, int sy, int ss, int st, list *set, bool U_V, CellCYK *cd); 34 | void bsearchHBP(int sx, int sy, int ss, int st, list *set, CellCYK *cd); 35 | 36 | public: 37 | LogSpace(CellCYK *c, int nr, int dx, int dy); 38 | ~LogSpace(); 39 | 40 | void getH(CellCYK *c, list *set); 41 | void getV(CellCYK *c, list *set); 42 | void getU(CellCYK *c, list *set); 43 | void getI(CellCYK *c, list *set); 44 | void getM(CellCYK *c, list *set); 45 | void getS(CellCYK *c, list *set); 46 | }; 47 | 48 | #endif 49 | -------------------------------------------------------------------------------- /meparser.h: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #ifndef _MEPARSER_ 19 | #define _MEPARSER_ 20 | 21 | #include 22 | #include 23 | #include 24 | #include "symrec.h" 25 | #include "sample.h" 26 | #include "sparel.h" 27 | #include "duration.h" 28 | #include "segmentation.h" 29 | #include "logspace.h" 30 | #include "grammar.h" 31 | #include "hypothesis.h" 32 | #include "tablecyk.h" 33 | #include "cellcyk.h" 34 | 35 | class meParser{ 36 | 37 | Grammar *G; 38 | 39 | int max_strokes; 40 | float clusterF, segmentsTH; 41 | float ptfactor, pbfactor, rfactor; 42 | float qfactor, dfactor, gfactor, InsPen; 43 | 44 | SymRec *sym_rec; 45 | GMM *gmm_spr; 46 | DurationModel *duration; 47 | SegmentationModelGMM *segmentation; 48 | 49 | //Private methods 50 | void loadSymRec(char *conf); 51 | int tree2dot(FILE *fd, Hypothesis *H, int id); 52 | 53 | void initCYKterms(Sample *m, TableCYK *tcyk, int N, int K); 54 | 55 | void combineStrokes(Sample *M, TableCYK *tcyk, LogSpace **LSP, int N); 56 | CellCYK* fusion(Sample *M, ProductionB *pd, Hypothesis *A, Hypothesis *B, int N, double prob); 57 | 58 | public: 59 | meParser(char *conf); 60 | ~meParser(); 61 | 62 | //Parse math expression 63 | void parse_me(Sample *M); 64 | 65 | //Output formatting methods 66 | void print_symrec(Hypothesis *H); 67 | void print_latex(Hypothesis *H); 68 | void save_dot( Hypothesis *H, char *outfile ); 69 | }; 70 | 71 | #endif 72 | -------------------------------------------------------------------------------- /online.cc: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | This file is a modification of the online features original software 19 | covered by the following copyright and permission notice: 20 | 21 | */ 22 | /* 23 | Copyright (C) 2006,2007 Moisés Pastor 24 | 25 | This program is free software: you can redistribute it and/or modify 26 | it under the terms of the GNU General Public License as published by 27 | the Free Software Foundation, either version 3 of the License, or 28 | (at your option) any later version. 29 | 30 | This program is distributed in the hope that it will be useful, 31 | but WITHOUT ANY WARRANTY; without even the implied warranty of 32 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 33 | GNU General Public License for more details. 34 | 35 | You should have received a copy of the GNU General Public License 36 | along with this program. If not, see . 37 | */ 38 | #include "online.h" 39 | 40 | 41 | // Aux functions 42 | 43 | inline int MAX(int a, int b) { 44 | if (a>=b) return a; 45 | else return b; 46 | } 47 | 48 | inline int MIN(int a, int b) { 49 | if (a<=b) return a; 50 | else return b; 51 | } 52 | 53 | // 54 | // "stroke" methods 55 | // 56 | 57 | stroke::stroke(int n_p, bool pen_d, bool is_ht): n_points(n_p), pen_down(pen_d), is_hat(is_ht) {} 58 | 59 | 60 | int stroke::F_XMIN() { 61 | int xmin=INT_MAX; 62 | for (int p=0; ppoints[p].x) xmin=points[p].x; 64 | return xmin; 65 | } 66 | 67 | int stroke::F_XMAX() { 68 | int xmax=INT_MIN; 69 | for (int p=0; p puntos=strokes[s].points; 95 | int np=strokes[s].n_points; 96 | for (int p=0; p puntos=strokes[i].points; 115 | int np=strokes[i].n_points; 116 | for (int p=0; p=np) { 123 | sum_x+=puntos[np-1].x; 124 | sum_y+=puntos[np-1].y; 125 | } else { 126 | sum_x+=puntos[c].x; 127 | sum_y+=puntos[c].y; 128 | } 129 | Point point(int(sum_x/(cont_size*2+1)),int(sum_y/(cont_size*2+1))); 130 | strokeNorm.points.push_back(point); 131 | } 132 | strokeNorm.pen_down=strokes[i].pen_down; 133 | strokeNorm.n_points=strokeNorm.points.size(); 134 | (*sentNorm).strokes.push_back(strokeNorm); 135 | } 136 | return sentNorm; 137 | } 138 | -------------------------------------------------------------------------------- /online.h: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | This file is a modification of the online features original software 19 | covered by the following copyright and permission notice: 20 | 21 | */ 22 | /* 23 | Copyright (C) 2006,2007 Moisés Pastor 24 | 25 | This program is free software: you can redistribute it and/or modify 26 | it under the terms of the GNU General Public License as published by 27 | the Free Software Foundation, either version 3 of the License, or 28 | (at your option) any later version. 29 | 30 | This program is distributed in the hope that it will be useful, 31 | but WITHOUT ANY WARRANTY; without even the implied warranty of 32 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 33 | GNU General Public License for more details. 34 | 35 | You should have received a copy of the GNU General Public License 36 | along with this program. If not, see . 37 | */ 38 | #ifndef ONLINE_H 39 | #define ONLINE_H 40 | 41 | #include 42 | #include 43 | #include 44 | #include 45 | #include 46 | #include 47 | #include 48 | 49 | using namespace std; 50 | 51 | //Real point 52 | class PointR { 53 | // True if this is the last point of a stroke 54 | bool point_pu; 55 | 56 | public: 57 | float x, y; 58 | 59 | PointR(float _x, float _y): x(_x), y(_y), point_pu(false) {} 60 | PointR & operator= (const PointR & p) { 61 | x=p.x; y=p.y; 62 | point_pu=p.point_pu; 63 | return *this; 64 | } 65 | bool operator ==(const PointR & p) const { 66 | return p.x==x && p.y==y; 67 | } 68 | bool operator !=(const PointR & p) const { 69 | return p.x!=x || p.y!=y; 70 | } 71 | void setpu() { 72 | point_pu=1; 73 | } 74 | bool getpu() { 75 | return point_pu; 76 | } 77 | }; 78 | 79 | 80 | 81 | //Integer point 82 | class Point { 83 | // True if this is the last point of a stroke 84 | bool point_pu; 85 | public: 86 | int x, y; 87 | 88 | Point(int _x, int _y): x(_x), y(_y), point_pu(false) {} 89 | Point & operator= (const Point & p) { 90 | x=p.x; y=p.y; 91 | point_pu=p.point_pu; 92 | return *this; 93 | } 94 | bool operator == (const Point & p) const { 95 | return p.x==x && p.y==y; 96 | } 97 | bool operator !=(const Point & p) const { 98 | return p.x!=x || p.y!=y; 99 | } 100 | void setpu() { 101 | point_pu=1; 102 | } 103 | bool getpu() { 104 | return point_pu; 105 | } 106 | }; 107 | 108 | 109 | class stroke { 110 | public: 111 | int n_points; 112 | bool pen_down; 113 | bool is_hat; 114 | vector points; 115 | 116 | stroke(int n_p=0, bool pen_d=0, bool is_ht=0); 117 | 118 | int F_XMIN(); 119 | int F_XMAX(); 120 | int F_XMED(); 121 | }; 122 | 123 | 124 | class sentence { 125 | public: 126 | int n_strokes; 127 | vector strokes; 128 | 129 | sentence(int n_s); 130 | 131 | sentence * anula_rep_points(); 132 | sentence * suaviza_traza(int cont_size=2); 133 | }; 134 | 135 | 136 | #endif 137 | 138 | -------------------------------------------------------------------------------- /production.h: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #ifndef _PRODUCTION_ 19 | #define _PRODUCTION_ 20 | 21 | class CellCYK; 22 | 23 | #include 24 | #include 25 | #include 26 | #include "hypothesis.h" 27 | #include "cellcyk.h" 28 | #include "symrec.h" 29 | #include "grammar.h" 30 | 31 | //Binary productions of the grammar (2D-PCFG) 32 | class ProductionB{ 33 | protected: 34 | char *outStr; 35 | char merge_cen; 36 | 37 | public: 38 | int S; 39 | int A, B; 40 | float prior; 41 | 42 | ProductionB(int s, int a, int b); 43 | ProductionB(int s, int a, int b, float pr, char *out); 44 | ~ProductionB(); 45 | 46 | float solape(Hypothesis *a, Hypothesis *b); 47 | void printOut(Grammar *G, Hypothesis *H); 48 | void setMerges(char c); 49 | void mergeRegions(Hypothesis *a, Hypothesis *b, Hypothesis *s); 50 | bool check_out(); 51 | char *get_outstr(); 52 | 53 | //Pure virtual functions 54 | virtual char tipo() = 0; 55 | virtual void print() = 0; 56 | virtual void print_mathml(Grammar *G, Hypothesis *H, FILE *fout, int *nid) = 0; 57 | }; 58 | 59 | 60 | //Production S -> A : B 61 | class ProductionH : public ProductionB{ 62 | 63 | public: 64 | ProductionH(int s, int a, int b); 65 | ProductionH(int s, int a, int b, float pr, char *out); 66 | 67 | void print(); 68 | char tipo(); 69 | void mergeRegions(Hypothesis *a, Hypothesis *b, Hypothesis *s); 70 | void print_mathml(Grammar *G, Hypothesis *H, FILE *fout, int *nid); 71 | }; 72 | 73 | 74 | //Production: S -> A / B 75 | class ProductionV : public ProductionB{ 76 | 77 | public: 78 | ProductionV(int s, int a, int b); 79 | ProductionV(int s, int a, int b, float pr, char *out); 80 | 81 | void print(); 82 | char tipo(); 83 | void mergeRegions(Hypothesis *a, Hypothesis *b, Hypothesis *s); 84 | void print_mathml(Grammar *G, Hypothesis *H, FILE *fout, int *nid); 85 | }; 86 | 87 | 88 | //Production: S -> A /u B 89 | class ProductionU : public ProductionB{ 90 | 91 | public: 92 | ProductionU(int s, int a, int b); 93 | ProductionU(int s, int a, int b, float pr, char *out); 94 | 95 | void print(); 96 | char tipo(); 97 | void mergeRegions(Hypothesis *a, Hypothesis *b, Hypothesis *s); 98 | void print_mathml(Grammar *G, Hypothesis *H, FILE *fout, int *nid); 99 | }; 100 | 101 | 102 | //Production: S -> A /e B 103 | class ProductionVe : public ProductionB{ 104 | 105 | public: 106 | ProductionVe(int s, int a, int b); 107 | ProductionVe(int s, int a, int b, float pr, char *out); 108 | 109 | void print(); 110 | char tipo(); 111 | void mergeRegions(Hypothesis *a, Hypothesis *b, Hypothesis *s); 112 | void print_mathml(Grammar *G, Hypothesis *H, FILE *fout, int *nid); 113 | }; 114 | 115 | 116 | 117 | //Production: S -> A sse B 118 | class ProductionSSE : public ProductionB{ 119 | 120 | public: 121 | ProductionSSE(int s, int a, int b); 122 | ProductionSSE(int s, int a, int b, float pr, char *out); 123 | 124 | void print(); 125 | char tipo(); 126 | void mergeRegions(Hypothesis *a, Hypothesis *b, Hypothesis *s); 127 | void print_mathml(Grammar *G, Hypothesis *H, FILE *fout, int *nid); 128 | }; 129 | 130 | 131 | 132 | //Production: S -> A ^ B 133 | class ProductionSup : public ProductionB{ 134 | 135 | public: 136 | ProductionSup(int s, int a, int b); 137 | ProductionSup(int s, int a, int b, float pr, char *out); 138 | 139 | void print(); 140 | char tipo(); 141 | void mergeRegions(Hypothesis *a, Hypothesis *b, Hypothesis *s); 142 | void print_mathml(Grammar *G, Hypothesis *H, FILE *fout, int *nid); 143 | }; 144 | 145 | 146 | //Production: S -> A _ B 147 | class ProductionSub : public ProductionB{ 148 | 149 | public: 150 | ProductionSub(int s, int a, int b); 151 | ProductionSub(int s, int a, int b, float pr, char *out); 152 | 153 | void print(); 154 | char tipo(); 155 | void mergeRegions(Hypothesis *a, Hypothesis *b, Hypothesis *s); 156 | void print_mathml(Grammar *G, Hypothesis *H, FILE *fout, int *nid); 157 | }; 158 | 159 | 160 | //Production: S -> A ins B 161 | class ProductionIns : public ProductionB{ 162 | 163 | public: 164 | ProductionIns(int s, int a, int b); 165 | ProductionIns(int s, int a, int b, float pr, char *out); 166 | 167 | void print(); 168 | char tipo(); 169 | void mergeRegions(Hypothesis *a, Hypothesis *b, Hypothesis *s); 170 | void print_mathml(Grammar *G, Hypothesis *H, FILE *fout, int *nid); 171 | }; 172 | 173 | 174 | //Production: S -> A mroot B 175 | class ProductionMrt : public ProductionB{ 176 | 177 | public: 178 | ProductionMrt(int s, int a, int b); 179 | ProductionMrt(int s, int a, int b, float pr, char *out); 180 | 181 | void print(); 182 | char tipo(); 183 | void mergeRegions(Hypothesis *a, Hypothesis *b, Hypothesis *s); 184 | void print_mathml(Grammar *G, Hypothesis *H, FILE *fout, int *nid); 185 | }; 186 | 187 | 188 | //Production S -> term ( N clases ) 189 | class ProductionT{ 190 | int S; 191 | bool *clases; 192 | char **texStr; 193 | char *mltype; 194 | float *probs; 195 | int N; 196 | 197 | public: 198 | ProductionT(int s, int nclases); 199 | ~ProductionT(); 200 | 201 | void setClase(int k, float pr, char *tex, char mlt); 202 | bool getClase(int k); 203 | float getPrior(int k); 204 | char *getTeX(int k); 205 | char getMLtype(int k); 206 | int getNoTerm(); 207 | void print(); 208 | }; 209 | 210 | #endif 211 | -------------------------------------------------------------------------------- /rnnlib4seshat/ActivationFunctions.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2009,2010 Alex Graves 2 | 3 | This file is part of RNNLIB. 4 | 5 | RNNLIB is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | RNNLIB is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with RNNLIB. If not, see .*/ 17 | 18 | #ifndef _INCLUDED_ActivationFunctions_h 19 | #define _INCLUDED_ActivationFunctions_h 20 | 21 | #include 22 | #include 23 | #include 24 | #include "Log.hpp" 25 | 26 | // Implement logistic function 27 | // f(x) = 1 / (1 + exp(-x)) 28 | struct Logistic { 29 | static real_t fn(real_t x) { 30 | if (x < Log::expLimit) { 31 | if (x > -Log::expLimit) { 32 | return 1.0 / (1.0 + exp(-x)); 33 | } 34 | return 0; 35 | } 36 | return 1; 37 | } 38 | static real_t deriv(real_t y) { 39 | return y*(1.0-y); 40 | } 41 | static real_t second_deriv(real_t y) { 42 | return deriv(y) * (1 - (2 * y)); 43 | } 44 | }; 45 | // Implements a soft version of the sign function with 46 | // first and second derivatives. 47 | // f(x) = x / (1 + |x|) 48 | struct Softsign { 49 | static real_t fn(real_t x) { 50 | if (x < realMax) { 51 | if (x > -realMax) { 52 | return (x/(1 + fabs(x))); 53 | } 54 | return -1; 55 | } 56 | return 1; 57 | } 58 | static real_t deriv(real_t y) { 59 | return squared(1 - fabs(y)); 60 | } 61 | static real_t second_deriv(real_t y) { 62 | return -2 * sign(y) * pow((1 - fabs(y)), 3); 63 | } 64 | }; 65 | // Identity activation function 66 | // f(n) = x 67 | // f'(n) = 1 68 | // f''(n) = 0 69 | struct Identity { 70 | static real_t fn(real_t x) { 71 | return x; 72 | } 73 | static real_t deriv(real_t y) { 74 | return 1; 75 | } 76 | static real_t second_deriv(real_t y) { 77 | return 0; 78 | } 79 | }; 80 | // Logistic unit in the range [-2, 2] 81 | struct Maxmin2 { 82 | static real_t fn(real_t x) { 83 | return (4 * Logistic::fn(x)) - 2; 84 | } 85 | static real_t deriv(real_t y) { 86 | if (y == -2 || y == 2) { 87 | return 0; 88 | } 89 | return (4 - (y * y)) / 4.0; 90 | } 91 | static real_t second_deriv(real_t y) { 92 | return -deriv(y) * 0.5 * y; 93 | } 94 | }; 95 | // Logistic unit in the range [-1, 1] 96 | struct Maxmin1 { 97 | static real_t fn(real_t x) { 98 | return (2 * Logistic::fn(x)) - 1; 99 | } 100 | static real_t deriv(real_t y) { 101 | if (y == -1 || y == 1) { 102 | return 0; 103 | } 104 | return (1.0 - (y * y)) / 2.0; 105 | } 106 | static real_t second_deriv(real_t y) { 107 | return -deriv(y) * y; 108 | } 109 | }; 110 | // Logistic unit in the range [0, 2] 111 | struct Max2min0 { 112 | static real_t fn(real_t x) { 113 | return (2 * Logistic::fn(x)); 114 | } 115 | static real_t deriv(real_t y) { 116 | if (y == -1 || y == 1) { 117 | return 0; 118 | } 119 | return y * (1 - (y / 2.0)); 120 | } 121 | static real_t second_deriv(real_t y) { 122 | return deriv(y) * (1 - y); 123 | } 124 | }; 125 | struct Tanh { 126 | static real_t fn(real_t x) { 127 | return Maxmin1::fn(2*x); 128 | } 129 | static real_t deriv(real_t y) { 130 | return 1.0 - (y * y); 131 | } 132 | static real_t second_deriv(real_t y) { 133 | return -2 * deriv(y) * y; 134 | } 135 | }; 136 | 137 | #endif 138 | -------------------------------------------------------------------------------- /rnnlib4seshat/BiasLayer.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #ifndef _INCLUDED_BiasLayer_h 41 | #define _INCLUDED_BiasLayer_h 42 | 43 | #include 44 | #include "Layer.hpp" 45 | 46 | struct BiasLayer: public Layer { 47 | // data 48 | View acts; 49 | View errors; 50 | 51 | // functions 52 | BiasLayer(WeightContainer *wc, DataExportHandler *deh) : 53 | Layer("bias", 0, 0, 1, wc, deh), acts(this->outputActivations[0]), 54 | errors(this->outputErrors[0]) { 55 | acts.front() = 1; 56 | } 57 | 58 | ~BiasLayer() {} 59 | 60 | const View out_acts(const vector& coords) { 61 | return acts; 62 | } 63 | 64 | const View out_errs(const vector& coords) { 65 | return errors; 66 | } 67 | }; 68 | 69 | #endif 70 | -------------------------------------------------------------------------------- /rnnlib4seshat/BlockLayer.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #ifndef _INCLUDED_BlockLayer_h 41 | #define _INCLUDED_BlockLayer_h 42 | 43 | #include 44 | #include 45 | #include 46 | #include "Layer.hpp" 47 | 48 | struct BlockLayer: public Layer { 49 | // data 50 | vector blockShape; 51 | vector blockOffset; 52 | vector inCoords; 53 | size_t sourceSize; 54 | CVI blockIterator; 55 | vector outSeqShape; 56 | 57 | // functions 58 | BlockLayer(Layer* src, const vector& blockshape, WeightContainer *weight, DataExportHandler *deh): 59 | Layer( 60 | src->name + "_block", src->num_seq_dims(), 0, 61 | product(blockshape) * src->output_size(), weight, deh, src), 62 | blockShape(blockshape), 63 | blockOffset(this->num_seq_dims()), 64 | inCoords(this->num_seq_dims()), 65 | sourceSize(src->outputActivations.depth), 66 | blockIterator(blockShape), 67 | outSeqShape(this->num_seq_dims()) { 68 | assert(blockShape.size() == this->num_seq_dims()); 69 | assert(!in(blockShape, 0)); 70 | wc->link_layers( 71 | this->source->name, this->name, 72 | this->source->name + "_to_" + this->name); 73 | display(this->outputActivations, "activations"); 74 | display(this->outputErrors, "errors"); 75 | } 76 | void print(ostream& out = cout) const { 77 | Layer::print(out); 78 | out << " block " << blockShape; 79 | } 80 | void start_sequence() { 81 | for (int i = 0; i < outSeqShape.size(); ++i) { 82 | outSeqShape.at(i) = ceil( 83 | (real_t)this->source->output_seq_shape().at(i) / 84 | (real_t)blockShape.at(i)); 85 | } 86 | outputActivations.reshape(outSeqShape, 0); 87 | outputErrors.reshape(outputActivations, 0); 88 | } 89 | void feed_forward(const vector& outCoords) { 90 | real_t* outIt = this->outputActivations[outCoords].begin(); 91 | range_multiply(blockOffset, outCoords, blockShape); 92 | for (blockIterator.begin(); !blockIterator.end; ++blockIterator) { 93 | range_plus(inCoords, *blockIterator, blockOffset); 94 | View inActs = this->source->outputActivations.at(inCoords); 95 | if (inActs.begin()) { 96 | copy(inActs.begin(), inActs.end(), outIt); 97 | } else { 98 | fill(outIt, outIt + sourceSize, 0); 99 | } 100 | outIt += sourceSize; 101 | } 102 | } 103 | void feed_back(const vector& outCoords) { 104 | const real_t* outIt = this->outputErrors[outCoords].begin(); 105 | range_multiply(blockOffset, outCoords, blockShape); 106 | for (blockIterator.begin(); !blockIterator.end; ++blockIterator) { 107 | range_plus(inCoords, *blockIterator, blockOffset); 108 | real_t* inErr = this->source->outputErrors.at(inCoords).begin(); 109 | if (inErr) { 110 | transform(outIt, outIt + sourceSize, inErr, inErr, plus()); 111 | } 112 | outIt += sourceSize; 113 | } 114 | } 115 | }; 116 | 117 | #endif 118 | -------------------------------------------------------------------------------- /rnnlib4seshat/ClassificationLayer.cpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #include "ClassificationLayer.hpp" 41 | 42 | ClassificationLayer* make_classification_layer(ostream& out, const string& name, size_t numSeqDims, const vector& labels, WeightContainer *weight, DataExportHandler *deh) 43 | { 44 | assert(labels.size() >= 2); 45 | if (labels.size() == 2) 46 | { 47 | return new BinaryClassificationLayer(out, name, numSeqDims, labels, weight, deh); 48 | } 49 | return new MulticlassClassificationLayer(out, name, numSeqDims, labels, weight, deh); 50 | } 51 | -------------------------------------------------------------------------------- /rnnlib4seshat/CollapseLayer.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #ifndef _INCLUDED_CollapseLayer_h 41 | #define _INCLUDED_CollapseLayer_h 42 | 43 | #include "Layer.hpp" 44 | 45 | struct CollapseLayer: public Layer 46 | { 47 | //data 48 | vector activeDims; 49 | vector outSeqShape; 50 | 51 | //functions 52 | CollapseLayer(Layer* src, Layer* des, WeightContainer *weight, DataExportHandler *deh, const vector& activDims = empty_list_of()): 53 | Layer(des->name + "_collapse", des->directions, des->input_size(), des->input_size(), weight, deh, src), 54 | activeDims(activDims) 55 | { 56 | activeDims.resize(src->num_seq_dims(), false); 57 | assert(count(activDims, true) == des->num_seq_dims()); 58 | //DISPLAY(inputActivations); 59 | //DISPLAY(inputErrors); 60 | //DISPLAY(outputActivations); 61 | //DISPLAY(outputErrors); 62 | } 63 | virtual void start_sequence() 64 | { 65 | outSeqShape.clear(); 66 | for (int i = 0; i < activeDims.size(); ++i) 67 | { 68 | if (activeDims[i]) 69 | { 70 | outSeqShape += source->output_seq_shape()[i]; 71 | } 72 | } 73 | assert(outSeqShape.size() == num_seq_dims()); 74 | inputActivations.reshape(source->output_seq_shape(), 0); 75 | outputActivations.reshape(outSeqShape, 0); 76 | reshape_errors(); 77 | } 78 | vector get_out_coords(const vector& inCoords) 79 | { 80 | vector outCoords; 81 | assert(inCoords.size() == activeDims.size()); 82 | for (int i = 0; i < inCoords.size(); ++i) 83 | { 84 | if (activeDims[i]) 85 | { 86 | outCoords += inCoords[i]; 87 | } 88 | } 89 | assert(outCoords.size() == num_seq_dims()); 90 | return outCoords; 91 | } 92 | void feed_forward(const vector& coords) 93 | { 94 | range_plus_equals(this->outputActivations[get_out_coords(coords)], inputActivations[coords]); 95 | } 96 | void feed_back(const vector& coords) 97 | { 98 | copy(outputErrors[get_out_coords(coords)], inputErrors[coords]); 99 | } 100 | }; 101 | 102 | #endif 103 | 104 | -------------------------------------------------------------------------------- /rnnlib4seshat/ConfigFile.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2009,2010 Alex Graves 2 | 3 | This file is part of RNNLIB. 4 | 5 | RNNLIB is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | RNNLIB is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with RNNLIB. If not, see .*/ 17 | 18 | #ifndef _INCLUDED_ConfigFile_h 19 | #define _INCLUDED_ConfigFile_h 20 | 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include "Helpers.hpp" 28 | #include "String.hpp" 29 | 30 | using namespace std; 31 | 32 | struct ConfigFile { 33 | //data 34 | set used; 35 | map params; 36 | string filename; 37 | 38 | //functions 39 | ConfigFile(const string& fname, char readLineChar = '_'): filename(fname) { 40 | ifstream instream(filename.c_str()); 41 | check(instream.is_open(), "could not open config file \"" + filename + "\""); 42 | string name; 43 | string val; 44 | while(instream >> name && instream >> val) { 45 | string line; 46 | getline(instream, line); 47 | if(name[0] != '#') { 48 | if (in(name, readLineChar) && line.size() > 1) { 49 | val += line; 50 | } 51 | params[name] = val; 52 | } 53 | } 54 | } 55 | bool contains(const string& name) const { 56 | return in(params, name); 57 | } 58 | bool remove(const string& name) { 59 | if (contains(name)) { 60 | params.erase(name); 61 | used.erase(name); 62 | return true; 63 | } 64 | return false; 65 | } 66 | template const T& set_val( 67 | const string& name, const T& val, bool valUsed = true) { 68 | stringstream ss; 69 | ss << boolalpha << val; 70 | params[name] = ss.str(); 71 | if (valUsed) { 72 | used.insert(name); 73 | } 74 | return val; 75 | } 76 | template T get(const string& name, const T& defaultVal) { 77 | MSSI it = params.find(name); 78 | if (it == params.end()) { 79 | set_val(name, defaultVal); 80 | return defaultVal; 81 | } 82 | return get(name); 83 | } 84 | template T get(const string& name) { 85 | MSSCI it = params.find(name); 86 | check(it != params.end(), "param '" + name + "' not found in config file '" + filename); 87 | used.insert(name); 88 | return read(it->second); 89 | } 90 | template Vector get_list( 91 | const string& name, const char delim = ',') { 92 | Vector vect; 93 | MSSCI it = params.find(name); 94 | if (it != params.end()) { 95 | vect = split_with_repeat(it->second, delim); 96 | used.insert(name); 97 | } 98 | return vect; 99 | } 100 | template Vector get_list( 101 | const string& name, const T& defaultVal, size_t length, 102 | const char delim = ',') { 103 | Vector vect = get_list(name, delim); 104 | vect.resize(length, vect.size() == 1 ? vect.front() : defaultVal); 105 | used.insert(name); 106 | return vect; 107 | } 108 | template Vector > get_array( 109 | const string& name, const char delim1 = ';', const char delim2 = ',') { 110 | Vector > array; 111 | MSSCI it = params.find(name); 112 | if (it != params.end()) { 113 | LOOP(const string& row, split(it->second, delim1)) { 114 | array += split_with_repeat(row, delim2); 115 | } 116 | used.insert(name); 117 | } 118 | return array; 119 | } 120 | template Vector > get_array( 121 | const string& name, const string& defaultStr, size_t length, 122 | const char delim1 = ';', const char delim2 = ',') { 123 | Vector > array = get_array(name, delim1, delim2); 124 | array.resize( 125 | length, array.size() == 1 ? 126 | array.front() : split_with_repeat(defaultStr, delim2)); 127 | used.insert(name); 128 | return array; 129 | } 130 | void warn_unused(ostream& out, bool removeUnused = true) { 131 | Vector unused; 132 | LOOP(const PSS& p, params) { 133 | if (!in(used, p.first)) { 134 | unused += p.first; 135 | } 136 | } 137 | if (unused.size()) { 138 | LOOP(string& s, unused) { 139 | out << "WARNING: " << s << " in config but never used" << endl; 140 | if (removeUnused) { 141 | params.erase(s); 142 | } 143 | } 144 | out << endl; 145 | } 146 | } 147 | }; 148 | 149 | static ostream& operator << (ostream& out, const ConfigFile& conf) { 150 | out << conf.params; 151 | return out; 152 | } 153 | 154 | #endif 155 | -------------------------------------------------------------------------------- /rnnlib4seshat/Connection.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2009,2010 Alex Graves 2 | 3 | This file is part of RNNLIB. 4 | 5 | RNNLIB is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | RNNLIB is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with RNNLIB. If not, see .*/ 17 | 18 | #ifndef _INCLUDED_Connection_h 19 | #define _INCLUDED_Connection_h 20 | 21 | #include 22 | #include "Named.hpp" 23 | 24 | struct Connection: public Named 25 | { 26 | //data 27 | Layer* from; 28 | Layer* to; 29 | 30 | //functions 31 | Connection(const string& name, Layer* f, Layer* t): 32 | Named(name), 33 | from(f), 34 | to(t) 35 | { 36 | assert(from); 37 | assert(to); 38 | assert(from->output_size()); 39 | assert(to->input_size()); 40 | } 41 | virtual ~Connection(){} 42 | virtual size_t num_weights() const {return 0;} 43 | virtual void feed_forward(const vector& coords){} 44 | virtual void feed_back(const vector& coords){} 45 | virtual void update_derivs(const vector& coords){} 46 | virtual void print(ostream& out) const{} 47 | virtual const View weights(){return View();} 48 | }; 49 | static ostream& operator <<(ostream& out, const Connection& c) 50 | { 51 | c.print(out); 52 | return out; 53 | } 54 | 55 | #endif 56 | -------------------------------------------------------------------------------- /rnnlib4seshat/Container.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #ifndef _INCLUDED_Container_h 41 | #define _INCLUDED_Container_h 42 | 43 | template struct Vector; 44 | 45 | template struct View: public sub_range > { 46 | View(pair& p): sub_range >(p) {} 47 | View(T* first = 0, T* second = 0): 48 | sub_range >(make_pair(first, second)) {} 49 | View slice(int first = 0, int last = numeric_limits::max()) { 50 | first = bound(first, 0, (int)this->size()); 51 | if (last < 0) { 52 | last += (int)this->size(); 53 | } 54 | last = bound(last, first, (int)this->size()); 55 | return View(&((*this)[first]), &((*this)[last])); 56 | } 57 | View slice(pair& r) { 58 | return slice(r.first, r.second); 59 | } 60 | const View slice(int first = 0, int last = numeric_limits::max()) const { 61 | return slice(first, last); 62 | } 63 | const View slice(pair& r) const { 64 | return slice(r.first, r.second); 65 | } 66 | T& at(size_t i) { 67 | check(i < this->size(), "at(" + str(i) + ") called for view of size " + 68 | str(this->size())); 69 | return (*this)[i]; 70 | } 71 | const T& at(size_t i) const { 72 | check(i < this->size(), "at(" + str(i) + ") called for view of size " + 73 | str(this->size())); 74 | return (*this)[i]; 75 | } 76 | template const View& operator =(const R& r) const { 77 | check(boost::size(r) == this->size(), "can't assign range " + str(r) + 78 | " to view " + str(*this)); 79 | copy(r, *this); 80 | return *this; 81 | } 82 | template Vector to() const { 83 | Vector v; 84 | LOOP(const T& t, *this) { 85 | v += lexical_cast(t); 86 | } 87 | return v; 88 | } 89 | }; 90 | 91 | template struct Vector: public vector { 92 | Vector() { } 93 | Vector(const vector& v): vector(v) {} 94 | Vector(const View& v) { 95 | *this = v; 96 | } 97 | Vector(size_t n): vector(n) {} 98 | Vector(size_t n, const T& t): vector(n, t) {} 99 | Vector& grow(size_t length) { 100 | this->resize(this->size() + length); 101 | return *this; 102 | } 103 | Vector& shrink(size_t length) { 104 | this->resize(max((size_t)0, this->size() - length)); 105 | return *this; 106 | } 107 | void push_front(const T& t) { 108 | this->insert(this->begin(), t); 109 | } 110 | T& pop_front() { 111 | T& front = front(); 112 | erase(this->begin()); 113 | return front; 114 | } 115 | View slice(int first = 0, int last = numeric_limits::max()) { 116 | first = bound(first, 0, (int)this->size()); 117 | if (last < 0) { 118 | last += (int)this->size(); 119 | } 120 | last = bound(last, first, (int)this->size()); 121 | return View(&((*this)[first]), &((*this)[last])); 122 | } 123 | View slice(pair& r) { 124 | return slice(r.first, r.second); 125 | } 126 | const View slice( 127 | int first = 0, int last = numeric_limits::max()) const { 128 | return slice(first, last); 129 | } 130 | const View slice(pair& r) const { 131 | return slice(r.first, r.second); 132 | } 133 | template Vector& extend(const R& r) { 134 | size_t oldSize = this->size(); 135 | grow(boost::size(r)); 136 | copy(boost::begin(r), boost::end(r), this->begin() + oldSize); 137 | return *this; 138 | } 139 | Vector replicate(size_t times) const { 140 | Vector v; 141 | REPEAT(times) { 142 | v.extend(*this); 143 | } 144 | return v; 145 | } 146 | template Vector& operator =(const R& r) { 147 | vector::resize(boost::size(r)); 148 | copy(r, *this); 149 | return *this; 150 | } 151 | template Vector to() const { 152 | Vector v; 153 | LOOP(const T& t, *this) { 154 | v += lexical_cast(t); 155 | } 156 | return v; 157 | } 158 | }; 159 | 160 | template struct Set: public set { 161 | Set() {} 162 | Set(const vector& v) { 163 | *this = v; 164 | } 165 | Set(const View& v) { 166 | *this = v; 167 | } 168 | Set& operator +=(const T& val) { 169 | this->insert(val); 170 | return *this; 171 | } 172 | template Set& operator =(const R& r) { 173 | this->clear(); 174 | return this->extend(r); 175 | } 176 | template Set& extend(const R& r) { 177 | LOOP(const typename boost::range_value::type& val, r) { 178 | (*this) += val; 179 | } 180 | return *this; 181 | } 182 | }; 183 | 184 | #endif 185 | -------------------------------------------------------------------------------- /rnnlib4seshat/CopyConnection.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #ifndef _INCLUDED_CopyConnection_h 41 | #define _INCLUDED_CopyConnection_h 42 | 43 | #include "Connection.hpp" 44 | 45 | struct CopyConnection: public Connection 46 | { 47 | //functions 48 | CopyConnection(Layer* f, Layer* t, WeightContainer *weight): 49 | Connection(f->name + "_to_" + t->name, f, t) 50 | { 51 | assert(this->from != this->to); 52 | assert(this->from->output_size() == this->to->input_size()); 53 | assert(this->from->output_size()); 54 | this->to->source = this->from; 55 | weight->link_layers(this->from->name, this->to->name); 56 | } 57 | virtual ~CopyConnection(){} 58 | void feed_forward(const vector& coords) 59 | { 60 | range_plus_equals(this->to->inputActivations[coords], this->from->outputActivations[coords]); 61 | } 62 | void feed_back(const vector& coords) 63 | { 64 | range_plus_equals(this->from->outputErrors[coords], this->to->inputErrors[coords]); 65 | } 66 | void print(ostream& out) const 67 | { 68 | Named::print(out); 69 | out << " (copy)"; 70 | } 71 | }; 72 | 73 | #endif 74 | -------------------------------------------------------------------------------- /rnnlib4seshat/DataExporter.cpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2009,2010 Alex Graves 2 | 3 | This file is part of RNNLIB. 4 | 5 | RNNLIB is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | RNNLIB is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with RNNLIB. If not, see .*/ 17 | 18 | #include "DataExporter.hpp" 19 | 20 | #include 21 | 22 | void DataExportHandler::save(ostream& out) const { 23 | LOOP(const PSPDE& exp, dataExporters) { 24 | out << *(exp.second); 25 | } 26 | } 27 | 28 | void DataExportHandler::load(ConfigFile& conf, ostream& out) { 29 | LOOP(PSPDE& exp, dataExporters) { 30 | if (!exp.second->load(conf, out)) { 31 | out << " for '" << exp.first << "' in config file " << conf.filename 32 | << ", exiting" << endl; 33 | exit(0); 34 | } 35 | } 36 | } 37 | 38 | void DataExportHandler::display(const string& path) const { 39 | LOOP(const PSPDE& exp, dataExporters) { 40 | LOOP(const PSPV& val, exp.second->displayVals) { 41 | string filename = path + exp.first + "_" + val.first; 42 | ofstream out(filename.c_str()); 43 | check(out.is_open(), "couldn't open display file " + filename + 44 | " for writing"); 45 | out << *(val.second); 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /rnnlib4seshat/DataSequence.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #ifndef _INCLUDED_DataSequence_h 41 | #define _INCLUDED_DataSequence_h 42 | 43 | #include 44 | #include 45 | #include 46 | #include 47 | #include 48 | #include "Helpers.hpp" 49 | #include "SeqBuffer.hpp" 50 | 51 | template static string label_seq_to_str(const R& labelSeq, const vector& alphabet, const string& delim = " ") 52 | { 53 | stringstream ss; 54 | for (typename range_const_iterator::type it = boost::begin(labelSeq); it != boost::end(labelSeq); ++it) 55 | { 56 | if (in_range(alphabet,*it)) 57 | { 58 | ss << alphabet[*it]; 59 | } 60 | else 61 | { 62 | ss << ""; 63 | } 64 | if (it != --boost::end(labelSeq)) 65 | { 66 | ss << delim; 67 | } 68 | } 69 | return ss.str(); 70 | } 71 | static vector str_to_label_seq(const string& labelSeqString, const vector& alphabet) 72 | { 73 | static vector v; 74 | v.clear(); 75 | stringstream ss(labelSeqString); 76 | string lab; 77 | while(ss >> lab) 78 | { 79 | /* check(in_right(alphabet, lab), lab + " not found in alphabet");*/ 80 | // if (warn_unless(in_right(alphabet, lab), lab + " not found in alphabet")) 81 | int i = index(alphabet, lab); 82 | if (i != alphabet.size()) 83 | { 84 | v += i; 85 | } 86 | } 87 | return v; 88 | } 89 | 90 | struct DataSequence 91 | { 92 | //data 93 | SeqBuffer inputs; 94 | SeqBuffer inputClasses; 95 | SeqBuffer targetPatterns; 96 | SeqBuffer targetClasses; 97 | SeqBuffer importance; 98 | vector targetLabelSeq; 99 | vector targetWordSeq; 100 | string tag; 101 | 102 | //functions 103 | DataSequence(const DataSequence& ds): 104 | inputs(ds.inputs), 105 | inputClasses(ds.inputClasses), 106 | targetPatterns(ds.targetPatterns), 107 | targetClasses(ds.targetClasses), 108 | importance(ds.importance), 109 | targetLabelSeq(ds.targetLabelSeq), 110 | tag(ds.tag) 111 | { 112 | } 113 | DataSequence(size_t inputDepth = 0, size_t targetPattDepth = 0): 114 | inputs(inputDepth), 115 | inputClasses(0), 116 | targetPatterns(targetPattDepth), 117 | targetClasses(0), 118 | importance(0) 119 | { 120 | } 121 | size_t num_timesteps() const 122 | { 123 | return inputs.seq_size(); 124 | } 125 | void print(ostream& out, vector* targetLabels = 0, vector* inputLabels = 0) const 126 | { 127 | PRINT(tag, out); 128 | out << "input shape = (" << inputs.shape << ")" << endl; 129 | out << "timesteps = " << inputs.seq_size() << endl; 130 | if (targetLabelSeq.size() && targetLabels) 131 | { 132 | out << "target label sequence:" << endl; 133 | out << label_seq_to_str(this->targetLabelSeq, *targetLabels) << endl; 134 | } 135 | if (targetPatterns.size()) 136 | { 137 | out << "target shape = (" << targetPatterns.shape << ")" << endl; 138 | } 139 | if (verbose) 140 | { 141 | if(targetClasses.size() && targetLabels) 142 | { 143 | out << label_seq_to_str(this->targetClasses.data, *targetLabels) << endl; 144 | } 145 | if(inputClasses.size() && inputLabels) 146 | { 147 | out << label_seq_to_str(this->inputClasses.data, *inputLabels) << endl; 148 | } 149 | } 150 | } 151 | }; 152 | static ostream& operator <<(ostream& out, const DataSequence& seq) 153 | { 154 | seq.print(out); 155 | return out; 156 | } 157 | 158 | #endif 159 | -------------------------------------------------------------------------------- /rnnlib4seshat/FullConnection.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #ifndef _INCLUDED_FullConnection_h 41 | #define _INCLUDED_FullConnection_h 42 | 43 | #include 44 | #include 45 | #include 46 | #include 47 | #include "Layer.hpp" 48 | #include "WeightContainer.hpp" 49 | #include "Helpers.hpp" 50 | #include "DataExporter.hpp" 51 | #include "Connection.hpp" 52 | #include "Matrix.hpp" 53 | 54 | struct FullConnection: public Connection 55 | { 56 | //data 57 | vector delay; 58 | vector delayedCoords; 59 | FullConnection* source; 60 | pair paramRange; 61 | WeightContainer *wc; 62 | 63 | //functions 64 | FullConnection(Layer* f, Layer* t, WeightContainer *weight, const vector& d = empty_list_of(), FullConnection* s = 0): 65 | Connection(make_name(f, t, d), f, t), 66 | source(s) 67 | //paramRange(source ? source->paramRange : wc->new_parameters(this->from->output_size() * this->to->input_size(), this->from->name, this->to->name, name)) 68 | { 69 | wc = weight; 70 | if (source) 71 | { 72 | paramRange = source->paramRange; 73 | wc->link_layers(this->from->name, this->to->name, this->name, paramRange.first, paramRange.second); 74 | } 75 | else 76 | paramRange = wc->new_parameters(this->from->output_size() * this->to->input_size(), this->from->name, this->to->name, name); 77 | set_delay(d); 78 | assert(num_weights() == (this->from->output_size() * this->to->input_size())); 79 | if (this->from->name != "bias" && this->from != this->to && !this->to->source) 80 | { 81 | this->to->source = this->from; 82 | } 83 | } 84 | ~FullConnection(){} 85 | void set_delay(const vector& d) 86 | { 87 | delay = d; 88 | assert(delay.size() == 0 || delay.size() == this->from->num_seq_dims()); 89 | delayedCoords.resize(delay.size()); 90 | } 91 | static const string& make_name(Layer* f, Layer* t, const vector& d) 92 | { 93 | static string name; 94 | name = f->name + "_to_" + t->name; 95 | if (find_if(d.begin(), d.end(), std::bind2nd(not_equal_to(), 0)) != d.end()) 96 | { 97 | stringstream temp; 98 | temp << "_delay_"; 99 | copy(d.begin(), d.end()-1, ostream_iterator(temp, "_")); 100 | temp << d.back(); 101 | name += temp.str(); 102 | } 103 | return name; 104 | } 105 | const View weights() 106 | { 107 | return wc->get_weights(paramRange); 108 | } 109 | const View derivs() 110 | { 111 | return wc->get_derivs(paramRange); 112 | } 113 | size_t num_weights() const 114 | { 115 | return difference(paramRange); 116 | } 117 | const vector* add_delay(const vector& toCoords) 118 | { 119 | if (delay.empty()) 120 | { 121 | return &toCoords; 122 | } 123 | range_plus(delayedCoords, toCoords, delay); 124 | if (this->from->outputActivations.in_range(delayedCoords)) 125 | { 126 | return &delayedCoords; 127 | } 128 | return 0; 129 | } 130 | void feed_forward(const vector& toCoords) 131 | { 132 | const vector* fromCoords = add_delay(toCoords); 133 | if (fromCoords) 134 | { 135 | dot(this->from->out_acts(*fromCoords), weights().begin(), this->to->inputActivations[toCoords]); 136 | } 137 | } 138 | void feed_back(const vector& toCoords) 139 | { 140 | const vector* fromCoords = add_delay(toCoords); 141 | if (fromCoords) 142 | { 143 | dot_transpose(this->to->inputErrors[toCoords], weights().begin(), this->from->out_errs(*fromCoords)); 144 | } 145 | } 146 | void update_derivs(const vector& toCoords) 147 | { 148 | const vector* fromCoords = add_delay(toCoords); 149 | if (fromCoords) 150 | { 151 | outer(this->from->out_acts(*fromCoords), derivs().begin(), this->to->inputErrors[toCoords]); 152 | } 153 | } 154 | void print(ostream& out) const 155 | { 156 | Named::print(out); 157 | out << " (" << num_weights() << " wts"; 158 | if (source) 159 | { 160 | out << " shared with " << source->name; 161 | } 162 | out << ")"; 163 | } 164 | }; 165 | 166 | #endif 167 | -------------------------------------------------------------------------------- /rnnlib4seshat/GatherLayer.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #ifndef _INCLUDED_GatherLayer_h 41 | #define _INCLUDED_GatherLayer_h 42 | 43 | #include "Layer.hpp" 44 | 45 | struct GatherLayer: public Layer 46 | { 47 | //data 48 | vector sources; 49 | 50 | //functions 51 | GatherLayer(const string& name, vector& srcs, WeightContainer *weight, DataExportHandler *deh): 52 | Layer(name, srcs.front()->num_seq_dims(), 0, get_size(srcs), weight, deh, srcs.front()), 53 | sources(srcs) 54 | { 55 | source = sources.front(); 56 | wc->new_parameters(0, source->name, name, source->name + "_to_" + name); 57 | display(outputActivations, "activations"); 58 | display(outputErrors, "errors"); 59 | } 60 | int get_size(vector& srcs) 61 | { 62 | int size = 0; 63 | for (int i = 0; i < srcs.size(); ++i) 64 | { 65 | size += srcs[i]->output_size(); 66 | } 67 | return size; 68 | } 69 | void feed_forward(const vector& outCoords) 70 | { 71 | real_t* actBegin = outputActivations[outCoords].begin(); 72 | LOOP(Layer* l, sources) 73 | { 74 | View inActs = l->outputActivations[outCoords]; 75 | copy(inActs.begin(), inActs.end(), actBegin); 76 | actBegin += inActs.size(); 77 | } 78 | } 79 | void feed_back(const vector& outCoords) 80 | { 81 | real_t* errBegin = outputErrors[outCoords].begin(); 82 | LOOP(Layer* l, sources) 83 | { 84 | View inErrs = l->outputErrors[outCoords]; 85 | int dist = inErrs.size(); 86 | copy(errBegin, errBegin + dist, inErrs.begin()); 87 | errBegin += dist; 88 | } 89 | } 90 | }; 91 | 92 | #endif 93 | -------------------------------------------------------------------------------- /rnnlib4seshat/IdentityLayer.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #ifndef _INCLUDED_IdentityLayer_h 41 | #define _INCLUDED_IdentityLayer_h 42 | 43 | #include "Layer.hpp" 44 | #include "Helpers.hpp" 45 | 46 | struct IdentityLayer: public FlatLayer 47 | { 48 | //functions 49 | IdentityLayer(const string& name, size_t numSeqDims, size_t size, WeightContainer *weight, DataExportHandler *deh): 50 | FlatLayer(name, numSeqDims, size, weight, deh) 51 | { 52 | display(this->outputErrors, "errors"); 53 | display(this->outputActivations, "activations"); 54 | } 55 | IdentityLayer(const string& name, const vector& directions, size_t size, WeightContainer *weight, DataExportHandler *deh): 56 | FlatLayer(name, directions, size, weight, deh) 57 | { 58 | display(this->outputErrors, "errors"); 59 | display(this->outputActivations, "activations"); 60 | } 61 | void feed_forward(const vector& coords) 62 | { 63 | copy(this->inputActivations[coords], this->outputActivations[coords]); 64 | } 65 | void feed_back(const vector& coords) 66 | { 67 | copy(this->outputErrors[coords], this->inputErrors[coords]); 68 | } 69 | }; 70 | 71 | #endif 72 | -------------------------------------------------------------------------------- /rnnlib4seshat/InputLayer.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #ifndef _INCLUDED_InputLayer_h 41 | #define _INCLUDED_InputLayer_h 42 | 43 | #include "Layer.hpp" 44 | 45 | struct InputLayer: public Layer 46 | { 47 | //functions 48 | InputLayer(const string& name, size_t numSeqDims, size_t size, const vector& inputLabels, WeightContainer *weight, DataExportHandler *deh): 49 | Layer(name, numSeqDims, 0, size, weight, deh) 50 | { 51 | //const vector* labs = inputLabels.empty() ? 0 : &inputLabels; 52 | //display(this->outputActivations, "activations", labs); 53 | //display(this->outputErrors, "errors", labs); 54 | } 55 | ~InputLayer(){} 56 | template void copy_inputs(const SeqBuffer& inputs) 57 | { 58 | assert(inputs.depth == this->output_size()); 59 | this->outputActivations = inputs; 60 | this->outputErrors.reshape(this->outputActivations, 0); 61 | } 62 | }; 63 | 64 | #endif 65 | -------------------------------------------------------------------------------- /rnnlib4seshat/Layer.cpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #include "Layer.hpp" 41 | 42 | ostream& operator <<(ostream& out, const Layer& l) { 43 | l.print(out); 44 | return out; 45 | } 46 | -------------------------------------------------------------------------------- /rnnlib4seshat/Log.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2009,2010 Alex Graves 2 | 3 | This file is part of RNNLIB. 4 | 5 | RNNLIB is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | RNNLIB is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with RNNLIB. If not, see .*/ 17 | 18 | #ifndef _INCLUDED_Log_h 19 | #define _INCLUDED_Log_h 20 | 21 | #include 22 | #include 23 | 24 | using namespace std; 25 | 26 | template class Log 27 | { 28 | //data 29 | T expVal; 30 | T logVal; 31 | 32 | public: 33 | 34 | //static data 35 | static const T expMax; 36 | static const T expMin; 37 | static const T expLimit; 38 | static const T logZero; 39 | static const T logInfinity; 40 | 41 | //static functions 42 | static T safe_exp(T x) 43 | { 44 | if (x == logZero) 45 | { 46 | return 0; 47 | } 48 | if (x >= expLimit) 49 | { 50 | return expMax; 51 | } 52 | return std::exp(x); 53 | } 54 | static T safe_log(T x) 55 | { 56 | if (x <= expMin) 57 | { 58 | return logZero; 59 | } 60 | return std::log(x); 61 | } 62 | static T log_add(T x, T y) 63 | { 64 | if (x == logZero) 65 | { 66 | return y; 67 | } 68 | if (y == logZero) 69 | { 70 | return x; 71 | } 72 | if (x < y) 73 | { 74 | swap(x, y); 75 | } 76 | return x + std::log(1.0 + safe_exp(y - x)); 77 | } 78 | static T log_subtract(T x, T y) 79 | { 80 | if (y == logZero) 81 | { 82 | return x; 83 | } 84 | if (y >= x) 85 | { 86 | return logZero; 87 | } 88 | return x + std::log(1.0 - safe_exp(y - x)); 89 | } 90 | static T log_multiply(T x, T y) 91 | { 92 | if (x == logZero || y == logZero) 93 | { 94 | return logZero; 95 | } 96 | return x + y; 97 | } 98 | static T log_divide(T x, T y) 99 | { 100 | if (x == logZero) 101 | { 102 | return logZero; 103 | } 104 | if (y == logZero) 105 | { 106 | return logInfinity; 107 | } 108 | return x - y; 109 | } 110 | 111 | //functions 112 | Log(T v = 0, bool logScale = false): 113 | expVal(logScale ? -1 : v), 114 | logVal(logScale ? v : safe_log(v)) 115 | {} 116 | Log& operator =(const Log& l) 117 | { 118 | logVal = l.logVal; 119 | expVal = l.expVal; 120 | return *this; 121 | } 122 | Log& operator +=(const Log& l) 123 | { 124 | logVal = log_add(logVal, l.logVal); 125 | expVal = -1; 126 | return *this; 127 | } 128 | Log& operator -=(const Log& l) 129 | { 130 | logVal = log_subtract(logVal, l.logVal); 131 | expVal = -1; 132 | return *this; 133 | } 134 | Log& operator *=(const Log& l) 135 | { 136 | logVal = log_multiply(logVal, l.logVal); 137 | expVal = -1; 138 | return *this; 139 | } 140 | Log& operator /=(const Log& l) 141 | { 142 | logVal = log_divide(logVal, l.logVal); 143 | expVal = -1; 144 | return *this; 145 | } 146 | T exp() 147 | { 148 | if (expVal < 0) 149 | { 150 | expVal = safe_exp(logVal); 151 | } 152 | return expVal; 153 | } 154 | T log() const 155 | { 156 | return logVal; 157 | } 158 | }; 159 | 160 | //helper functions 161 | template Log operator +(Log log1, Log log2) 162 | { 163 | return Log(Log::log_add(log1.log(), log2.log()), true); 164 | } 165 | template Log operator -(Log log1, Log log2) 166 | { 167 | return Log(Log::log_subtract(log1.log(), log2.log()), true); 168 | } 169 | template Log operator *(Log log1, Log log2) 170 | { 171 | return Log(Log::log_multiply(log1.log(), log2.log()), true); 172 | } 173 | template Log operator /(Log log1, Log log2) 174 | { 175 | return Log(Log::log_divide(log1.log(), log2.log()), true); 176 | } 177 | template bool operator >(Log log1, Log log2) 178 | { 179 | return (log1.log() > log2.log()); 180 | } 181 | template bool operator <(Log log1, Log log2) 182 | { 183 | return (log1.log() < log2.log()); 184 | } 185 | template bool operator ==(Log log1, Log log2) 186 | { 187 | return (log1.log() == log2.log()); 188 | } 189 | template bool operator <=(Log log1, Log log2) 190 | { 191 | return (log1.log() <= log2.log()); 192 | } 193 | template bool operator >=(Log log1, Log log2) 194 | { 195 | return (log1.log() >= log2.log()); 196 | } 197 | template ostream& operator <<(ostream& out, const Log& l) 198 | { 199 | out << l.log(); 200 | return out; 201 | } 202 | template istream& operator >>(istream& in, Log& l) 203 | { 204 | T d; 205 | in >> d; 206 | l = Log(d, true); 207 | return in; 208 | } 209 | 210 | template const T Log::expMax = numeric_limits::max(); 211 | template const T Log::expMin = numeric_limits::min(); 212 | template const T Log::expLimit = std::log(expMax); 213 | template const T Log::logInfinity = 1e100; 214 | template const T Log::logZero = -Log::logInfinity; 215 | 216 | #endif 217 | -------------------------------------------------------------------------------- /rnnlib4seshat/Matrix.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2009,2010 Alex Graves 2 | 3 | This file is part of RNNLIB. 4 | 5 | RNNLIB is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | RNNLIB is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with RNNLIB. If not, see .*/ 17 | 18 | #ifndef _INCLUDED_Matrix_h 19 | #define _INCLUDED_Matrix_h 20 | 21 | #define OP_TRACKING 22 | 23 | #ifdef OP_TRACKING 24 | static uint64_t matrixOps = 0; 25 | #endif 26 | 27 | // M += a * b 28 | static void outer( 29 | const real_t *aBegin, const real_t *aEnd, real_t *M, const real_t *b, 30 | const real_t *bEnd) { 31 | #ifdef OP_TRACKING 32 | const real_t* mStart = M; 33 | #endif 34 | for (; b != bEnd; ++b) { 35 | real_t input = *b; 36 | for (const real_t *a = aBegin; a != aEnd; ++a, ++M) { 37 | *M += *a * input; 38 | } 39 | } 40 | #ifdef OP_TRACKING 41 | matrixOps += M - mStart; 42 | #endif 43 | } 44 | 45 | // out += M in 46 | static void dot( 47 | const real_t *inBegin, const real_t *inEnd, const real_t *M, real_t *out, 48 | real_t *outEnd) { 49 | #ifdef OP_TRACKING 50 | const real_t* mStart = M; 51 | #endif 52 | for (; out != outEnd; ++out) { 53 | real_t sum = 0; 54 | for (const real_t *in = inBegin; in != inEnd; ++in, ++M) { 55 | sum += *M * (*in); 56 | } 57 | *out += sum; 58 | } 59 | #ifdef OP_TRACKING 60 | matrixOps += M - mStart; 61 | #endif 62 | } 63 | 64 | // out += transpose(M) in 65 | static void dot_transpose( 66 | const real_t *in, const real_t *inEnd, const real_t *M, real_t *outBegin, 67 | real_t *outEnd) { 68 | #ifdef OP_TRACKING 69 | const real_t* mStart = M; 70 | #endif 71 | for (; in != inEnd; ++in) { 72 | real_t input = *in; 73 | for (real_t *out = outBegin; out != outEnd; ++out, ++M) { 74 | *out += *M * input; 75 | } 76 | } 77 | #ifdef OP_TRACKING 78 | matrixOps += M - mStart; 79 | #endif 80 | } 81 | 82 | // out += transpose(M^2) in 83 | static void dot_transpose_m_squared( 84 | const real_t *in, const real_t *inEnd, const real_t *M, real_t *outBegin, 85 | real_t *outEnd) { 86 | #ifdef OP_TRACKING 87 | const real_t* mStart = M; 88 | #endif 89 | for (; in != inEnd; ++in) { 90 | real_t input = *in; 91 | for (real_t *out = outBegin; out != outEnd; ++out, ++M) { 92 | *out += squared(*M) * input; 93 | } 94 | } 95 | #ifdef OP_TRACKING 96 | matrixOps += M - mStart; 97 | #endif 98 | } 99 | 100 | // M += a^2 * b 101 | static void outer_a_squared( 102 | const real_t *aBegin, const real_t *aEnd, real_t *M, const real_t *b, 103 | const real_t *bEnd) { 104 | #ifdef OP_TRACKING 105 | const real_t* mStart = M; 106 | #endif 107 | for (; b != bEnd; ++b) { 108 | real_t input = *b; 109 | for (const real_t *a = aBegin; a != aEnd; ++a, ++M) { 110 | *M += squared(*a) * input; 111 | } 112 | } 113 | #ifdef OP_TRACKING 114 | matrixOps += M - mStart; 115 | #endif 116 | } 117 | 118 | template static void outer(const R& a, real_t *M, const R&b) { 119 | outer(boost::begin(a), boost::end(a), M, boost::begin(b), boost::end(b)); 120 | } 121 | 122 | template static void dot(const R& a, const real_t *M, const R& b) { 123 | dot(boost::begin(a), boost::end(a), M, boost::begin(b), boost::end(b)); 124 | } 125 | 126 | template static void dot_transpose( 127 | const R& a, const real_t *M, const R& b) { 128 | dot_transpose( 129 | boost::begin(a), boost::end(a), M, boost::begin(b), boost::end(b)); 130 | } 131 | 132 | template static void outer_a_squared( 133 | const R& a, real_t *M, const R&b) { 134 | outer_a_squared( 135 | boost::begin(a), boost::end(a), M, boost::begin(b), boost::end(b)); 136 | } 137 | 138 | template static void dot_transpose_m_squared( 139 | const R& a, const real_t *M, const R& b) { 140 | dot_transpose_m_squared( 141 | boost::begin(a), boost::end(a), M, boost::begin(b), boost::end(b)); 142 | } 143 | 144 | static real_t& elt(View M, int x, int y, int width) { 145 | return M[(y*width) + x]; 146 | } 147 | #endif 148 | -------------------------------------------------------------------------------- /rnnlib4seshat/Mdrnn.cpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #include "Mdrnn.hpp" 41 | 42 | ostream& operator << (ostream& out, const Mdrnn& net) { 43 | net.print(out); 44 | return out; 45 | } 46 | -------------------------------------------------------------------------------- /rnnlib4seshat/MultiArray.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2009,2010 Alex Graves 2 | 3 | This file is part of RNNLIB. 4 | 5 | RNNLIB is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | RNNLIB is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with RNNLIB. If not, see .*/ 17 | 18 | #ifndef _INCLUDED_MultiArray_h 19 | #define _INCLUDED_MultiArray_h 20 | 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include "Helpers.hpp" 31 | #include "Container.hpp" 32 | 33 | using namespace std; 34 | using namespace boost; 35 | using namespace boost::assign; 36 | 37 | template struct MultiArray { 38 | //data 39 | Vector data; 40 | Vector shape; 41 | vector strides; 42 | 43 | //functions 44 | MultiArray() { } 45 | MultiArray(const vector& s) { 46 | reshape(s); 47 | } 48 | MultiArray(const vector& s, const T& fillval) { 49 | reshape(s, fillval); 50 | } 51 | virtual ~MultiArray() { 52 | } 53 | virtual size_t size() const { 54 | return data.size(); 55 | } 56 | virtual size_t num_dims() const { 57 | return shape.size(); 58 | } 59 | virtual bool empty() const { 60 | return data.empty(); 61 | } 62 | virtual void resize_data() { 63 | data.resize(product(shape)); 64 | strides.resize(shape.size()); 65 | strides.back() = 1; 66 | for (int i = shape.size()-2; i >= 0; --i) { 67 | strides.at(i) = strides.at(i+1) * shape.at(i+1); 68 | } 69 | } 70 | template void reshape(const R& newShape) { 71 | assert(newShape.size()); 72 | shape = newShape; 73 | resize_data(); 74 | } 75 | void fill_data(const T& fillVal) { 76 | fill(data, fillVal); 77 | } 78 | template void reshape(const R& dims, const T& fillVal) { 79 | reshape(dims); 80 | fill_data(fillVal); 81 | } 82 | bool in_range(const vector& coords) const { 83 | if (coords.size() > shape.size()) { 84 | return false; 85 | } 86 | VSTCI shapeIt = shape.begin(); 87 | for (VICI coordIt = coords.begin(); coordIt != coords.end(); 88 | ++coordIt, ++shapeIt) { 89 | int c = *coordIt; 90 | if (c < 0 || c >= *shapeIt) { 91 | return false; 92 | } 93 | } 94 | return true; 95 | } 96 | T& get(const vector& coords) { 97 | check(boost::size(coords) == shape.size(), "get(" + str(coords) + 98 | ") called with shape " + str(shape)); 99 | return *((*this)[coords].begin()); 100 | } 101 | const T& get(const vector& coords) const { 102 | check(boost::size(coords) == shape.size(), "get(" + str(coords) + 103 | ") called with shape " + str(shape)); 104 | return (*this)[coords].front(); 105 | } 106 | size_t offset(const vector& coords) const { 107 | return inner_product(coords, strides); 108 | } 109 | const View operator[](const vector& coords) { 110 | check(coords.size() <= shape.size(), "operator [" + str(coords) + 111 | "] called with shape " + str(shape)); 112 | if (coords.empty()) { 113 | return View(&data.front(), &data.front() + data.size()); 114 | } 115 | T* start = &data.front() + offset(coords); 116 | T* end = start + strides[coords.size() - 1]; 117 | return View(start, end); 118 | } 119 | const View operator[](const vector& coords) const { 120 | check(coords.size() <= shape.size(), "operator [" + str(coords) + 121 | "] called with shape " + str(shape)); 122 | if (coords.empty()) { 123 | return View(&data.front(), &data.front() + data.size()); 124 | } 125 | const T* start = &data.front(); 126 | VSTCI strideIt = strides.begin() + offset(coords); 127 | const T* end = start + strides[coords.size() - 1]; 128 | return View(start, end); 129 | } 130 | const View at(const vector& coords) { 131 | if (in_range(coords)) { 132 | return (*this)[coords]; 133 | } 134 | return View(0, 0); 135 | } 136 | const View at(const vector& coords) const { 137 | if (in_range(coords)) { 138 | return (*this)[coords]; 139 | } 140 | return View(0, 0); 141 | } 142 | template void assign(const MultiArray& a) { 143 | reshape(a.shape); 144 | copy(a.data, data); 145 | } 146 | template MultiArray& operator=(const MultiArray& a) { 147 | assign(a); 148 | return *this; 149 | } 150 | }; 151 | 152 | template static bool operator ==( 153 | const MultiArray& a, const MultiArray& b) { 154 | return (a.data == b.data && a.shape == b.shape); 155 | } 156 | 157 | #endif 158 | -------------------------------------------------------------------------------- /rnnlib4seshat/MultilayerNet.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #ifndef _INCLUDED_MultilayerNet_h 41 | #define _INCLUDED_MultilayerNet_h 42 | 43 | #include "Mdrnn.hpp" 44 | #include "ClassificationLayer.hpp" 45 | #include "TranscriptionLayer.hpp" 46 | 47 | struct MultilayerNet: public Mdrnn { 48 | //functions 49 | MultilayerNet(ostream& out, ConfigFile& conf, const DataHeader& data, WeightContainer *weight, DataExportHandler *deh): 50 | Mdrnn(out, conf, data, weight, deh) { 51 | string task = conf.get("task"); 52 | vector hiddenSizes = conf.get_list("hiddenSize"); 53 | assert(hiddenSizes.size()); 54 | Vector hiddenTypes = 55 | conf.get_list("hiddenType", "lstm", hiddenSizes.size()); 56 | Vector > hiddenBlocks = 57 | conf.get_array("hiddenBlock"); 58 | assert(hiddenBlocks.size() < hiddenSizes.size()); 59 | Vector subsampleSizes = conf.get_list("subsampleSize"); 60 | assert(subsampleSizes.size() < hiddenSizes.size()); 61 | string subsampleType = conf.get("subsampleType", "tanh"); 62 | bool subsampleBias = conf.get("subsampleBias", false); 63 | Vector recurrent = 64 | conf.get_list("recurrent", true, hiddenSizes.size()); 65 | Layer* input = this->get_input_layer(); 66 | LOOP(int i, indices(hiddenSizes)) { 67 | string level_suffix = int_to_sortable_string(i, hiddenSizes.size()); 68 | this->add_hidden_level( 69 | hiddenTypes.at(i), hiddenSizes.at(i), recurrent.at(i), 70 | "hidden_" + level_suffix); 71 | this->connect_to_hidden_level(input, i); 72 | vector blocks; 73 | if (i < hiddenBlocks.size()) { 74 | LOOP(Layer* l, hiddenLevels[i]) { 75 | blocks += this->add_layer(new BlockLayer(l, hiddenBlocks.at(i), wc, deh)); 76 | } 77 | } 78 | vector& topLayers = blocks.size() ? blocks : hiddenLevels[i]; 79 | if (i < subsampleSizes.size()) { 80 | input = this->add_layer( 81 | subsampleType, "subsample_" + level_suffix, subsampleSizes.at(i), 82 | empty_list_of().repeat(this->num_seq_dims(), 1), 83 | subsampleBias, false); 84 | LOOP(Layer* l, topLayers) { 85 | this->connect_layers(l, input); 86 | } 87 | } else if (i < last_index(hiddenSizes)) { 88 | input = this->add_layer(new GatherLayer( 89 | "gather_" + level_suffix, topLayers, wc, deh)); 90 | } 91 | } 92 | conf.set_val("inputSize", inputLayer->output_size()); 93 | if (data.targetLabels.size()) { 94 | string labelDelimiters(",.;:|+&_~*%$#^=-<>/?{}[]()"); 95 | LOOP(char c, labelDelimiters) { 96 | bool goodDelim = true; 97 | LOOP(const string& s, data.targetLabels) { 98 | if (in(s, c)) { 99 | goodDelim = false; 100 | break; 101 | } 102 | } 103 | if (goodDelim) { 104 | stringstream ss; 105 | print_range(ss, data.targetLabels, c); 106 | conf.set_val("targetLabels", ss.str()); 107 | conf.set_val("labelDelimiter", c); 108 | break; 109 | } 110 | } 111 | } 112 | string outputName = "output"; 113 | Layer* output = 0; 114 | size_t outSeqDims = (in(task, "sequence_") ? 0 : num_seq_dims()); 115 | if (in(task, "classification")) { 116 | output = add_output_layer(make_classification_layer( 117 | out, outputName, outSeqDims, data.targetLabels, wc, deh)); 118 | } else if (task == "transcription") { 119 | check(this->num_seq_dims(), "cannot perform transcription wth 0D net"); 120 | output = add_output_layer(new TranscriptionLayer( 121 | out, outputName, data.targetLabels, wc, deh, conf.get( 122 | "confusionMatrix", false))); 123 | if (this->num_seq_dims() > 1) { 124 | output = this->collapse_layer( 125 | hiddenLayers.back(), output, list_of(true)); 126 | } 127 | } else { 128 | check(false, "unknown task '" + task + "'"); 129 | } 130 | if(this->num_seq_dims() && in(task, "sequence_")) { 131 | output = this->collapse_layer(hiddenLayers.back(), output); 132 | } 133 | connect_from_hidden_level(last_index(hiddenLevels), output); 134 | } 135 | }; 136 | 137 | #endif 138 | -------------------------------------------------------------------------------- /rnnlib4seshat/Named.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2009,2010 Alex Graves 2 | 3 | This file is part of RNNLIB. 4 | 5 | RNNLIB is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | RNNLIB is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with RNNLIB. If not, see .*/ 17 | 18 | #ifndef _INCLUDED_NamedObject_h 19 | #define _INCLUDED_NamedObject_h 20 | 21 | #include 22 | #include 23 | #include 24 | 25 | using namespace std; 26 | 27 | struct Named { 28 | string name; 29 | Named(const string & n): name(n) {} 30 | virtual ~Named(){} 31 | void print(ostream& out = cout) const { 32 | out << "\"" << name << "\""; 33 | } 34 | }; 35 | 36 | static ostream& operator << (ostream& out, const Named& n) { 37 | n.print(out); 38 | return out; 39 | } 40 | 41 | #endif 42 | -------------------------------------------------------------------------------- /rnnlib4seshat/NetcdfDataset.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #ifndef _INCLUDED_NetcdfDataset_h 41 | #define _INCLUDED_NetcdfDataset_h 42 | 43 | #include 44 | #include 45 | #include 46 | #include 47 | #include 48 | //#include 49 | #include 50 | #include "DataSequence.hpp" 51 | #include "Helpers.hpp" 52 | 53 | #define SEQ_IT vector::iterator 54 | #define CONST_SEQ_IT vector::const_iterator 55 | 56 | 57 | struct DataHeader 58 | { 59 | //data 60 | int numDims; 61 | Vector inputLabels; 62 | map inputLabelCounts; 63 | Vector targetLabels; 64 | map targetLabelCounts; 65 | size_t inputSize; 66 | size_t outputSize; 67 | size_t numSequences; 68 | size_t numTimesteps; 69 | size_t totalTargetStringLength; 70 | 71 | //functions 72 | DataHeader(): outputSize(0), 73 | numTimesteps(0), 74 | totalTargetStringLength(0) 75 | {} 76 | }; 77 | 78 | #endif 79 | -------------------------------------------------------------------------------- /rnnlib4seshat/NetworkOutput.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2009,2010 Alex Graves 2 | 3 | This file is part of RNNLIB. 4 | 5 | RNNLIB is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | RNNLIB is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with RNNLIB. If not, see .*/ 17 | 18 | #ifndef _INCLUDED_NetworkOutput_h 19 | #define _INCLUDED_NetworkOutput_h 20 | 21 | #include 22 | #include "DataSequence.hpp" 23 | 24 | #define ERR(x) this->errorMap[#x] = x 25 | 26 | struct NetworkOutput 27 | { 28 | //data 29 | map errorMap; 30 | map normFactors; 31 | Vector criteria; 32 | 33 | //functions 34 | NetworkOutput(){} 35 | virtual real_t calculate_errors(const DataSequence& seq){return realMax;} 36 | }; 37 | 38 | #endif 39 | -------------------------------------------------------------------------------- /rnnlib4seshat/NeuronLayer.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #ifndef _INCLUDED_NeuronLayer_h 41 | #define _INCLUDED_NeuronLayer_h 42 | 43 | #include "Layer.hpp" 44 | 45 | template struct NeuronLayer: public FlatLayer 46 | { 47 | NeuronLayer(const string& name, size_t numDims, size_t size, WeightContainer *weight, DataExportHandler *deh): 48 | FlatLayer(name, numDims, size, weight, deh) 49 | { 50 | init(); 51 | } 52 | NeuronLayer(const string& name, const vector& directions, size_t size, WeightContainer *weight, DataExportHandler *deh): 53 | FlatLayer(name, directions, size, weight, deh) 54 | { 55 | init(); 56 | } 57 | ~NeuronLayer(){} 58 | void init() 59 | { 60 | display(this->inputActivations, "inputActivations"); 61 | display(this->outputActivations, "outputActivations"); 62 | display(this->inputErrors, "inputErrors"); 63 | display(this->outputErrors, "outputErrors"); 64 | } 65 | void feed_forward(const vector& coords) 66 | { 67 | transform(this->inputActivations[coords], this->outputActivations[coords], F::fn); 68 | } 69 | void feed_back(const vector& coords) 70 | { 71 | LOOP(TDDD t, zip(this->inputErrors[coords], this->outputActivations[coords], this->outputErrors[coords])) 72 | { 73 | t.get<0>() = F::deriv(t.get<1>()) * t.get<2>(); 74 | } 75 | } 76 | }; 77 | 78 | #endif 79 | -------------------------------------------------------------------------------- /rnnlib4seshat/Optimiser.cpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #include "Optimiser.hpp" 41 | 42 | ostream& operator << (ostream& out, const Optimiser& o) { 43 | o.print(out); 44 | return out; 45 | } 46 | -------------------------------------------------------------------------------- /rnnlib4seshat/Optimiser.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #ifndef _INCLUDED_Optimiser_h 41 | #define _INCLUDED_Optimiser_h 42 | 43 | #include 44 | #include 45 | 46 | #include "DataExporter.hpp" 47 | 48 | using namespace std; 49 | 50 | struct Optimiser { 51 | // data 52 | vector& wts; 53 | vector& derivs; 54 | 55 | // functions 56 | Optimiser(vector& weights, vector& derivatives): 57 | wts(weights), derivs(derivatives) { 58 | } 59 | 60 | virtual ~Optimiser() {} 61 | virtual void update_weights() = 0; 62 | virtual void print(ostream& out = cout) const = 0; 63 | virtual void build() = 0; 64 | }; 65 | 66 | ostream& operator << (ostream& out, const Optimiser& o); 67 | 68 | #endif 69 | -------------------------------------------------------------------------------- /rnnlib4seshat/Random.cpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2009,2010 Alex Graves 2 | 3 | This file is part of RNNLIB. 4 | 5 | RNNLIB is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | RNNLIB is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with RNNLIB. If not, see .*/ 17 | 18 | #include "Random.hpp" 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | using namespace std; 30 | using namespace boost; 31 | using namespace boost::random; 32 | using namespace boost::posix_time; 33 | using namespace boost::gregorian; 34 | 35 | typedef mt19937 BaseGenType; 36 | static BaseGenType generator(42u); 37 | 38 | real_t Random::normal() 39 | { 40 | static variate_generator > norm(generator, normal_distribution()); 41 | return norm(); 42 | } 43 | real_t Random::normal(real_t dev, real_t mean) 44 | { 45 | return (normal() * dev) + mean; 46 | } 47 | unsigned int Random::set_seed(unsigned int seed) 48 | { 49 | if (seed == 0) 50 | { 51 | time_period p (ptime(date(1970, Jan, 01)), microsec_clock::local_time()); 52 | seed = (unsigned int)p.length().ticks(); 53 | } 54 | srand(seed); 55 | generator.seed(seed); 56 | return seed; 57 | } 58 | real_t Random::uniform(real_t range) 59 | { 60 | return (uniform()*2*range) - range; 61 | } 62 | real_t Random::uniform() 63 | { 64 | static variate_generator > uni(generator, uniform_real()); 65 | return uni(); 66 | } 67 | -------------------------------------------------------------------------------- /rnnlib4seshat/Random.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2009,2010 Alex Graves 2 | 3 | This file is part of RNNLIB. 4 | 5 | RNNLIB is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | RNNLIB is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with RNNLIB. If not, see .*/ 17 | 18 | #ifndef _INCLUDED_Random_h 19 | #define _INCLUDED_Random_h 20 | 21 | #include "Helpers.hpp" 22 | 23 | namespace Random 24 | { 25 | unsigned int set_seed(unsigned int seed = 0); 26 | real_t normal(); //normal distribution with mean 0 std dev 1 27 | real_t normal(real_t dev, real_t mean = 0); //normal distribution with user defined mean, dev 28 | real_t uniform(real_t range); //uniform real in (-range, range) 29 | real_t uniform(); //uniform real in (0,1) 30 | }; 31 | 32 | #endif 33 | -------------------------------------------------------------------------------- /rnnlib4seshat/Rprop.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #ifndef _INCLUDED_Rprop_h 41 | #define _INCLUDED_Rprop_h 42 | 43 | #include 44 | #include 45 | #include 46 | 47 | #include "DataExporter.hpp" 48 | #include "Helpers.hpp" 49 | #include "Optimiser.hpp" 50 | 51 | using namespace std; 52 | 53 | struct Rprop: public DataExporter, public Optimiser { 54 | // data 55 | ostream& out; 56 | vector deltas; 57 | vector prevDerivs; 58 | real_t etaChange; 59 | real_t etaMin; 60 | real_t etaPlus; 61 | real_t minDelta; 62 | real_t maxDelta; 63 | real_t initDelta; 64 | real_t prevAvgDelta; 65 | bool online; 66 | WeightContainer *wc; 67 | 68 | // functions 69 | Rprop( 70 | const string& name, ostream& o, vector& weights, 71 | vector& derivatives, WeightContainer *weight, DataExportHandler *deh, bool on = false): 72 | DataExporter(name, deh), Optimiser(weights, derivatives), out(o), 73 | etaChange(0.01), etaMin(0.5), etaPlus(1.2), minDelta(1e-9), maxDelta(0.2), 74 | initDelta(0.01), prevAvgDelta(0), online(on), wc(weight) { 75 | if (online) { 76 | SAVE(prevAvgDelta); 77 | SAVE(etaPlus); 78 | } 79 | build(); 80 | } 81 | 82 | void update_weights() { 83 | assert(wts.size() == derivs.size()); 84 | assert(wts.size() == deltas.size()); 85 | assert(wts.size() == prevDerivs.size()); 86 | LOOP(int i, indices(wts)) { 87 | real_t deriv = derivs[i]; 88 | real_t delta = deltas[i]; 89 | real_t derivTimesPrev = deriv * prevDerivs[i]; 90 | if (derivTimesPrev > 0) { 91 | deltas[i] = bound(delta * etaPlus, minDelta, maxDelta); 92 | wts[i] -= sign(deriv) * delta; 93 | prevDerivs[i] = deriv; 94 | } else if (derivTimesPrev < 0) { 95 | deltas[i] = bound(delta * etaMin, minDelta, maxDelta); 96 | prevDerivs[i] = 0; 97 | } else { 98 | wts[i] -= sign(deriv) * delta; 99 | prevDerivs[i] = deriv; 100 | } 101 | } 102 | // use eta adaptations for online training (from Mike Schuster's thesis) 103 | if (online) { 104 | real_t avgDelta = mean(deltas); 105 | if (avgDelta > prevAvgDelta) { 106 | etaPlus = max((real_t)1.0, etaPlus - etaChange); 107 | } else { 108 | etaPlus += etaChange; 109 | } 110 | prevAvgDelta = avgDelta; 111 | } 112 | if (verbose) { 113 | PRINT(minmax(wts), out); 114 | PRINT(minmax(derivs), out); 115 | PRINT(minmax(deltas), out); 116 | PRINT(minmax(prevDerivs), out); 117 | } 118 | } 119 | 120 | // NOTE must be called after any change to weightContainer 121 | void build() { 122 | if (deltas.size() != wts.size()) { 123 | deltas.resize(wts.size()); 124 | prevDerivs.resize(wts.size()); 125 | fill(deltas, initDelta); 126 | fill(prevDerivs, 0); 127 | wc->save_by_conns( 128 | deltas, ((name == "optimiser") ? "" : name + "_") + "deltas"); 129 | wc->save_by_conns( 130 | prevDerivs, ((name == "optimiser") ? "" : name + "_") + "prevDerivs"); 131 | } 132 | } 133 | 134 | void print(ostream& out = cout) const { 135 | out << "RPROP" << endl; 136 | PRINT(online, out); 137 | if (online) { 138 | PRINT(prevAvgDelta, out); 139 | PRINT(etaChange, out); 140 | } 141 | PRINT(etaMin, out); 142 | PRINT(etaPlus, out); 143 | PRINT(minDelta, out); 144 | PRINT(maxDelta, out); 145 | PRINT(initDelta, out); 146 | } 147 | }; 148 | 149 | #endif 150 | -------------------------------------------------------------------------------- /rnnlib4seshat/SoftmaxLayer.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #ifndef _INCLUDED_SoftmaxLayer_h 41 | #define _INCLUDED_SoftmaxLayer_h 42 | 43 | #include 44 | #include "NetworkOutput.hpp" 45 | #include "Log.hpp" 46 | #include "Layer.hpp" 47 | 48 | struct SoftmaxLayer: public FlatLayer{ 49 | //data 50 | vector targetLabels; 51 | SeqBuffer > logActivations; 52 | SeqBuffer > unnormedlogActivations; 53 | SeqBuffer unnormedActivations; 54 | 55 | //functions 56 | SoftmaxLayer(const string& name, size_t numSeqDims, const vector& labs, WeightContainer *wc, DataExportHandler *deh): 57 | FlatLayer(name, numSeqDims, labs.size(), wc, deh), 58 | targetLabels(labs), 59 | logActivations(this->output_size()), 60 | unnormedlogActivations(this->output_size()), 61 | unnormedActivations(this->output_size()) 62 | { 63 | //display(this->inputErrors, "inputErrors", &targetLabels); 64 | //display(this->outputErrors, "outputErrors", &targetLabels); 65 | //display(this->inputActivations, "inputActivations", &targetLabels); 66 | //display(this->outputActivations, "outputActivations", &targetLabels); 67 | display(this->outputActivations, "outputActivations"); 68 | } 69 | void start_sequence() 70 | { 71 | Layer::start_sequence(); 72 | logActivations.reshape(this->inputActivations); 73 | unnormedlogActivations.reshape(logActivations); 74 | unnormedActivations.reshape(logActivations); 75 | } 76 | void feed_forward(const vector& coords) 77 | { 78 | //transform to log scale and centre inputs on 0 for safer exponentiation 79 | View > unnormedLogActs = unnormedlogActivations[coords]; 80 | real_t offset = pair_mean(minmax(this->inputActivations[coords])); 81 | LOOP(TDL t, zip(this->inputActivations[coords], unnormedLogActs)) 82 | { 83 | t.get<1>() = Log(t.get<0>() - offset, true); 84 | } 85 | 86 | //apply exponential 87 | View unnormedActs = unnormedActivations[coords]; 88 | transform(unnormedLogActs, unnormedActs, mem_fun_ref(&Log::exp)); 89 | 90 | //normalise 91 | real_t Z = sum(unnormedActs); 92 | range_divide_val(this->outputActivations[coords], unnormedActs, Z); 93 | range_divide_val(logActivations[coords], unnormedLogActs, Log(Z)); 94 | } 95 | void feed_back(const vector& coords) 96 | { 97 | View outActs = this->outputActivations[coords]; 98 | View outErrs = this->outputErrors[coords]; 99 | real_t Z = inner_product(outActs, outErrs); 100 | LOOP(TDDD t, zip(this->inputErrors[coords], outActs, outErrs)) 101 | { 102 | t.get<0>() = t.get<1>() * (t.get<2>() - Z); 103 | } 104 | } 105 | }; 106 | 107 | #endif 108 | -------------------------------------------------------------------------------- /rnnlib4seshat/SteepestDescent.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #ifndef _INCLUDED_SteepestDescent_h 41 | #define _INCLUDED_SteepestDescent_h 42 | 43 | #include 44 | #include 45 | #include 46 | 47 | #include "DataExporter.hpp" 48 | #include "Optimiser.hpp" 49 | 50 | 51 | using namespace std; 52 | extern bool verbose; 53 | 54 | struct SteepestDescent: public DataExporter, public Optimiser { 55 | // data 56 | ostream& out; 57 | vector deltas; 58 | real_t learnRate; 59 | real_t momentum; 60 | WeightContainer *wc; 61 | 62 | // functions 63 | SteepestDescent( 64 | const string& name, ostream& o, vector& weights, 65 | vector& derivatives, WeightContainer *weight, DataExportHandler *deh, real_t lr = 1e-4, real_t mom = 0.9): 66 | DataExporter(name, deh), Optimiser(weights, derivatives), out(o), 67 | learnRate(lr), momentum(mom), wc(weight) { 68 | build(); 69 | } 70 | 71 | void update_weights() { 72 | assert(wts.size() == derivs.size()); 73 | assert(wts.size() == deltas.size()); 74 | LOOP(int i, indices(wts)) { 75 | real_t delta = (momentum * deltas[i]) - (learnRate * derivs[i]); 76 | deltas[i] = delta; 77 | wts[i] += delta; 78 | } 79 | if (verbose) { 80 | out << this->name << " weight updates:" << endl; 81 | PRINT(minmax(wts), out); 82 | PRINT(minmax(derivs), out); 83 | PRINT(minmax(deltas), out); 84 | } 85 | } 86 | 87 | // NOTE must be called after any change to weightContainer 88 | void build() { 89 | if (deltas.size() != wts.size()) { 90 | deltas.resize(wts.size()); 91 | fill(deltas, 0); 92 | wc->save_by_conns(deltas, name + "_deltas"); 93 | } 94 | } 95 | 96 | void print(ostream& out = cout) const { 97 | out << "steepest descent" << endl; 98 | PRINT(learnRate, out); 99 | PRINT(momentum, out); 100 | } 101 | }; 102 | 103 | #endif 104 | -------------------------------------------------------------------------------- /rnnlib4seshat/String.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2009,2010 Alex Graves 2 | 3 | This file is part of RNNLIB. 4 | 5 | RNNLIB is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | RNNLIB is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with RNNLIB. If not, see .*/ 17 | 18 | #ifndef _INCLUDED_String_h 19 | #define _INCLUDED_String_h 20 | 21 | #include "Helpers.hpp" 22 | #include "Container.hpp" 23 | 24 | static string ordinal(size_t n) { 25 | string s = str(n); 26 | if (n < 100) { 27 | char c = nth_last(s); 28 | if(c == '1') { 29 | return s + "st"; 30 | } else if(c == '2') { 31 | return s + "nd"; 32 | } else if(c == '3') { 33 | return s + "rd"; 34 | } 35 | } 36 | return s + "th"; 37 | } 38 | static void trim(string& str) { 39 | size_t startpos = str.find_first_not_of(" \t\n"); 40 | size_t endpos = str.find_last_not_of(" \t\n"); 41 | if(string::npos == startpos || string::npos == endpos) { 42 | str = ""; 43 | } else { 44 | str = str.substr(startpos, endpos-startpos + 1); 45 | } 46 | } 47 | static const string lower(const string& s) { 48 | static string l; 49 | l = s; 50 | algorithm::to_lower(l); 51 | return l; 52 | } 53 | static bool in(const string& str, const string& search) { 54 | return (str.find(search) != string::npos); 55 | } 56 | static bool in(const string& str, const char* search) { 57 | return in(str, string(search)); 58 | } 59 | static bool begins(const string& str, const string& search) { 60 | return (str.find(search) == 0); 61 | } 62 | static bool begins(const string& str, const char* search) { 63 | return begins(str, string(search)); 64 | } 65 | static bool ends(const string& str, const string& search) { 66 | return (str.find(search, str.size() - search.size()) != string::npos); 67 | } 68 | static bool ends(const string& str, const char* search) { 69 | return ends(str, string(search)); 70 | } 71 | template static Vector split( 72 | const string& original, char delim = ' ', size_t maxSplits = 0) { 73 | Vector vect; 74 | stringstream ss; 75 | ss << original; 76 | string s; 77 | while (delim == ' ' ? ss >> s : getline(ss, s, delim)) { 78 | vect += read(s); 79 | if (vect.size() == maxSplits-1) { 80 | delim = '\0'; 81 | } 82 | } 83 | return vect; 84 | } 85 | template static Vector split_with_repeat( 86 | const string& original, char delim = ' ', char repeater = '*') { 87 | Vector vect; 88 | LOOP(const string& s1, split(original, delim)) { 89 | vector v = split(s1, repeater); 90 | size_t numRepeats = (v.size() == 1 ? 1 : natural(v[1])); 91 | T val = read(v[0]); 92 | vect += val, repeat(numRepeats-1, val); 93 | } 94 | return vect; 95 | } 96 | templatestatic string join( 97 | const R& r, const string joinStr = "") { 98 | typename range_iterator::type b = boost::begin(r); 99 | string s = str(*b); 100 | ++b; 101 | for (; b != end(r); ++b) { 102 | s += joinStr + str(*b); 103 | } 104 | return s; 105 | } 106 | 107 | template string left_pad(const T& val, int width, char fill = '0') { 108 | ostringstream ss; 109 | ss << setw(width) << setfill(fill) << val; 110 | return ss.str(); 111 | } 112 | 113 | static string int_to_sortable_string(size_t num, size_t max) { 114 | assert(num < max); 115 | return left_pad(num, str(max-1).size()); 116 | } 117 | 118 | #endif 119 | -------------------------------------------------------------------------------- /rnnlib4seshat/StringAlignment.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2009,2010 Alex Graves 2 | 3 | This file is part of RNNLIB. 4 | 5 | RNNLIB is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | RNNLIB is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with RNNLIB. If not, see .*/ 17 | 18 | #ifndef _INCLUDED_StringAlignment_h 19 | #define _INCLUDED_StringAlignment_h 20 | 21 | #include 22 | #include 23 | #include 24 | #include "Helpers.hpp" 25 | 26 | using namespace std; 27 | 28 | template struct StringAlignment 29 | { 30 | //data 31 | map::type, map::type, int> > subsMap; 32 | map::type, int> delsMap; 33 | map::type, int> insMap; 34 | Vector > matrix; 35 | int substitutions; 36 | int deletions; 37 | int insertions; 38 | int distance; 39 | int subPenalty; 40 | int delPenalty; 41 | int insPenalty; 42 | size_t n; 43 | size_t m; 44 | 45 | //functions 46 | StringAlignment (const R1& reference_sequence, const R2& test_sequence, 47 | bool trackErrors = false, bool backtrace = true, int sp = 1, int dp = 1, int ip = 1): 48 | subPenalty(sp), 49 | delPenalty(dp), 50 | insPenalty(ip), 51 | n(reference_sequence.size()), 52 | m(test_sequence.size()) 53 | { 54 | if (n == 0) 55 | { 56 | substitutions = 0; 57 | deletions = 0; 58 | insertions = m; 59 | distance = m; 60 | } 61 | else if (m == 0) 62 | { 63 | substitutions = 0; 64 | deletions = n; 65 | insertions = 0; 66 | distance = n; 67 | } 68 | else 69 | { 70 | //initialise the matrix 71 | matrix.resize(n+1); 72 | LOOP(Vector& v, matrix) 73 | { 74 | v.resize(m+1); 75 | fill(v, 0); 76 | } 77 | LOOP (int i, span(n+1)) 78 | { 79 | matrix[i][0]=i; 80 | } 81 | LOOP (int j, span(m+1)) 82 | { 83 | matrix[0][j]=j; 84 | } 85 | 86 | //calculate the insertions, substitutions and deletions 87 | LOOP (int i, span(1, n + 1)) 88 | { 89 | const typename boost::range_value::type& s_i = reference_sequence[i-1]; 90 | LOOP (int j, span(1, m + 1)) 91 | { 92 | const typename boost::range_value::type& t_j = test_sequence[j-1]; 93 | int cost = ((s_i == t_j) ? 0 : 1); 94 | const int above = matrix[i-1][j]; 95 | const int left = matrix[i][j-1]; 96 | const int diag = matrix[i-1][j-1]; 97 | const int cell = min(above + 1, // deletion 98 | min(left + 1, // insertion 99 | diag + cost)); // substitution 100 | 101 | matrix[i][j]=cell; 102 | } 103 | } 104 | 105 | //N.B sub,ins and del penalties are all set to 1 if backtrace is ignored 106 | if (backtrace) 107 | { 108 | size_t i = n; 109 | size_t j = m; 110 | substitutions = 0; 111 | deletions = 0; 112 | insertions = 0; 113 | 114 | // Backtracking 115 | while (i != 0 && j != 0) 116 | { 117 | if (matrix[i][j] == matrix[i-1][j-1]) 118 | { 119 | --i; 120 | --j; 121 | } 122 | else if (matrix[i][j] == matrix[i-1][j-1] + 1) 123 | { 124 | if (trackErrors) 125 | { 126 | ++subsMap[reference_sequence[i]][test_sequence[j]]; 127 | } 128 | ++substitutions; 129 | --i; 130 | --j; 131 | } 132 | else if (matrix[i][j] == matrix[i-1][j] + 1) 133 | { 134 | if (trackErrors) 135 | { 136 | ++delsMap[reference_sequence[i]]; 137 | } 138 | ++deletions; 139 | --i; 140 | } 141 | else 142 | { 143 | if (trackErrors) 144 | { 145 | ++insMap[test_sequence[j]]; 146 | } 147 | ++insertions; 148 | --j; 149 | } 150 | } 151 | while (i != 0) 152 | { 153 | if (trackErrors) 154 | { 155 | ++delsMap[reference_sequence[i]]; 156 | } 157 | ++deletions; 158 | --i; 159 | } 160 | while (j != 0) 161 | { 162 | if (trackErrors) 163 | { 164 | ++insMap[test_sequence[j]]; 165 | } 166 | ++insertions; 167 | --j; 168 | } 169 | 170 | // Sanity check: 171 | check((substitutions + deletions + insertions) == matrix[n][m], 172 | "Found path with distance " + str(substitutions + deletions + insertions) + 173 | " but Levenshtein distance is " + str(matrix[n][m])); 174 | 175 | //scale individual errors by penalties 176 | distance = (subPenalty*substitutions) + (delPenalty*deletions) + (insPenalty*insertions); 177 | } 178 | else 179 | { 180 | distance = matrix[n][m]; 181 | } 182 | } 183 | } 184 | ~StringAlignment(){} 185 | }; 186 | 187 | #endif 188 | -------------------------------------------------------------------------------- /rnnlib4seshat/WeightContainer.cpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #include "WeightContainer.hpp" 41 | 42 | void perturb_weight(real_t& weight, real_t stdDev, bool additive) 43 | { 44 | weight += Random::normal(fabs(additive ? stdDev : stdDev * weight)); 45 | } 46 | 47 | 48 | template void perturb_weights(R& weights, real_t stdDev, bool additive = true) 49 | { 50 | LOOP(real_t& w, weights) 51 | { 52 | perturb_weight(w, stdDev, additive); 53 | } 54 | } 55 | template void perturb_weights(R& weights, R& stdDevs, bool additive = true) 56 | { 57 | assert(boost::size(weights) == boost::size(stdDevs)); 58 | LOOP(int i, indices(weights)) 59 | { 60 | perturb_weight(weights[i], stdDevs[i], additive); 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /rnnlib4seshat/WeightContainer.hpp: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | 18 | 19 | This file is a modification of the RNNLIB original software covered by 20 | the following copyright and permission notice: 21 | 22 | */ 23 | /*Copyright 2009,2010 Alex Graves 24 | 25 | This file is part of RNNLIB. 26 | 27 | RNNLIB is free software: you can redistribute it and/or modify 28 | it under the terms of the GNU General Public License as published by 29 | the Free Software Foundation, either version 3 of the License, or 30 | (at your option) any later version. 31 | 32 | RNNLIB is distributed in the hope that it will be useful, 33 | but WITHOUT ANY WARRANTY; without even the implied warranty of 34 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 35 | GNU General Public License for more details. 36 | 37 | You should have received a copy of the GNU General Public License 38 | along with RNNLIB. If not, see .*/ 39 | 40 | #ifndef _INCLUDED_WeightContainer_h 41 | #define _INCLUDED_WeightContainer_h 42 | 43 | #include 44 | #include 45 | #include 46 | #include 47 | #include 48 | #include 49 | #include "Random.hpp" 50 | #include "DataExporter.hpp" 51 | 52 | using namespace std; 53 | 54 | typedef multimap >::iterator WC_CONN_IT; 55 | typedef pair > WC_CONN_PAIR; 56 | 57 | struct WeightContainer: public DataExporter 58 | { 59 | //data 60 | Vector weights; 61 | Vector derivatives; 62 | multimap > connections; 63 | 64 | //functions 65 | WeightContainer(DataExportHandler *deh): 66 | DataExporter("weightContainer", deh) 67 | { 68 | } 69 | 70 | void link_layers(const string& fromName, const string& toName, const string& connName = "", int paramBegin = 0, int paramEnd = 0) { 71 | connections.insert(make_pair(toName, make_tuple(fromName, connName, paramBegin, paramEnd))); 72 | } 73 | 74 | pair new_parameters(size_t numParams, const string& fromName, const string& toName, const string& connName) { 75 | size_t begin = weights.size(); 76 | weights.resize(weights.size() + numParams); 77 | size_t end = weights.size(); 78 | link_layers(fromName, toName, connName, begin, end); 79 | return make_pair(begin, end); 80 | } 81 | 82 | View get_weights(pair range) 83 | { 84 | return weights.slice(range); 85 | } 86 | 87 | View get_derivs(pair range) 88 | { 89 | return derivatives.slice(range); 90 | } 91 | 92 | int randomise(real_t range) 93 | { 94 | int numRandWts = 0; 95 | LOOP(real_t& w, weights) 96 | { 97 | if (w == infinity) 98 | { 99 | w = Random::uniform(range); 100 | ++numRandWts; 101 | } 102 | } 103 | return numRandWts; 104 | } 105 | 106 | void reset_derivs() 107 | { 108 | fill(derivatives, 0); 109 | } 110 | 111 | void save_by_conns(vector& container, const string& nam) 112 | { 113 | LOOP(const WC_CONN_PAIR& p, connections) 114 | { 115 | VDI begin = container.begin() + p.second.get<2>(); 116 | VDI end = container.begin() + p.second.get<3>(); 117 | if (begin != end) 118 | { 119 | save_range(make_pair(begin, end), p.second.get<1>() + "_" + nam); 120 | } 121 | } 122 | } 123 | 124 | //MUST BE CALLED BEFORE WEIGHT CONTAINER IS USED 125 | void build() 126 | { 127 | fill(weights, infinity); 128 | derivatives.resize(weights.size()); 129 | save_by_conns(weights, "weights"); 130 | reset_derivs(); 131 | } 132 | }; 133 | 134 | void perturb_weight(real_t& weight, real_t stdDev, bool additive = true); 135 | template void perturb_weights(R& weights, real_t stdDev, bool additive = true); 136 | template void perturb_weights(R& weights, R& stdDevs, bool additive = true); 137 | 138 | #endif 139 | -------------------------------------------------------------------------------- /sample.h: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #ifndef _SAMPLE_ 19 | #define _SAMPLE_ 20 | 21 | class SymRec; 22 | class TableCYK; 23 | 24 | #include 25 | #include 26 | #include 27 | #include "tablecyk.h" 28 | #include "cellcyk.h" 29 | #include "stroke.h" 30 | #include "symrec.h" 31 | #include "grammar.h" 32 | 33 | using namespace std; 34 | 35 | //Segmentation hypothesis 36 | struct SegmentHyp{ 37 | list stks; //List of strokes 38 | 39 | //Bounding box (online coordinates) 40 | int rx, ry; //Top-left 41 | int rs, rt; //Bottom-right 42 | 43 | int cen; 44 | }; 45 | 46 | 47 | class Sample{ 48 | vector dataon; 49 | float **stk_dis; 50 | 51 | int **dataoff; 52 | int X, Y; 53 | int IMGxMIN, IMGyMIN, IMGxMAX, IMGyMAX; 54 | int **pix_stk; 55 | 56 | SymRec *SR; 57 | 58 | //Information to create the output InkML file 59 | char *outinkml, *outdot; 60 | string UItag; 61 | int next_id; 62 | 63 | void loadInkML(char *str); 64 | void loadSCGInk(char *str); 65 | 66 | void linea(int **img, Punto *pa, Punto *pb, int stkid); 67 | void linea_pbm(int **img, Punto *pa, Punto *pb, int stkid); 68 | bool not_visible(int si, int sj, Punto *pi, Punto *pj); 69 | 70 | public: 71 | //Normalized reference symbol size 72 | int RX, RY; 73 | float INF_DIST; //Infinite distance value (visibility) 74 | float NORMF; //Normalization factor for distances 75 | 76 | int ox, oy, os, ot; //Online bounding box 77 | int bx, by, bs, bt; //Offline bounding box 78 | 79 | Sample(char *in); 80 | ~Sample(); 81 | 82 | int dimX(); 83 | int dimY(); 84 | int nStrokes(); 85 | int get(int x, int y); 86 | Stroke *getStroke(int i); 87 | 88 | void getCentroids(CellCYK *cd, int *ce, int *as, int *ds); 89 | void getAVGstroke_size(float *avgw, float *avgh); 90 | 91 | void detRefSymbol(); 92 | void compute_strokes_distances(int rx, int ry); 93 | float stroke_distance(int si, int sj); 94 | float getDist(int si, int sj); 95 | void get_close_strokes(int id, list *L, float dist_th); 96 | 97 | float group_penalty(CellCYK *A, CellCYK *B); 98 | bool visibility(list *strokes_list); 99 | 100 | void setSymRec( SymRec *sr ); 101 | 102 | void setRegion(CellCYK *c, int nComp); 103 | void setRegion(CellCYK *c, list *LT); 104 | void setRegion(CellCYK *c, int *v, int size); 105 | 106 | int **render(int *pW, int *pH); 107 | void renderStrokesPBM(list *SL, int ***img, int *rows, int *cols); 108 | 109 | void render_img(char *out); 110 | void set_out_inkml(char *out); 111 | void set_out_dot(char *out); 112 | char *getOutDot(); 113 | 114 | void print(); 115 | void printInkML(Grammar *G, Hypothesis *H); 116 | void printSymRecInkML(Hypothesis *H, FILE *fout); 117 | }; 118 | 119 | #endif 120 | -------------------------------------------------------------------------------- /segmentation.cc: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #include "segmentation.h" 19 | #include 20 | 21 | SegmentationModelGMM::SegmentationModelGMM(char *mod) { 22 | FILE *fd = fopen(mod,"r"); 23 | if( !fd ) { 24 | fprintf(stderr, "Error loading segmentation model '%s'\n", mod); 25 | exit(-1); 26 | } 27 | fclose(fd); 28 | 29 | model = new GMM(mod); 30 | } 31 | 32 | SegmentationModelGMM::~SegmentationModelGMM() { 33 | delete model; 34 | } 35 | 36 | float SegmentationModelGMM::prob(CellCYK *cd, Sample *m) { 37 | int Nstrokes=0, nps=0; 38 | float dist=0, delta=0, sigma=0, mind=0, avgsize=0; 39 | 40 | for(int i=0; inc; i++) 41 | if( cd->ccc[i] ) 42 | Nstrokes++; 43 | 44 | int *strokes_list = new int[Nstrokes]; 45 | Nstrokes = 0; 46 | for(int i=0; inc; i++) 47 | if( cd->ccc[i] ) 48 | strokes_list[Nstrokes++] = i; 49 | 50 | //For every stroke 51 | for(int i=0; igetStroke( strokes_list[i] ); 53 | 54 | float size_i = max(Si->rs - Si->rx, Si->rt - Si->ry); 55 | avgsize += size_i; 56 | 57 | for(int j=i+1; jgetStroke( strokes_list[j] ); 59 | 60 | //distance between stroke Si and Sj 61 | mind += Si->min_dist( Sj ); 62 | 63 | dist += abs( (Si->rs + Si->rx)/2.0 - (Sj->rs + Sj->rx)/2.0 ); 64 | sigma += abs( (Si->rt + Si->ry)/2.0 - (Sj->rt + Sj->ry)/2.0 ); 65 | 66 | float size_j = max( Sj->rt - Sj->ry, Sj->rs - Sj->rx); 67 | delta += abs( size_i - size_j ); 68 | 69 | nps++; 70 | } 71 | 72 | } 73 | 74 | float avgw, avgh, nf; 75 | m->getAVGstroke_size(&avgw, &avgh); 76 | nf = sqrt(avgw*avgw + avgh*avgh); 77 | 78 | mind /= nps*nf; 79 | dist /= nps*nf; 80 | delta /= nps*nf; 81 | sigma /= nps*nf; 82 | 83 | float sample[4]; 84 | float probs[2]; 85 | 86 | sample[0] = mind; 87 | sample[1] = dist; 88 | sample[2] = delta; 89 | sample[3] = sigma; 90 | 91 | model->posterior(sample, probs); 92 | 93 | delete[] strokes_list; 94 | 95 | //Return probability of being a proper segmentation hypothesis 96 | return probs[1]; 97 | } 98 | -------------------------------------------------------------------------------- /segmentation.h: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #ifndef _SEGMENTATION_MODEL_ 19 | #define _SEGMENTATION_MODEL_ 20 | 21 | #include 22 | #include 23 | #include "cellcyk.h" 24 | #include "sample.h" 25 | #include "gmm.h" 26 | 27 | using namespace std; 28 | 29 | 30 | class SegmentationModelGMM{ 31 | GMM *model; 32 | 33 | public: 34 | SegmentationModelGMM(char *mod); 35 | ~SegmentationModelGMM(); 36 | 37 | float prob(CellCYK *cd, Sample *m); 38 | }; 39 | 40 | #endif 41 | -------------------------------------------------------------------------------- /seshat.cc: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #include 19 | #include 20 | #include 21 | #include "grammar.h" 22 | #include "sample.h" 23 | #include "meparser.h" 24 | 25 | 26 | #define MAXS 4096 27 | 28 | using namespace std; 29 | 30 | void usage(char *str) { 31 | fprintf(stderr, "SESHAT - Handwritten math expression parser\nhttps://github.com/falvaro/seshat\n"); 32 | fprintf(stderr, "Copyright (C) 2014, Francisco Alvaro\n\n"); 33 | fprintf(stderr, "Usage: %s -c config -i input [-o output] [-r render.pgm]\n\n", str); 34 | fprintf(stderr, " -c config: set the configuration file\n"); 35 | fprintf(stderr, " -i input: set the input math expression file\n"); 36 | fprintf(stderr, " -o output: save recognized expression to 'output' file (InkML format)\n"); 37 | fprintf(stderr, " -r render: save in 'render' the image representing the input expression (PGM format)\n"); 38 | fprintf(stderr, " -d graph: save in 'graph' the description of the recognized tree (DOT format)\n"); 39 | } 40 | 41 | int main(int argc, char *argv[]) { 42 | 43 | char input[MAXS], output[MAXS], config[MAXS], render[MAXS], dot[MAXS]; 44 | bool rc=false,ri=false,ro=false,rr=false,rd=false; 45 | input[0] = output[0] = config[0] = 0; 46 | 47 | int option; 48 | while ((option = getopt (argc, argv, "c:i:o:r:d:")) != -1) 49 | switch (option) { 50 | case 'c': strcpy(config, optarg); rc=true; break; 51 | case 'i': strcpy(input, optarg); ri=true; break; 52 | case 'o': strcpy(output, optarg); ro=true; break; 53 | case 'r': strcpy(render, optarg); rr=true; break; 54 | case 'd': strcpy(dot, optarg); rd=true; break; 55 | case '?': usage(argv[0]); return -1; 56 | } 57 | 58 | //Check mandatory args 59 | if( !rc || !ri ) { 60 | usage(argv[0]); 61 | return -1; 62 | } 63 | 64 | //Because some of the feature extraction code uses std::cout/cin 65 | ios_base::sync_with_stdio(true); 66 | 67 | //Load sample and system configuration 68 | Sample m( input ); 69 | meParser seshat(config); 70 | 71 | //Render image to file 72 | if( rr ) m.render_img(render); 73 | 74 | //Set output InkML file 75 | if( ro ) m.set_out_inkml( output ); 76 | 77 | //Set output DOT graph file 78 | if( rd ) m.set_out_dot( dot ); 79 | 80 | //Print sample information 81 | m.print(); 82 | printf("\n"); 83 | 84 | //Parse math expression 85 | seshat.parse_me(&m); 86 | 87 | return 0; 88 | } 89 | -------------------------------------------------------------------------------- /sparel.cc: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include "sparel.h" 24 | 25 | using namespace std; 26 | 27 | //Aux functions 28 | 29 | Hypothesis *leftmost(Hypothesis *h) { 30 | if( h->pt ) 31 | return h; 32 | 33 | Hypothesis *izq = leftmost(h->hi); 34 | Hypothesis *der = leftmost(h->hd); 35 | 36 | return izq->parent->x < der->parent->x ? izq : der; 37 | } 38 | 39 | Hypothesis *rightmost(Hypothesis *h) { 40 | if( h->pt ) 41 | return h; 42 | 43 | Hypothesis *izq = rightmost(h->hi); 44 | Hypothesis *der = rightmost(h->hd); 45 | 46 | return izq->parent->s > der->parent->s ? izq : der; 47 | } 48 | 49 | //Percentage of the area of region A that overlaps with region B 50 | float solape(CellCYK *a, CellCYK *b) { 51 | int x = max(a->x, b->x); 52 | int y = max(a->y, b->y); 53 | int s = min(a->s, b->s); 54 | int t = min(a->t, b->t); 55 | 56 | if( s >= x && t >= y ) { 57 | float aSolap = (s-x+1.0)*(t-y+1.0); 58 | float aTotal = (a->s - a->x+1.0)*(a->t - a->y+1.0); 59 | 60 | return aSolap/aTotal; 61 | } 62 | 63 | return 0.0; 64 | } 65 | 66 | 67 | // 68 | //SpaRel methods 69 | // 70 | 71 | SpaRel::SpaRel(GMM *gmm, Sample *m) { 72 | model = gmm; 73 | mue = m; 74 | } 75 | 76 | void SpaRel::smooth(float *post){ 77 | for(int i=0; iparent->t, b->parent->t) - min(a->parent->y, b->parent->y) + 1; 84 | 85 | sample[0] = (b->parent->t-b->parent->y+1)/F; 86 | sample[1] = (a->rcen - b->lcen)/F; 87 | sample[2] = ((a->parent->s+a->parent->x)/2.0 - (b->parent->s+b->parent->x)/2.0)/F; 88 | sample[3] = (b->parent->x-a->parent->s)/F; 89 | sample[4] = (b->parent->x-a->parent->x)/F; 90 | sample[5] = (b->parent->s-a->parent->s)/F; 91 | sample[6] = (b->parent->y-a->parent->t)/F; 92 | sample[7] = (b->parent->y-a->parent->y)/F; 93 | sample[8] = (b->parent->t-a->parent->t)/F; 94 | } 95 | 96 | double SpaRel::compute_prob(Hypothesis *h1, Hypothesis *h2, int k) { 97 | 98 | //Set probabilities according to spatial constraints 99 | 100 | if( k<=2 ) { 101 | //Check left-to-right order constraint in Hor/Sub/Sup relationships 102 | Hypothesis *rma = rightmost(h1); 103 | Hypothesis *lmb = leftmost(h2); 104 | 105 | if( lmb->parent->x < rma->parent->x || lmb->parent->s <= rma->parent->s ) 106 | return 0.0; 107 | } 108 | 109 | //Compute probabilities 110 | float sample[NFEAT]; 111 | 112 | getFeas(h1,h2,sample,mue->RY); 113 | 114 | //Get spatial relationships probability from the model 115 | model->posterior(sample, probs); 116 | 117 | //Slightly smooth probabilities because GMM classifier can provide 118 | //to biased probabilities. Thsi way we give some room to the 119 | //language model (the 2D-SCFG grammar) 120 | smooth(probs); 121 | 122 | return probs[k]; 123 | } 124 | 125 | 126 | SpaRel::~SpaRel() { 127 | } 128 | 129 | 130 | double SpaRel::getHorProb(Hypothesis *ha, Hypothesis *hb) { 131 | return compute_prob(ha,hb,0); 132 | } 133 | double SpaRel::getSubProb(Hypothesis *ha, Hypothesis *hb) { 134 | return compute_prob(ha,hb,1); 135 | } 136 | double SpaRel::getSupProb(Hypothesis *ha, Hypothesis *hb) { 137 | return compute_prob(ha,hb,2); 138 | } 139 | double SpaRel::getVerProb(Hypothesis *ha, Hypothesis *hb, bool strict) { 140 | //Pruning 141 | if( hb->parent->y < (ha->parent->y + ha->parent->t)/2 142 | || abs((ha->parent->x+ha->parent->s)/2 - (hb->parent->x+hb->parent->s)/2) > 2.5*mue->RX 143 | || (hb->parent->x > ha->parent->s || hb->parent->s < ha->parent->x) ) 144 | return 0.0; 145 | 146 | if( !strict ) 147 | return compute_prob(ha,hb,3); 148 | 149 | //Penalty for strict relationships 150 | float penalty = abs(ha->parent->x - hb->parent->x)/(3.0*mue->RX) 151 | + abs(ha->parent->s - hb->parent->s)/(3.0*mue->RX); 152 | 153 | if( penalty > 0.95 ) penalty = 0.95; 154 | 155 | return (1.0 - penalty) * compute_prob(ha,hb,3); 156 | } 157 | 158 | double SpaRel::getInsProb(Hypothesis *ha, Hypothesis *hb) { 159 | if( solape(hb->parent,ha->parent) < 0.5 || 160 | hb->parent->x < ha->parent->x || hb->parent->y < ha->parent->y ) 161 | return 0.0; 162 | 163 | return compute_prob(ha,hb,4); 164 | } 165 | double SpaRel::getMrtProb(Hypothesis *ha, Hypothesis *hb) { 166 | return compute_prob(ha,hb,5); 167 | } 168 | -------------------------------------------------------------------------------- /sparel.h: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #ifndef _SPAREL_ 19 | #define _SPAREL_ 20 | 21 | class CellCYK; 22 | 23 | #include 24 | #include "hypothesis.h" 25 | #include "cellcyk.h" 26 | #include "gmm.h" 27 | #include "sample.h" 28 | 29 | class SpaRel{ 30 | public: 31 | static const int NRELS = 6; 32 | static const int NFEAT = 9; 33 | 34 | private: 35 | GMM *model; 36 | Sample *mue; 37 | float probs[NRELS]; 38 | 39 | double compute_prob(Hypothesis *h1, Hypothesis *h2, int k); 40 | void smooth(float *post); 41 | 42 | public: 43 | SpaRel(GMM *gmm, Sample *m); 44 | ~SpaRel(); 45 | 46 | void getFeas(Hypothesis *a, Hypothesis *b, float *sample, int ry); 47 | 48 | double getHorProb(Hypothesis *ha, Hypothesis *hb); 49 | double getSubProb(Hypothesis *ha, Hypothesis *hb); 50 | double getSupProb(Hypothesis *ha, Hypothesis *hb); 51 | double getVerProb(Hypothesis *ha, Hypothesis *hb, bool strict=false); 52 | double getInsProb(Hypothesis *ha, Hypothesis *hb); 53 | double getMrtProb(Hypothesis *ha, Hypothesis *hb); 54 | }; 55 | 56 | #endif 57 | -------------------------------------------------------------------------------- /stroke.cc: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #include "stroke.h" 19 | #include 20 | #include 21 | 22 | bool esNum(char c){ 23 | return (c >= '0' && c <= '9') || c=='-' || c=='.'; 24 | } 25 | 26 | Stroke::Stroke(int np) { 27 | NP = np; 28 | pseq = new Punto[NP]; 29 | 30 | cx = cy = 0; 31 | rx = ry = INT_MAX; 32 | rs = rt = -INT_MAX; 33 | for(int i=0; i rs ) rs = pseq[i].x; 48 | if( pseq[i].y > rt ) rt = pseq[i].y; 49 | } 50 | } 51 | 52 | 53 | Stroke::Stroke(char *str, int inkml_id) { 54 | char aux[512]; 55 | int iaux; 56 | 57 | id = inkml_id; 58 | 59 | vector data; 60 | 61 | //Remove broken lines 62 | for(int i=0; str[i]; i++) 63 | if( str[i] == '\n' ) { 64 | for(int j=i; str[j]; j++) 65 | str[j] = str[j+1]; 66 | } 67 | 68 | for(int i=0; str[i]; i++) { 69 | 70 | while( str[i] && !esNum(str[i]) ) i++; 71 | 72 | if( !str[i] ) break; 73 | 74 | float px=0, py=0; 75 | 76 | for(iaux=0; str[i] && esNum(str[i]); iaux++, i++) 77 | aux[iaux] = str[i]; 78 | aux[iaux] = 0; 79 | 80 | if( !str[i] ) break; 81 | 82 | px=atof(aux); 83 | 84 | while( str[i] && !esNum(str[i]) ) i++; 85 | 86 | if( !str[i] ) break; 87 | 88 | for(iaux=0; str[i] && esNum(str[i]); iaux++, i++) 89 | aux[iaux] = str[i]; 90 | aux[iaux] = 0; 91 | 92 | py=atof(aux); 93 | 94 | while( str[i] && str[i] != ',' ) i++; 95 | i--; 96 | 97 | data.push_back(new Punto(px,py)); 98 | } 99 | 100 | NP = (int)data.size(); 101 | pseq = new Punto[NP]; 102 | for(int i=0; ix; 114 | pseq[idx].y = p->y; 115 | 116 | if( pseq[idx].x < rx ) rx = pseq[idx].x; 117 | if( pseq[idx].y < ry ) ry = pseq[idx].y; 118 | if( pseq[idx].x > rs ) rs = pseq[idx].x; 119 | if( pseq[idx].y > rt ) rt = pseq[idx].y; 120 | } 121 | 122 | Punto *Stroke::get(int idx) { 123 | return &pseq[idx]; 124 | } 125 | 126 | int Stroke::getNpuntos() { 127 | return NP; 128 | } 129 | 130 | int Stroke::getId() { 131 | return id; 132 | } 133 | 134 | void Stroke::print() { 135 | printf("STROKE - %d points\n", NP); 136 | for(int i=0; igetNpuntos(); j++) { 145 | Punto *p = st->get(j); 146 | 147 | float d = (pseq[i].x - p->x)*(pseq[i].x - p->x) 148 | + (pseq[i].y - p->y)*(pseq[i].y - p->y); 149 | 150 | if( d < mind ) mind=d; 151 | } 152 | } 153 | 154 | return sqrt( mind ); 155 | } 156 | -------------------------------------------------------------------------------- /stroke.h: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #ifndef _STROKE_ 19 | #define _STROKE_ 20 | 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | using namespace std; 27 | 28 | struct Punto{ 29 | float x,y; 30 | 31 | Punto(float vx, float vy) { 32 | x = vx; 33 | y = vy; 34 | } 35 | 36 | Punto() {} 37 | }; 38 | 39 | 40 | class Stroke{ 41 | Punto *pseq; 42 | int NP; 43 | int id; //InkML information 44 | 45 | public: 46 | //Coordinates of the region it defines 47 | int rx, ry, rs, rt; 48 | int cx, cy; //Centroid 49 | 50 | Stroke(int np); 51 | Stroke(int np, FILE *fd); 52 | Stroke(FILE *fd); 53 | Stroke(char *str, int inkml_id); 54 | ~Stroke(); 55 | 56 | void set(int idx, Punto *p); 57 | Punto *get(int idx); 58 | int getNpuntos(); 59 | int getId(); 60 | void print(); 61 | 62 | float min_dist(Stroke *st); 63 | }; 64 | 65 | #endif 66 | -------------------------------------------------------------------------------- /symfeatures.h: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #ifndef _SYMFEATURES_ 19 | #define _SYMFEATURES_ 20 | 21 | class DataSequence; 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include "online.h" 31 | #include "featureson.h" 32 | #include "sample.h" 33 | 34 | class SymFeatures{ 35 | static const int ON_FEAT = 7; 36 | static const int OFF_FEAT = 9; 37 | double means_on[ON_FEAT], means_off[OFF_FEAT]; 38 | double stds_on[ON_FEAT], stds_off[OFF_FEAT]; 39 | 40 | public: 41 | SymFeatures(char *mav_on, char *mav_off); 42 | ~SymFeatures(); 43 | 44 | DataSequence *getOnline(Sample *M, SegmentHyp *SegHyp); 45 | DataSequence *getOfflineFKI(int **img, int H, int W); 46 | }; 47 | 48 | #endif 49 | -------------------------------------------------------------------------------- /symrec.h: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #ifndef _SYMREC_ 19 | #define _SYMREC_ 20 | 21 | struct SegmentHyp; 22 | class Sample; 23 | class SymFeatures; 24 | 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include "rnnlib4seshat/DataSequence.hpp" 31 | #include "rnnlib4seshat/NetcdfDataset.hpp" 32 | #include "rnnlib4seshat/Mdrnn.hpp" 33 | #include "rnnlib4seshat/MultilayerNet.hpp" 34 | #include "rnnlib4seshat/Rprop.hpp" 35 | #include "rnnlib4seshat/SteepestDescent.hpp" 36 | #include "rnnlib4seshat/Trainer.hpp" 37 | #include "rnnlib4seshat/WeightContainer.hpp" 38 | #include "sample.h" 39 | #include "symfeatures.h" 40 | 41 | using namespace std; 42 | 43 | 44 | class SymRec{ 45 | SymFeatures *FEAS; 46 | DataHeader header_on, header_off; 47 | DataExportHandler deh_on, deh_off; 48 | WeightContainer *wc_on, *wc_off; 49 | Mdrnn *blstm_on, *blstm_off; 50 | float RNNalpha; 51 | 52 | //Symbol classes and types information 53 | int *type; 54 | map cl2key; 55 | string *key2cl; 56 | 57 | int C; //Number of classes 58 | 59 | int classify(Sample *M, SegmentHyp *SegHyp, const int NB, int *vclase, float *vpr, int *as, int *ds); 60 | void BLSTMclassification( Mdrnn *net, DataSequence *seq, pair *claspr, const int NB ); 61 | 62 | public: 63 | SymRec(char *path); 64 | ~SymRec(); 65 | 66 | char *strClase(int c); 67 | int keyClase(char *str); 68 | bool checkClase(char *str); 69 | int getNClases(); 70 | int symType(int k); 71 | 72 | int clasificar(Sample *M, int ncomp, const int NB, int *vclase, float *vpr, int *as, int *ds); 73 | int clasificar(Sample *M, list *LT, const int NB, int *vclase, float *vpr, int *as, int *ds); 74 | }; 75 | 76 | 77 | #endif 78 | -------------------------------------------------------------------------------- /tablecyk.cc: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #include "tablecyk.h" 19 | #include 20 | #include 21 | 22 | bool operator<(const coo &A, const coo &B) { 23 | if( A.x < B.x ) return true; 24 | if( A.x == B.x ) { 25 | if( A.y < B.y ) return true; 26 | if( A.y == B.y ) { 27 | if( A.s < B.s ) return true; 28 | if( A.s == B.s ) 29 | if( A.t < B.t ) return true; 30 | } 31 | } 32 | return false; 33 | } 34 | 35 | TableCYK::TableCYK(int n, int k) { 36 | N = n; 37 | K = k; 38 | 39 | //Target = NULL; 40 | Target = new Hypothesis(-1, -FLT_MAX, NULL, -1); 41 | pm_comps = 0.0; 42 | 43 | T = new CellCYK *[N]; 44 | for(int i=0; i[N]; 48 | } 49 | 50 | TableCYK::~TableCYK() { 51 | for(int i=0; isig; 54 | delete T[i]; 55 | T[i] = aux; 56 | } 57 | 58 | delete[] T; 59 | delete[] TS; 60 | } 61 | 62 | Hypothesis *TableCYK::getMLH() { 63 | return Target; 64 | } 65 | 66 | CellCYK *TableCYK::get(int n) { 67 | return T[n-1]; 68 | } 69 | 70 | int TableCYK::size(int n) { 71 | return TS[n-1].size(); 72 | } 73 | 74 | void TableCYK::updateTarget(coo *K, Hypothesis *H) { 75 | int pcomps = 0; 76 | 77 | for(int i=0; i < H->parent->nc; i++) 78 | if( H->parent->ccc[i] ) 79 | pcomps++; 80 | 81 | if( pcomps > pm_comps || (pcomps==pm_comps && H->pr > Target->pr) ) { 82 | Target->copy( H ); 83 | pm_comps = pcomps; 84 | } 85 | } 86 | 87 | void TableCYK::add(int n, CellCYK *celda, int noterm_id, bool *esinit) { 88 | coo key(celda->x, celda->y, celda->s, celda->t); 89 | map::iterator it=TS[n-1].find( key ); 90 | 91 | celda->talla = n; 92 | 93 | if( it == TS[n-1].end() ) { 94 | //Link as head of size n 95 | celda->sig = T[n-1]; 96 | T[n-1] = celda; 97 | TS[n-1][key] = celda; 98 | 99 | if( noterm_id >= 0 ) { 100 | if( esinit[noterm_id] ) 101 | updateTarget(&key, celda->noterm[noterm_id] ); 102 | } 103 | else { 104 | 105 | for(int nt=0; ntnnt; nt++) 106 | if( celda->noterm[nt] && esinit[nt] ) 107 | updateTarget(&key, celda->noterm[nt] ); 108 | 109 | } 110 | } 111 | else { //Maximize probability avoiding duplicates 112 | 113 | int VA, VB; 114 | if( noterm_id < 0 ) { 115 | VA = 0; 116 | VB = celda->nnt; 117 | } 118 | else { 119 | VA = noterm_id; 120 | VB = VA+1; 121 | } 122 | 123 | CellCYK *r = it->second; 124 | 125 | if( !celda->ccEqual( r ) ) { 126 | //The cells cover the same region with a different set of strokes 127 | 128 | float maxpr_c=-FLT_MAX; 129 | for(int i=VA; inoterm[i] && celda->noterm[i]->pr > maxpr_c ) 131 | maxpr_c = celda->noterm[i]->pr; 132 | 133 | float maxpr_r=-FLT_MAX; 134 | for(int i=0; innt; i++) 135 | if( r->noterm[i] && r->noterm[i]->pr > maxpr_r ) 136 | maxpr_r = r->noterm[i]->pr; 137 | 138 | //If the new cell contains the most likely hypothesis, replace the hypotheses 139 | if( maxpr_c > maxpr_r ) { 140 | 141 | //Copy the new set of strokes 142 | for(int i=0; inc; i++) 143 | r->ccc[i] = celda->ccc[i]; 144 | 145 | //Replace the hypotheses for each non-terminal 146 | for(int i=0; innt; i++) 147 | if( celda->noterm[i] ) { 148 | 149 | if( r->noterm[i] ) { 150 | r->noterm[i]->copy( celda->noterm[i] ); 151 | r->noterm[i]->parent = r; 152 | 153 | if( esinit[i] ) 154 | updateTarget(&key, r->noterm[i] ); 155 | } 156 | else{ 157 | r->noterm[i] = celda->noterm[i]; 158 | r->noterm[i]->parent = r; 159 | 160 | //Set to NULL such that the "delete celda" doesn't delete the hypothesis 161 | celda->noterm[i] = NULL; 162 | 163 | if( esinit[i] ) 164 | updateTarget(&key, r->noterm[i] ); 165 | } 166 | 167 | } 168 | else if( r->noterm[i] ) { 169 | delete r->noterm[i]; 170 | r->noterm[i] = NULL; 171 | } 172 | 173 | } 174 | 175 | delete celda; 176 | 177 | //Finished 178 | return; 179 | } 180 | 181 | 182 | for(int i=VA; inoterm[i] ) { 185 | if( r->noterm[i] ) { 186 | 187 | if( celda->noterm[i]->pr > r->noterm[i]->pr ) { 188 | //Maximize probability (replace) 189 | r->noterm[i]->copy( celda->noterm[i] ); 190 | r->noterm[i]->parent = r; 191 | 192 | if( esinit[i] ) 193 | updateTarget(&key, r->noterm[i] ); 194 | } 195 | 196 | } 197 | else { 198 | r->noterm[i] = celda->noterm[i]; 199 | r->noterm[i]->parent = r; 200 | 201 | //Set to NULL such that the "delete celda" doesn't delete the hypothesis 202 | celda->noterm[i] = NULL; 203 | 204 | if( esinit[i] ) 205 | updateTarget(&key, r->noterm[i] ); 206 | } 207 | } 208 | 209 | } 210 | 211 | delete celda; 212 | } 213 | 214 | } 215 | -------------------------------------------------------------------------------- /tablecyk.h: -------------------------------------------------------------------------------- 1 | /*Copyright 2014 Francisco Alvaro 2 | 3 | This file is part of SESHAT. 4 | 5 | SESHAT is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | SESHAT is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with SESHAT. If not, see . 17 | */ 18 | #ifndef _TABLECYK_ 19 | #define _TABLECYK_ 20 | 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include "cellcyk.h" 27 | #include "hypothesis.h" 28 | 29 | using namespace std; 30 | 31 | //Structure to handle coordinates 32 | 33 | struct coo{ 34 | int x,y,s,t; 35 | 36 | coo() { 37 | x = y = s = t = -1; 38 | } 39 | 40 | coo(int a, int b, int c, int d) { 41 | x=a; y=b; s=c; t=d; 42 | } 43 | 44 | bool operator==(coo &R) { 45 | return x == R.x && y == R.y && s == R.s && t == R.t; 46 | } 47 | }; 48 | 49 | bool operator<(const coo &A, const coo &B); 50 | 51 | class TableCYK{ 52 | CellCYK **T; 53 | map *TS; 54 | int N, K; 55 | 56 | //Hypothesis that accounts for the target (input) math expression 57 | Hypothesis *Target; 58 | 59 | //Percentage of strokes covered by the most likely hypothesis (target) 60 | int pm_comps; 61 | 62 | public: 63 | TableCYK(int n, int k); 64 | ~TableCYK(); 65 | 66 | Hypothesis *getMLH(); 67 | CellCYK *get(int n); 68 | int size(int n); 69 | void updateTarget(coo *K, Hypothesis *H); 70 | void add(int n, CellCYK *celda, int noterm_id, bool *esinit); 71 | }; 72 | 73 | 74 | #endif 75 | --------------------------------------------------------------------------------