├── .gitignore ├── LICENSE.md ├── README.md ├── build.gradle ├── gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat ├── job.bash ├── job_cbow.bash ├── job_embedding.bash ├── job_tokenize.bash └── src ├── main ├── html │ ├── editor_plugin2.js │ ├── index.html │ ├── setup.js │ └── style.css ├── java │ └── de │ │ └── hhu │ │ └── mabre │ │ └── languagetool │ │ ├── BinarySentenceDatabaseCreator.java │ │ ├── BinarySentencesDict.java │ │ ├── FileTokenizer.java │ │ ├── NGram.java │ │ ├── NGramDatabaseCreator.java │ │ ├── PythonDict.java │ │ ├── PythonGateway.java │ │ ├── SamplingMode.java │ │ ├── SentenceDatabaseCreator.java │ │ ├── SentencesDict.java │ │ ├── SubjectGrepper.java │ │ ├── SubsetType.java │ │ └── transformationrules │ │ ├── Das_Dass.java │ │ ├── Dass_Das.java │ │ ├── KommaDas_Das.java │ │ ├── KommaDass_Das.java │ │ ├── KommaOhneDass_OhneDas.java │ │ ├── TransformationRule.java │ │ ├── ihm_im.java │ │ └── im_ihm.java ├── python │ ├── EvalResult.py │ ├── LayeredScorer.py │ ├── embedding │ │ ├── LICENSE │ │ ├── README.md │ │ ├── __init__.py │ │ ├── cbow.py │ │ ├── cbow_checker.py │ │ ├── common.py │ │ ├── question-words.txt │ │ ├── word2vec.py │ │ ├── word2vec_kernels.cc │ │ ├── word2vec_ops.cc │ │ ├── word2vec_optimized.py │ │ ├── word2vec_optimized_test.py │ │ ├── word2vec_test.py │ │ └── word_to_char_embedding.py │ ├── eval.py │ ├── languagetool │ │ └── languagetool.py │ ├── nn.py │ ├── nn_word_sequence.py │ ├── nn_words.py │ ├── nn_words_correct_C.py │ ├── parallel.py │ ├── random_forest.py │ └── repl.py └── resources │ └── example-corpus.txt └── test ├── java └── de │ └── hhu │ └── mabre │ └── languagetool │ ├── BinarySentenceDatabaseCreatorTest.java │ ├── FileTokenizerTest.java │ ├── NGramDatabaseCreatorTest.java │ ├── NGramTest.java │ ├── SentenceDatabaseCreatorTest.java │ ├── SentencesDictTest.java │ └── SubjectGrepperTest.java └── python ├── embedding ├── .cache │ └── v │ │ └── cache │ │ └── lastfailed ├── test_common.py └── test_word_to_char_embedding.py └── test_eval.py /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /classes 3 | /checkouts 4 | /build 5 | checkpoint 6 | events.out.* 7 | *.iml 8 | *.ipr 9 | *.iws 10 | pom.xml 11 | pom.xml.asc 12 | *.jar 13 | *.class 14 | /.lein-* 15 | /.nrepl-port 16 | .hgignore 17 | .hg/ 18 | *.data* 19 | *.index* 20 | *.meta* 21 | *.txt 22 | .idea 23 | *.so 24 | __pycache__ 25 | .ipynb_checkpoints 26 | .gradle 27 | *bash.o* 28 | *bash.e* 29 | job*_*_*bash 30 | *log 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License: AGPL v3](https://img.shields.io/badge/License-AGPL%20v3-blue.svg)](https://www.gnu.org/licenses/agpl-3.0) 2 | 3 | © 2017 Markus Brenneis 4 | 5 | # For Users of LanguageTool 6 | 7 | * Make sure you are using [LanguageTool](https://languagetool.org) version 4.0 or later. 8 | * Download the language data from [here](https://fscs.hhu.de/languagetool/word2vec.tar.gz) and extract the archive to a directory. 9 | * In the LanguageTool settings, choose that directory (which contains the subfolders “de”, “en” etc.) in “word2vec data directory”. 10 | * If you are using a LanguageTool server, set `word2vecDir` in your `languagetool.cfg` (by default located at `~/.languagetool.cfg` in Linux). 11 | 12 | # TL;DR 13 | 14 | In case everything is already set up: 15 | 16 | ```bash 17 | ./gradlew createNGramDatabase -PlanguageCode="en-US" -PcorpusFile="training-corpus.txt" -Ptokens="to too" 18 | python3 src/main/python/nn_words.py dictionary.txt final_embeddings.txt /tmp/to_too_training.py /tmp/to_too_validate.py . 19 | # copy W_fc1.txt and b_fc1.txt to nuralnetwork/en/to_too 20 | # edit nuralnetwork/en/confusion_sets.txt 21 | # calibrate using NeuralNetworkRuleEvaluator language-code word2vec-dir RULE_ID corpus1.xml 22 | ``` 23 | 24 | # Prerequisites 25 | 26 | ## Software 27 | 28 | This README assumes you are using an Ubuntu based operating system, but the instructions should basically also work for every other operating system. 29 | 30 | You need Java 8 (probably alread installed if you can compile LanguageTool) and python3 with pip (the Python package manager, `sudo apt install python3-pip`). Install the following packages using pip: 31 | 32 | * TensorFlow: machine learning library for training neural networks 33 | * scikit-learn: machine learning library 34 | * NumPy: scientific computing library 35 | 36 | 37 | ``` 38 | pip3 install --user tensorflow scikit-learn numpy 39 | ``` 40 | 41 | Note that TensorFlow officially supports 64 bit systems only. 42 | 43 | If you have an nVidia GPU, you might want to use the GPU version of TensorFlow. See [tensorflow.org](https://www.tensorflow.org/install/) for installtion instructions. As the CUDA setup can take some time, I recommend proceeding with the CPU version. 44 | 45 | ## Sources 46 | 47 | Neural network rules are supported by LanguageTool since 15 December 2017 in the development version of LanguageTool 4.0. 48 | 49 | The code for learning new rules is not part of LanguageTool. Get it by running 50 | 51 | ```bash 52 | git clone git@github.com:gulp21/languagetool-neural-network.git 53 | ``` 54 | 55 | # Adding support for new languages and confusion pairs 56 | 57 | ## Getting a corpus 58 | 59 | Whether you want to add a new language model or support for a new confusion pair, you have to get a big corpus first, which shouldn’t contain any grammar errors. Possible sources could be newspaper articles from the [Leipzig Corpora Collection](http://wortschatz.uni-leipzig.de/en/download/), [Wikipedia](http://wiki.languagetool.org/checking-the-complete-wikipedia) or [Tatoeba](https://tatoeba.org/downloads). I prefer using the Leipzig data for training and Wikipedia data for assessing rule performance. Note that newspaper and Wikipedia articles rarely include 1st and 2nd person verb forms; keep that in mind if you want to detect confusion pairs involving those verb forms. 60 | 61 | If you just want to test your setup and don't have a corpus, yet, you can use `src/main/resources/example-corpus.txt`. Note that a good corpus should contain more than 100,000 sentences. 62 | 63 | The training input files are plain text files containing sentences which may not be spread over multiple lines; Whether there are multiple sentences in one line doesn’t matter. If you’re using the Leipzig corpus, you can use the *-sentences.txt file, but you have to remove the line numbers first: 64 | 65 | ```bash 66 | sed -E "s/^[0-9]+\W+//" *-sentences.txt > training-corpus.txt 67 | ``` 68 | 69 | You now have a file `training-corpus.txt` containing lots of sentences. 70 | 71 | ## Adding support for a new language 72 | 73 | ### Tokenizing the corpus 74 | 75 | You have to tokenize the training corpus with LanguageTool. As LanguageTool itself doesn’t contain a file tokenizer for command line usage, we use a tool included in the languagetool-neural-network repository mentioned above, so `cd` to `languagetool-neural-network`. As tokenizing 1,000,000 sentences might be too much for your system memory (it can require up to 10 GB of RAM), you may decide to train the language model with fewer sentences, let’s say 300,000. 76 | 77 | ```bash 78 | shuf training-corpus.txt | head -n300000 > language-model-corpus.txt 79 | ./gradlew tokenizeFile -PlanguageCode="en-US" -PsentencesFile="language-model-corpus.txt" 80 | ``` 81 | 82 | Don’t forget to change the `languageCode` parameter. 83 | 84 | After having downloaded the whole internet, you should end up with a file called `language-model-corpus.txt-tokens`. The terminal output should look like this: 85 | 86 | ``` 87 | [...] 88 | :tokenizeFile 89 | Reading language-model-corpus.txt 90 | Tokenizing 91 | Tokens written to language-model-corpus.txt-tokens 92 | 93 | BUILD SUCCESSFUL 94 | 95 | Total time: 52.144 secs 96 | ``` 97 | 98 | ### Creating a language model 99 | 100 | If your language does not yet have a neural network rule, you have to learn a language model first, which will be shared by all neural network rules. 101 | 102 | First, you have to compile the word2vec C files: 103 | 104 | ```bash 105 | TF_INC=$(python3 -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') 106 | cd src/main/python/embedding 107 | g++ -std=c++11 -shared word2vec_ops.cc word2vec_kernels.cc -o word2vec_ops.so -fPIC -I $TF_INC -O2 -D_GLIBCXX_USE_CXX11_ABI=0 108 | cd - 109 | ``` 110 | 111 | Then you can train a new language model: 112 | 113 | ```bash 114 | python3 src/main/python/embedding/word2vec.py --train_data language-model-corpus.txt-tokens --eval_data src/main/python/embedding/question-words.txt --save_path . --epochs_to_train 10 115 | ``` 116 | 117 | If you get an error about `word2vec_ops.so`, try compiling with `-D_GLIBCXX_USE_CXX11_ABI=1` instead. 118 | 119 | The process can take a while (on my notebook I have a rate of ~7,000 words/sec; on my university’s high performance cluster ~28,000 words/sec and a total runtime of ~30 minutes). You should see that the loss value decreases over time. 120 | 121 | When the process has finished, you have files `dictionary.txt` (~1 MB) and `final_embeddings.txt` (~80 MB). Open the directory containing the existing word2vec models (or create a new directory, if you haven’t downloaded [models of other languages](fscs.hhu.de/languagetool/word2vec.tar.gz)), create a sub-directory `LANG` (e. g. `en`) and move the two files created to that directory. 122 | 123 | #### What just happened? 124 | 125 | The language model trained here is a 64 dimensional [word embedding](http://colah.github.io/posts/2014-07-NLP-RNNs-Representations/#word-embeddings). All words (or more precisely: tokens, as returned by the LanguageTool tokenizer in the previous step) which appear at least 5 times in the training corpus are mapped to a vector containing 64 numbers. Similar tokens (e. g. “I”, “you”, “he” or “my”, “your”, “her”) will magically end up “close” to each other. This will later allow the neural network to detect errors even if the exact phrase was not part of the training corpus. 126 | 127 | The `--train_data` parameter is required, but not important for us. If you are curious, have a look at [Analogical Reasoning](https://www.tensorflow.org/tutorials/word2vec#evaluating_embeddings_analogical_reasoning). 128 | 129 | 130 | ### Adding methods to Language.java 131 | 132 | Open the java class for your language and add the following methods: 133 | 134 | ```java 135 | @Override 136 | public synchronized Word2VecModel getWord2VecModel(File indexDir) throws IOException { 137 | return new Word2VecModel(indexDir + File.separator + getShortCode()); 138 | } 139 | 140 | @Override 141 | public List getRelevantWord2VecModelRules(ResourceBundle messages, Word2VecModel word2vecModel) throws IOException { 142 | return NeuralNetworkRuleCreator.createRules(messages, this, word2vecModel); 143 | } 144 | ``` 145 | 146 | 147 | ## Adding support for a new confusion pair 148 | 149 | Before you add a new confusion pair, think about whether the neural network actually has a chance to detect an error properly. The neural network gets a context of 2 tokens before and after a token as input, e. g. for the to/too pair and the sentence “I would like too learn more about neural networks.”, the network will get `[would like learn more]` as input. If you as a human can infer from `[would like learn more]` that “to” must be in the middle, the neural network can probably learn that, too. On the other hand, consider the German an/in pair and the sentence “Ich bin in der Universitätsstraße.” If you see the tokens `[Ich bin der Universitätsstraße]`, you cannot really determine whether “an” or “in” should be used, so the pair an/in is probably no good candidate for a neural network rule. 150 | 151 | NB: 152 | 153 | * As the neural network gets tokens as inputs, and outputs which of two tokens fits best, it cannot be used for pairs like their/they’re, because “they’re” is split into 3 tokens by LanguageTool. 154 | * The language model is case sensitive. 155 | 156 | ### Training the neural network 157 | 158 | First, you must generate training and validation sets from the corpus. 159 | 160 | ```bash 161 | ./gradlew createNGramDatabase -PlanguageCode="en-US" -PcorpusFile="training-corpus.txt" -Ptokens="to too" 162 | ``` 163 | 164 | Don’t forget to replace “to too” with your confusion pair tokens. The output “sampling to xxx” tells you how many training or validation samples have been created for each token. 165 | 166 | Now you have two files, `/tmp/to_too_training.py` and `/tmp/to_too_validate.py`, and we can train a neural network: 167 | 168 | ```bash 169 | python3 src/main/python/nn_words.py dictionary.txt final_embeddings.txt /tmp/to_too_training.py /tmp/to_too_validate.py . 170 | ``` 171 | 172 | Again, don’t forget to change the file paths to match your system. 173 | 174 | Depending on whether a CUDA capable GPU is available and how many training samples there are, this process takes several minutes. Accuracy should get closer to 1 over time. 175 | 176 | In the end of the training process, the neural network is validated with the validation set. The figures after “incorrect” should be near zero, and “unclassified” should not be bigger than “correct”, otherwise the learned network is probably unusable. 177 | 178 | You now have the files `W_fc1.txt` and `b_fc1.txt` in your current working directory. Move them to `word2vec/LANG/neuralnetwork/TOKEN1_TOKEN2`, where `word2vec/LANG` is the directory containing `dictionary.txt` you have created or downloaded earlier. Don’t forget to include a `LICENSE` file if needed. 179 | 180 | #### What just happened? 181 | 182 | Let’s take the to/too pair as an example. You’ve trained a single-layer [neural network](https://www.quora.com/Can-you-explain-neural-nets-in-laymans-terms) which can tell you how much it feels that “to” or “too” fits into a context. E. g. given the input `[would like help you]`, it will output the “scores” `[3.95 -4.12]`, which means that it prefers the phrase “would like to help you” above “would like too help you”. On the other hand, scores like `[0.04 -0.10]` mean that the network has no preference. Which minimum score is required to mark the usage of a token as wrong is determined during the calibration of the rules, which is described in the next section. As of now, you also see those scores as part of the error message when a neural network rule detects an error. 183 | 184 | ### Adding the rule 185 | 186 | Add a new line to `word2vec/LANG/neuralnetwork/confusion_sets.txt` which looks like this: 187 | 188 | ``` 189 | to; too; 0.5 190 | ``` 191 | 192 | If you start LanguageTool now, the rule, which has the id `LANG_to_VS_too_NEURALNETWORK`, should work, if you have specified the word2vec directory in the settings. The new rule might cause more false alarms than necessary, though. 193 | 194 | Now you have to tweak the sensitivity of the rule, which currently is 0.5. Open `org.languagetool.dev.bigdata.NeuralNetworkRuleEvaluator` in your IDE and run the main method with the arguments `language-code word2vec-directory RULE_ID corpus1.xml corpus2.txt etc.`; a corpus can be a Wikipedia XML file or some plain text file; any number of corpora may be given. Do not use the same corpus you used for training! The output will look like this: 195 | 196 | ``` 197 | Results for both tokens 198 | to; too; 0.50; # p=0.985, r=0.900, tp=1320, tn=1446, fp=20, fn=146, 1000+466, 2017-10-15 199 | to; too; 0.75; # p=0.988, r=0.877, tp=1286, tn=1450, fp=16, fn=180, 1000+466, 2017-10-15 200 | to; too; 1.00; # p=0.992, r=0.844, tp=1237, tn=1456, fp=10, fn=229, 1000+466, 2017-10-15 201 | to; too; 1.25; # p=0.993, r=0.815, tp=1195, tn=1457, fp=9, fn=271, 1000+466, 2017-10-15 202 | to; too; 1.50; # p=0.993, r=0.778, tp=1141, tn=1458, fp=8, fn=325, 1000+466, 2017-10-15 203 | to; too; 1.75; # p=0.993, r=0.727, tp=1066, tn=1459, fp=7, fn=400, 1000+466, 2017-10-15 204 | to; too; 2.00; # p=0.994, r=0.681, tp=999, tn=1460, fp=6, fn=467, 1000+466, 2017-10-15 205 | to; too; 2.25; # p=0.996, r=0.621, tp=911, tn=1462, fp=4, fn=555, 1000+466, 2017-10-15 206 | to; too; 2.50; # p=0.996, r=0.559, tp=820, tn=1463, fp=3, fn=646, 1000+466, 2017-10-15 207 | to; too; 2.75; # p=0.996, r=0.489, tp=717, tn=1463, fp=3, fn=749, 1000+466, 2017-10-15 208 | to; too; 3.00; # p=0.998, r=0.428, tp=628, tn=1465, fp=1, fn=838, 1000+466, 2017-10-15 209 | to; too; 3.25; # p=0.998, r=0.369, tp=541, tn=1465, fp=1, fn=925, 1000+466, 2017-10-15 210 | to; too; 3.50; # p=0.998, r=0.315, tp=462, tn=1465, fp=1, fn=1004, 1000+466, 2017-10-15 211 | to; too; 3.75; # p=0.997, r=0.265, tp=389, tn=1465, fp=1, fn=1077, 1000+466, 2017-10-15 212 | to; too; 4.00; # p=1.000, r=0.225, tp=330, tn=1466, fp=0, fn=1136, 1000+466, 2017-10-15 213 | 214 | Time: 133742 ms 215 | Recommended configuration: 216 | to; too; 1.00 # p=0.992, r=0.844, tp=1237, tn=1456, fp=10, fn=229, 1000+466, 2017-10-15 217 | ``` 218 | 219 | The p value is the precision, which tells you how often a detected error was an actual error (i. e. 1−p is the probability for false alarms). The r value is the recall, which tells you how often the rule could find an error. As a rule of thumb, the precision should be greater than 0.99, or 0.995 for common words/tokens. Recall should be greater than 0.5, otherwise the rule won’t detect many errors. If you have chosen a good certainty level (which is the same as the score I mentioned earlier), you can update `neuralnetwork/confusion_sets.txt`: 220 | 221 | ``` 222 | to; too; 1.00 # p=0.992, r=0.844, tp=1237, tn=1456, fp=10, fn=229, 1000+466, 2017-10-15 223 | ``` 224 | 225 | You can also pass `ALL` as rule id to evaluate the performance for all confusion sets in `confusion_sets.txt`. 226 | -------------------------------------------------------------------------------- /build.gradle: -------------------------------------------------------------------------------- 1 | apply plugin: 'java' 2 | apply plugin: 'idea' 3 | 4 | task tokenizeFile(type:JavaExec) { 5 | main = "de.hhu.mabre.languagetool.FileTokenizer" 6 | classpath = sourceSets.main.runtimeClasspath 7 | args "${-> languageCode}", "${-> sentencesFile}" 8 | maxHeapSize = "19000m" 9 | } 10 | 11 | task createNGramDatabase(type:JavaExec) { 12 | main = "de.hhu.mabre.languagetool.NGramDatabaseCreator" 13 | classpath = sourceSets.main.runtimeClasspath 14 | if(project.hasProperty("corpusFile")) { 15 | if(project.hasProperty("token3")) { 16 | args "${-> languageCode}", "${-> corpusFile}", "${-> token1}", "${-> token2}", "${-> token3}" 17 | } else if(project.hasProperty("token2")) { 18 | args "${-> languageCode}", "${-> corpusFile}", "${-> token1}", "${-> token2}" 19 | } else { 20 | args "${-> languageCode}", "${-> corpusFile}" 21 | args += tokens.split().toList() 22 | } 23 | } 24 | maxHeapSize = "10000m" 25 | } 26 | 27 | task pythonGateway(type:JavaExec) { 28 | main = "de.hhu.mabre.languagetool.PythonGateway" 29 | classpath = sourceSets.main.runtimeClasspath 30 | maxHeapSize = "10000m" 31 | } 32 | 33 | repositories { 34 | jcenter() 35 | mavenCentral() 36 | } 37 | 38 | dependencies { 39 | compile 'org.languagetool:languagetool-core:3.9' 40 | compile 'org.languagetool:language-all:3.9' 41 | compile 'com.google.code.gson:gson:2.8.2' 42 | compile 'net.sf.py4j:py4j:0.10.6' 43 | testCompile 'junit:junit:3.8.2' 44 | testCompile 'junit:junit:4.12' 45 | } 46 | -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulp21/languagetool-neural-network/02aec088afe9ecd383904cdd3bd7c6a5ada0c9f0/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | #Sun Oct 15 13:35:49 CEST 2017 2 | distributionBase=GRADLE_USER_HOME 3 | distributionPath=wrapper/dists 4 | zipStoreBase=GRADLE_USER_HOME 5 | zipStorePath=wrapper/dists 6 | distributionUrl=https\://services.gradle.org/distributions/gradle-2.10-bin.zip 7 | -------------------------------------------------------------------------------- /gradlew: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ############################################################################## 4 | ## 5 | ## Gradle start up script for UN*X 6 | ## 7 | ############################################################################## 8 | 9 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 10 | DEFAULT_JVM_OPTS="" 11 | 12 | APP_NAME="Gradle" 13 | APP_BASE_NAME=`basename "$0"` 14 | 15 | # Use the maximum available, or set MAX_FD != -1 to use that value. 16 | MAX_FD="maximum" 17 | 18 | warn ( ) { 19 | echo "$*" 20 | } 21 | 22 | die ( ) { 23 | echo 24 | echo "$*" 25 | echo 26 | exit 1 27 | } 28 | 29 | # OS specific support (must be 'true' or 'false'). 30 | cygwin=false 31 | msys=false 32 | darwin=false 33 | case "`uname`" in 34 | CYGWIN* ) 35 | cygwin=true 36 | ;; 37 | Darwin* ) 38 | darwin=true 39 | ;; 40 | MINGW* ) 41 | msys=true 42 | ;; 43 | esac 44 | 45 | # Attempt to set APP_HOME 46 | # Resolve links: $0 may be a link 47 | PRG="$0" 48 | # Need this for relative symlinks. 49 | while [ -h "$PRG" ] ; do 50 | ls=`ls -ld "$PRG"` 51 | link=`expr "$ls" : '.*-> \(.*\)$'` 52 | if expr "$link" : '/.*' > /dev/null; then 53 | PRG="$link" 54 | else 55 | PRG=`dirname "$PRG"`"/$link" 56 | fi 57 | done 58 | SAVED="`pwd`" 59 | cd "`dirname \"$PRG\"`/" >/dev/null 60 | APP_HOME="`pwd -P`" 61 | cd "$SAVED" >/dev/null 62 | 63 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar 64 | 65 | # Determine the Java command to use to start the JVM. 66 | if [ -n "$JAVA_HOME" ] ; then 67 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 68 | # IBM's JDK on AIX uses strange locations for the executables 69 | JAVACMD="$JAVA_HOME/jre/sh/java" 70 | else 71 | JAVACMD="$JAVA_HOME/bin/java" 72 | fi 73 | if [ ! -x "$JAVACMD" ] ; then 74 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME 75 | 76 | Please set the JAVA_HOME variable in your environment to match the 77 | location of your Java installation." 78 | fi 79 | else 80 | JAVACMD="java" 81 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 82 | 83 | Please set the JAVA_HOME variable in your environment to match the 84 | location of your Java installation." 85 | fi 86 | 87 | # Increase the maximum file descriptors if we can. 88 | if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then 89 | MAX_FD_LIMIT=`ulimit -H -n` 90 | if [ $? -eq 0 ] ; then 91 | if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then 92 | MAX_FD="$MAX_FD_LIMIT" 93 | fi 94 | ulimit -n $MAX_FD 95 | if [ $? -ne 0 ] ; then 96 | warn "Could not set maximum file descriptor limit: $MAX_FD" 97 | fi 98 | else 99 | warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" 100 | fi 101 | fi 102 | 103 | # For Darwin, add options to specify how the application appears in the dock 104 | if $darwin; then 105 | GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" 106 | fi 107 | 108 | # For Cygwin, switch paths to Windows format before running java 109 | if $cygwin ; then 110 | APP_HOME=`cygpath --path --mixed "$APP_HOME"` 111 | CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` 112 | JAVACMD=`cygpath --unix "$JAVACMD"` 113 | 114 | # We build the pattern for arguments to be converted via cygpath 115 | ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` 116 | SEP="" 117 | for dir in $ROOTDIRSRAW ; do 118 | ROOTDIRS="$ROOTDIRS$SEP$dir" 119 | SEP="|" 120 | done 121 | OURCYGPATTERN="(^($ROOTDIRS))" 122 | # Add a user-defined pattern to the cygpath arguments 123 | if [ "$GRADLE_CYGPATTERN" != "" ] ; then 124 | OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" 125 | fi 126 | # Now convert the arguments - kludge to limit ourselves to /bin/sh 127 | i=0 128 | for arg in "$@" ; do 129 | CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` 130 | CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option 131 | 132 | if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition 133 | eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` 134 | else 135 | eval `echo args$i`="\"$arg\"" 136 | fi 137 | i=$((i+1)) 138 | done 139 | case $i in 140 | (0) set -- ;; 141 | (1) set -- "$args0" ;; 142 | (2) set -- "$args0" "$args1" ;; 143 | (3) set -- "$args0" "$args1" "$args2" ;; 144 | (4) set -- "$args0" "$args1" "$args2" "$args3" ;; 145 | (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; 146 | (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; 147 | (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; 148 | (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; 149 | (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; 150 | esac 151 | fi 152 | 153 | # Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules 154 | function splitJvmOpts() { 155 | JVM_OPTS=("$@") 156 | } 157 | eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS 158 | JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME" 159 | 160 | exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@" 161 | -------------------------------------------------------------------------------- /gradlew.bat: -------------------------------------------------------------------------------- 1 | @if "%DEBUG%" == "" @echo off 2 | @rem ########################################################################## 3 | @rem 4 | @rem Gradle startup script for Windows 5 | @rem 6 | @rem ########################################################################## 7 | 8 | @rem Set local scope for the variables with windows NT shell 9 | if "%OS%"=="Windows_NT" setlocal 10 | 11 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 12 | set DEFAULT_JVM_OPTS= 13 | 14 | set DIRNAME=%~dp0 15 | if "%DIRNAME%" == "" set DIRNAME=. 16 | set APP_BASE_NAME=%~n0 17 | set APP_HOME=%DIRNAME% 18 | 19 | @rem Find java.exe 20 | if defined JAVA_HOME goto findJavaFromJavaHome 21 | 22 | set JAVA_EXE=java.exe 23 | %JAVA_EXE% -version >NUL 2>&1 24 | if "%ERRORLEVEL%" == "0" goto init 25 | 26 | echo. 27 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 28 | echo. 29 | echo Please set the JAVA_HOME variable in your environment to match the 30 | echo location of your Java installation. 31 | 32 | goto fail 33 | 34 | :findJavaFromJavaHome 35 | set JAVA_HOME=%JAVA_HOME:"=% 36 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 37 | 38 | if exist "%JAVA_EXE%" goto init 39 | 40 | echo. 41 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 42 | echo. 43 | echo Please set the JAVA_HOME variable in your environment to match the 44 | echo location of your Java installation. 45 | 46 | goto fail 47 | 48 | :init 49 | @rem Get command-line arguments, handling Windowz variants 50 | 51 | if not "%OS%" == "Windows_NT" goto win9xME_args 52 | if "%@eval[2+2]" == "4" goto 4NT_args 53 | 54 | :win9xME_args 55 | @rem Slurp the command line arguments. 56 | set CMD_LINE_ARGS= 57 | set _SKIP=2 58 | 59 | :win9xME_args_slurp 60 | if "x%~1" == "x" goto execute 61 | 62 | set CMD_LINE_ARGS=%* 63 | goto execute 64 | 65 | :4NT_args 66 | @rem Get arguments from the 4NT Shell from JP Software 67 | set CMD_LINE_ARGS=%$ 68 | 69 | :execute 70 | @rem Setup the command line 71 | 72 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 73 | 74 | @rem Execute Gradle 75 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% 76 | 77 | :end 78 | @rem End local scope for the variables with windows NT shell 79 | if "%ERRORLEVEL%"=="0" goto mainEnd 80 | 81 | :fail 82 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 83 | rem the _cmd.exe /c_ return code! 84 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 85 | exit /b 1 86 | 87 | :mainEnd 88 | if "%OS%"=="Windows_NT" endlocal 89 | 90 | :omega 91 | -------------------------------------------------------------------------------- /job.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #PBS -l select=1:ncpus=1:mem=6gb 4 | #PBS -l walltime=06:00:00 5 | #PBS -A "stupsprojmabre" 6 | 7 | lang=eng 8 | subject1=to 9 | subject2=too 10 | 11 | module load TensorFlow/1.1.0 12 | module load Python/3.4.5 13 | module load CUDA/7.5.18 14 | 15 | cd ~/projektarbeit/lt/grammarchecker 16 | 17 | training_file="/tmp/${subject1}_${subject2}_training.py" 18 | validate_file="/tmp/${subject1}_${subject2}_validate.py" 19 | 20 | output_path="res_training/$lang/${subject1}_${subject2}" 21 | mkdir $output_path 22 | 23 | export LOGFILE=$output_path/$PBS_JOBNAME"."$PBS_JOBID".log" 24 | 25 | echo `date` create_training_files >> $LOGFILE 26 | ./gradlew createNGramDatabase -PlanguageCode="en-US" -PcorpusFile="res_training/$lang/${lang}_news_2015_3M-sentences-raw.txt" -Ptokens="$subject1 $subject2" >> $LOGFILE 27 | 28 | echo `date` create_classifier >> $LOGFILE 29 | python3 src/main/python/nn_words.py res_training/$lang/embedding/dictionary.txt res_training/$lang/embedding/final_embeddings.txt $training_file $validate_file $output_path >> $LOGFILE 30 | 31 | echo `date` finished, $output_path >> $LOGFILE 32 | -------------------------------------------------------------------------------- /job_cbow.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #PBS -l select=1:ncpus=8:mem=50gb 4 | #PBS -l walltime=24:00:00 5 | #PBS -A "stupsprojmabre" 6 | 7 | lang=deu 8 | langcode="de-DE" 9 | corpus=res_training/$lang/news_tatoeba_training.txt-tokens 10 | outdir=res_training/$lang/cbow 11 | 12 | mkdir $outdir 13 | 14 | export LOGFILE=$PBS_O_WORKDIR/$PBS_JOBNAME"."$PBS_JOBID".cbow.log" 15 | 16 | module load TensorFlow/1.1.0 17 | module load Python/3.4.5 18 | module load CUDA/7.5.18 19 | 20 | cd ~/projektarbeit/languagetool-neural-network 21 | echo `date` >> $LOGFILE 22 | TF_INC=$(python3 -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') 23 | ./gradlew pythonGateway 2>&1 >> $LOGFILE & 24 | sleep 10 25 | cd src/main/python 26 | PYTHONPATH=$PYTHONPATH:. 27 | export PYTHONPATH 28 | echo $PYTHONPATH 29 | python3 embedding/cbow.py ../../../$corpus $langcode 50000 20000 ../../../$outdir 2>&1 >> $LOGFILE 30 | killall java 2>&1 >> $LOGFILE 31 | -------------------------------------------------------------------------------- /job_embedding.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #PBS -l select=1:ncpus=1:mem=6gb 4 | #PBS -l walltime=06:00:00 5 | #PBS -A "stupsprojmabre" 6 | 7 | lang=eng 8 | tokensFile=res_training/$lang/${lang}_news_2015_3M-sentences-raw.txt_small_tokens 9 | 10 | export LOGFILE=$PBS_O_WORKDIR/$PBS_JOBNAME"."$PBS_JOBID".embedding.log" 11 | 12 | module load TensorFlow/1.1.0 13 | module load Python/3.4.5 14 | module load CUDA/7.5.18 15 | 16 | cd ~/projektarbeit/lt/grammarchecker 17 | echo `date` >> $LOGFILE 18 | TF_INC=$(python3 -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') 19 | cd src/main/python/embedding 20 | g++ -std=c++11 -shared word2vec_ops.cc word2vec_kernels.cc -o word2vec_ops.so -fPIC -I $TF_INC -O2 -D_GLIBCXX_USE_CXX11_ABI=1 21 | cd ~/projektarbeit/lt/grammarchecker 22 | python3 src/main/python/embedding/word2vec.py --train_data $tokensFile --eval_data res_training/$lang/question-words.txt --save_path res_training/$lang/embedding --epochs_to_train 10 >> $LOGFILE 23 | -------------------------------------------------------------------------------- /job_tokenize.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #PBS -l select=1:ncpus=8:mem=50gb 4 | #PBS -l walltime=24:00:00 5 | #PBS -A "stupsprojmabre" 6 | 7 | lang=deu 8 | langcode="de-DE" 9 | corpus=res_training/$lang/news_tatoeba_training.txt 10 | 11 | module load Java/1.8.0_151 12 | 13 | cd ~/projektarbeit/languagetool-neural-network 14 | ./gradlew -debug tokenizeFile -PlanguageCode="$langcode" -PsentencesFile="$corpus" 15 | -------------------------------------------------------------------------------- /src/main/html/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Wordfiller 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 |
17 | 22 | 23 |
24 |
25 |

Using Machine Learning

26 |

only neural network rules enabled (see below)

27 |

28 |

29 | supported confusion sets for 30 | English, 31 | German, 32 | Portuguese 33 |

34 |

35 | 38 |

39 |

Using original LanguageTool

40 |

all rules enabled, including n-gram rules

41 |

42 | 45 |

46 |
47 | 48 | 49 | 50 | 51 |

52 | The neural network has been trained with data from Projekt Deutscher Wortschatz, licensed under CC BY. 53 |

54 | 55 |

56 | Computational support and infrastructure was provided by the Centre for Information and Media Technology (ZIM) at the University of Düsseldorf (Germany). 57 |

58 | 59 |

60 | Kindly hosted by Fachschaft Informatik, University of Düsseldorf 61 |

62 | 63 |

64 | The development of neural network rules for LanguageTool is a part of Markus Brenneis’ project work at the research group Software Engineering and Programming Languages. 65 |

66 | 67 |

68 | LanguageTool is freely available under the LGPL 2.1 or later. 69 |

70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /src/main/html/setup.js: -------------------------------------------------------------------------------- 1 | tinyMCE.init({ 2 | selector : "#checktext", 3 | plugins : "AtD,paste", 4 | paste_text_sticky : true, 5 | setup : function(ed) { 6 | ed.onInit.add(function(ed) { 7 | ed.pasteAsPlainText = true; 8 | }); 9 | ed.onKeyUp.add(function(ed, l) { 10 | const editor_wordfiller = tinymce.get('checktext_wordfiller'); 11 | editor_wordfiller.setContent(ed.getContent()); 12 | }); 13 | }, 14 | languagetool_i18n_no_errors : { 15 | "de-DE": "Keine Fehler gefunden." 16 | }, 17 | languagetool_i18n_explain : { 18 | "de-DE": "Mehr Informationen..." 19 | }, 20 | languagetool_i18n_ignore_once : { 21 | "de-DE": "Hier ignorieren" 22 | }, 23 | languagetool_i18n_ignore_all : { 24 | "de-DE": "Fehler dieses Typs ignorieren" 25 | }, 26 | languagetool_i18n_rule_implementation : { 27 | "de-DE": "Implementierung der Regel" 28 | }, 29 | 30 | languagetool_i18n_current_lang : 31 | function() { return document.checkform.lang.value; }, 32 | languagetool_rpc_url : "https://languagetool.org/api/v2/check", 33 | /* edit this file to customize how LanguageTool shows errors: */ 34 | languagetool_css_url : 35 | "https://www.languagetool.org/online-check/" + 36 | "tiny_mce/plugins/atd-tinymce/css/content.css", 37 | theme : "advanced", 38 | theme_advanced_buttons1 : "", 39 | theme_advanced_buttons2 : "", 40 | theme_advanced_buttons3 : "", 41 | theme_advanced_toolbar_location : "none", 42 | theme_advanced_toolbar_align : "left", 43 | theme_advanced_statusbar_location : "bottom", 44 | theme_advanced_path : false, 45 | theme_advanced_resizing : true, 46 | theme_advanced_resizing_use_cookie : false, 47 | gecko_spellcheck : false 48 | }); 49 | 50 | 51 | tinyMCE.init({ 52 | selector : "#checktext_wordfiller", 53 | plugins : "AtD,paste", 54 | paste_text_sticky : true, 55 | setup : function(ed) { 56 | ed.onInit.add(function(ed) { 57 | ed.pasteAsPlainText = true; 58 | }); 59 | ed.onKeyUp.add(function(ed, l) { 60 | const editor_lt = tinymce.get('checktext'); 61 | editor_lt.setContent(ed.getContent()); 62 | }); 63 | }, 64 | languagetool_i18n_no_errors : { 65 | "de-DE": "Keine Fehler gefunden." 66 | }, 67 | languagetool_i18n_explain : { 68 | "de-DE": "Mehr Informationen..." 69 | }, 70 | languagetool_i18n_ignore_once : { 71 | "de-DE": "Hier ignorieren" 72 | }, 73 | languagetool_i18n_ignore_all : { 74 | "de-DE": "Fehler dieses Typs ignorieren" 75 | }, 76 | languagetool_i18n_rule_implementation : { 77 | "de-DE": "Implementierung der Regel" 78 | }, 79 | 80 | languagetool_i18n_current_lang : function() { 81 | return document.checkform.lang.value; 82 | }, 83 | languagetool_rpc_url : "http://localhost:8081/v2/check", 84 | languagetool_css_url : 85 | "https://www.languagetool.org/online-check/" + 86 | "tiny_mce/plugins/atd-tinymce/css/content.css", 87 | theme : "advanced", 88 | theme_advanced_buttons1 : "", 89 | theme_advanced_buttons2 : "", 90 | theme_advanced_buttons3 : "", 91 | theme_advanced_toolbar_location : "none", 92 | theme_advanced_toolbar_align : "left", 93 | theme_advanced_statusbar_location : "bottom", 94 | theme_advanced_path : false, 95 | theme_advanced_resizing : true, 96 | theme_advanced_resizing_use_cookie : false, 97 | gecko_spellcheck : false 98 | }); 99 | 100 | function doit() { 101 | const langCode = document.checkform.lang.value; 102 | // if one of them returns before the other one has started checking, the first one might mix up the input text 103 | setTimeout(() => tinymce.get('checktext_wordfiller').execCommand("mceWritingImprovementTool", langCode), 250); 104 | tinymce.get('checktext').execCommand("mceWritingImprovementTool", langCode); 105 | } 106 | 107 | const exampleTexts = { 108 | "de-DE": "Ich glaube, das der Spieleabend gut besucht sein wir, da wir fiel Werbung gemacht haben. Was machst du den da? Ab wann seit ihr in der Uni? Ich bin gerade ihm Copyshop.", 109 | "en-US": "I didn’t no the answer, but he person told me the correct answer. We want too go to the museum no, but Peter isn’t here yet. Sara past the test yesterday. I lent him same money. Please turn of your phones.", 110 | "pt-PT": "Isso junta quanto? Sem dívida, algo inesperado ocorreu. Posso por a mesa? Filipe quer ganhar missa muscular. Eles são parceiros de lança há mais de 60 anos. Chorei a norte toda. Falei pôr telefone. [from tatoeba.org]" 111 | } 112 | 113 | function updateExampleText() { 114 | const editor_wordfiller = tinymce.get('checktext_wordfiller'); 115 | const editor_lt = tinymce.get('checktext'); 116 | chosenLanguage = document.getElementById("lang").value; 117 | editor_wordfiller.setContent(exampleTexts[chosenLanguage]); 118 | editor_lt.setContent(exampleTexts[chosenLanguage]); 119 | } 120 | 121 | document.addEventListener("DOMContentLoaded", () => { 122 | document.getElementById("lang").onchange = updateExampleText; 123 | },false); 124 | -------------------------------------------------------------------------------- /src/main/html/style.css: -------------------------------------------------------------------------------- 1 | table { 2 | border-collapse:collapse; 3 | } 4 | 5 | tr:nth-child(odd) { 6 | background: #eeeeee; 7 | 8 | } 9 | 10 | td, th { 11 | padding: 5px; 12 | text-align: center; 13 | } 14 | 15 | [title] { 16 | text-decoration: underline; 17 | text-decoration-style: dotted; 18 | text-decoration-color: rgba(0,0,0,.5); 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/BinarySentenceDatabaseCreator.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | import de.hhu.mabre.languagetool.transformationrules.*; 4 | 5 | import java.io.File; 6 | import java.io.IOException; 7 | import java.nio.file.Files; 8 | import java.nio.file.Paths; 9 | import java.util.*; 10 | 11 | import static de.hhu.mabre.languagetool.FileTokenizer.tokenizedSentences; 12 | import static de.hhu.mabre.languagetool.SubsetType.TRAINING; 13 | import static de.hhu.mabre.languagetool.SubsetType.VALIDATION; 14 | 15 | public class BinarySentenceDatabaseCreator { 16 | 17 | public static void main(String[] args) { 18 | if(args.length != 2) { 19 | System.out.println("parameters: language-code corpus"); 20 | System.exit(-1); 21 | } 22 | 23 | String languageCode = args[0]; 24 | String corpusFilename = args[1]; 25 | 26 | List lines = new ArrayList<>(0); 27 | try { 28 | lines = Files.readAllLines(Paths.get(corpusFilename)); 29 | } catch (IOException e) { 30 | e.printStackTrace(); 31 | System.exit(1); 32 | } 33 | List> relevantSentences = tokenizedSentences(languageCode, String.join("\n", lines)); 34 | 35 | EnumMap>> sets = randomlySplit(relevantSentences, 20); 36 | 37 | String basename = System.getProperty("java.io.tmpdir") + File.separator + "A"; 38 | 39 | writeDatabase(createDatabase(sets.get(TRAINING)), basename + "_training.json"); 40 | writeDatabase(createDatabase(sets.get(VALIDATION)), basename +"_validate.json"); 41 | } 42 | 43 | static EnumMap> randomlySplit(List items, int validatePercentage) { 44 | Collections.shuffle(items); 45 | int totalLines = items.size(); 46 | int firstTrainingIndex = validatePercentage * totalLines / 100; 47 | List trainingLines = items.subList(firstTrainingIndex, totalLines); 48 | List validationLines = items.subList(0, firstTrainingIndex); 49 | EnumMap> sets = new EnumMap<>(SubsetType.class); 50 | sets.put(TRAINING, trainingLines); 51 | sets.put(VALIDATION, validationLines); 52 | return sets; 53 | } 54 | 55 | private static void writeDatabase(BinarySentencesDict dict, String filename) { 56 | try { 57 | Files.write(Paths.get(filename), Collections.singletonList(dict.toJson())); 58 | System.out.println(filename + " created"); 59 | } catch (IOException e) { 60 | e.printStackTrace(); 61 | } 62 | } 63 | 64 | static BinarySentencesDict createDatabase(List> tokenizedSentences) { 65 | BinarySentencesDict db = new BinarySentencesDict(); 66 | 67 | List transformationRules = Arrays.asList( 68 | new KommaDass_Das(), 69 | new KommaOhneDass_OhneDas(), 70 | new KommaDas_Das(), 71 | new Das_Dass(), 72 | new Dass_Das(), 73 | new im_ihm(), 74 | new ihm_im() 75 | ); 76 | 77 | for (List sentence: tokenizedSentences) { 78 | Collections.shuffle(transformationRules); 79 | Optional transformationRule = transformationRules.stream().filter(tr -> tr.applicable(sentence)).findFirst(); 80 | if (transformationRule.isPresent()) { 81 | db.add(sentence, true); 82 | db.add(transformationRule.get().apply(sentence), false); 83 | } 84 | } 85 | 86 | return db; 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/BinarySentencesDict.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | import com.google.gson.Gson; 4 | 5 | import java.util.LinkedList; 6 | import java.util.List; 7 | 8 | class BinarySentencesDict { 9 | private final List> tokens = new LinkedList<>(); 10 | private final List groundTruths = new LinkedList<>(); 11 | 12 | void add(List tokens, boolean correct) { 13 | this.tokens.add(tokens); 14 | groundTruths.add(correct ? 1 : -1); 15 | } 16 | 17 | public String toJson() { 18 | return new Gson().toJson(this); 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/FileTokenizer.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | import org.languagetool.Language; 4 | import org.languagetool.tokenizers.SentenceTokenizer; 5 | import org.languagetool.tokenizers.Tokenizer; 6 | 7 | import java.io.IOException; 8 | import java.nio.file.Files; 9 | import java.nio.file.Paths; 10 | import java.util.Collections; 11 | import java.util.List; 12 | import java.util.stream.Collectors; 13 | 14 | import static org.languagetool.Languages.getLanguageForShortCode; 15 | 16 | /* 17 | * Tokenize an input file containing sentences, producing a file containing all tokens separated by spaces. 18 | */ 19 | public class FileTokenizer { 20 | public static void main(String[] args) { 21 | if (args.length != 2) { 22 | System.out.println("Parameters: language-code sentences-file"); 23 | System.exit(-1); 24 | } 25 | 26 | String languageCode = args[0]; 27 | String sentencesFile = args[1]; 28 | 29 | String text = readText(sentencesFile); 30 | String tokens = String.join(" ", tokenize(languageCode, text)); 31 | createTokensFile(sentencesFile+"-tokens", tokens); 32 | } 33 | 34 | static String readText(String sentencesFile) { 35 | System.out.println("Reading " + sentencesFile); 36 | String text = ""; 37 | try { 38 | text = String.join("\n", Files.readAllLines(Paths.get(sentencesFile))); 39 | } catch (IOException e) { 40 | e.printStackTrace(); 41 | } 42 | return text; 43 | } 44 | 45 | private static void createTokensFile(String tokensFile, String tokens) { 46 | try { 47 | Files.write(Paths.get(tokensFile), Collections.singletonList(tokens)); 48 | System.out.println("Tokens written to " + tokensFile); 49 | } catch (IOException e) { 50 | e.printStackTrace(); 51 | } 52 | } 53 | 54 | public static List tokenize(String languageCode, String text) { 55 | System.out.println("Tokenizing"); 56 | Tokenizer tokenizer = getTokenizer(languageCode); 57 | return tokenize(tokenizer, text); 58 | } 59 | 60 | private static List tokenize(Tokenizer tokenizer, String text) { 61 | List tokenizedText = tokenizer.tokenize(text); 62 | return tokenizedText.stream().filter(token -> !token.trim().isEmpty()).collect(Collectors.toList()); 63 | } 64 | 65 | private static Tokenizer getTokenizer(String languageCode) { 66 | Language language = getLanguageForShortCode(languageCode); 67 | return language.getWordTokenizer(); 68 | } 69 | 70 | public static List> tokenizedSentences(String languageCode, String text) { 71 | Language language = getLanguageForShortCode(languageCode); 72 | SentenceTokenizer sentenceTokenizer = language.getSentenceTokenizer(); 73 | List sentences = sentenceTokenizer.tokenize(text); 74 | Tokenizer tokenizer = getTokenizer(languageCode); 75 | return sentences.stream().map(sentence -> tokenize(tokenizer, sentence)).collect(Collectors.toList()); 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/NGram.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | import java.util.Arrays; 4 | import java.util.List; 5 | import java.util.stream.Collectors; 6 | 7 | class NGram { 8 | private List tokens; 9 | 10 | NGram(List tokens) { 11 | this.tokens = tokens; 12 | } 13 | 14 | NGram(String ...tokens) { 15 | this.tokens = Arrays.asList(tokens); 16 | } 17 | 18 | @Override 19 | public boolean equals(Object o) { 20 | if (this == o) return true; 21 | if (o == null || getClass() != o.getClass()) return false; 22 | 23 | NGram nGram = (NGram) o; 24 | 25 | return tokens.equals(nGram.tokens); 26 | } 27 | 28 | @Override 29 | public int hashCode() { 30 | return tokens.hashCode(); 31 | } 32 | 33 | @Override 34 | public String toString() { 35 | return "[" 36 | + tokens.stream().map(token -> "'" + token.replaceAll("'", "\\\\'") + "'") 37 | .collect(Collectors.joining(",")) 38 | + "]"; 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/NGramDatabaseCreator.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | import java.io.File; 4 | import java.io.FileReader; 5 | import java.io.IOException; 6 | import java.nio.file.Files; 7 | import java.nio.file.Paths; 8 | import java.util.*; 9 | import java.util.stream.Collectors; 10 | 11 | import static de.hhu.mabre.languagetool.FileTokenizer.tokenize; 12 | import static de.hhu.mabre.languagetool.SamplingMode.NONE; 13 | import static de.hhu.mabre.languagetool.SamplingMode.UNDERSAMPLE; 14 | import static de.hhu.mabre.languagetool.SubjectGrepper.grep; 15 | import static de.hhu.mabre.languagetool.SubsetType.TRAINING; 16 | import static de.hhu.mabre.languagetool.SubsetType.VALIDATION; 17 | 18 | /** 19 | * Create a 5-gram database as input for the neural network. 20 | */ 21 | public class NGramDatabaseCreator { 22 | 23 | private static final int N = 5; 24 | private static final int AVERAGE_WORD_LENGTH = 10; 25 | 26 | public static void main(String[] args) { 27 | if(args.length < 4) { 28 | System.out.println("parameters: language-code corpus subject1 subject2 …"); 29 | System.exit(-1); 30 | } 31 | 32 | String languageCode = args[0]; 33 | String corpusFilename = args[1]; 34 | List subjects = Arrays.asList(args).subList(2, args.length); 35 | 36 | List relevantLines = new ArrayList<>(0); 37 | try { 38 | relevantLines = grep(new FileReader(corpusFilename), subjects.toArray(new String[1])); 39 | } catch (IOException e) { 40 | e.printStackTrace(); 41 | System.exit(1); 42 | } 43 | 44 | HashMap sets = randomlySplit(relevantLines, 20); 45 | 46 | String basename = System.getProperty("java.io.tmpdir") + File.separator + String.join("_", subjects); 47 | 48 | writeDatabase(databaseFromSentences(languageCode, sets.get(TRAINING), subjects, UNDERSAMPLE), basename + "_training.py"); 49 | writeDatabase(databaseFromSentences(languageCode, sets.get(VALIDATION), subjects, NONE), basename +"_validate.py"); 50 | } 51 | 52 | static HashMap randomlySplit(List items, int validatePercentage) { 53 | Collections.shuffle(items); 54 | int totalLines = items.size(); 55 | int firstTrainingIndex = validatePercentage * totalLines / 100; 56 | String trainingLines = String.join("\n", items.subList(firstTrainingIndex, totalLines)); 57 | String validationLines = String.join("\n", items.subList(0, firstTrainingIndex)); 58 | HashMap sets = new HashMap<>(); 59 | sets.put(TRAINING, trainingLines); 60 | sets.put(VALIDATION, validationLines); 61 | return sets; 62 | } 63 | 64 | private static void writeDatabase(PythonDict pythonDict, String filename) { 65 | try { 66 | Files.write(Paths.get(filename), Collections.singletonList(pythonDict.toString())); 67 | System.out.println(filename + " created"); 68 | } catch (IOException e) { 69 | e.printStackTrace(); 70 | } 71 | } 72 | 73 | static PythonDict databaseFromSentences(String languageCode, String sentences, List subjects, SamplingMode samplingMode) { 74 | List tokens = tokenize(languageCode, sentences); 75 | tokens.add(0, "."); 76 | tokens.add(1, "."); 77 | tokens.add("."); 78 | tokens.add("."); 79 | List> tokenizedSubjects = new ArrayList<>(); 80 | for (String subject: subjects) { 81 | tokenizedSubjects.add(tokenize(languageCode, subject)); 82 | } 83 | return createDatabase(tokens, tokenizedSubjects, samplingMode); 84 | } 85 | 86 | static PythonDict databaseFromSentences(String languageCode, String sentences, String subject1, String subject2, SamplingMode samplingMode) { 87 | return databaseFromSentences(languageCode, sentences, Arrays.asList(subject1, subject2), samplingMode); 88 | } 89 | 90 | static PythonDict createDatabase(List tokens, String token1, String token2, SamplingMode samplingMode) { 91 | return createDatabase(tokens, Arrays.asList(Collections.singletonList(token1), Collections.singletonList(token2)), samplingMode); 92 | } 93 | 94 | static PythonDict createDatabase(List tokens, List> subjects, SamplingMode samplingMode) { 95 | ArrayList> nGrams = new ArrayList<>(); 96 | 97 | for (List subject: subjects) { 98 | nGrams.add(getRelevantNGrams(tokens, subject)); 99 | } 100 | 101 | PythonDict db = new PythonDict(); 102 | 103 | List numberOfSamples = getNumberOfSamples(nGrams.stream().map(ArrayList::size).collect(Collectors.toList()), samplingMode); 104 | System.out.println("sampling to " + Arrays.toString(numberOfSamples.toArray())); 105 | 106 | for (int n = 0; n < nGrams.size(); n++) { 107 | for (int i = 0; i < numberOfSamples.get(n); i++) { 108 | db.add(nGrams.get(n).get(i % nGrams.get(n).size()), n); 109 | } 110 | } 111 | return db; 112 | } 113 | 114 | private static List getNumberOfSamples(List sampleCounts, SamplingMode samplingMode) { 115 | if (samplingMode != NONE) { 116 | int numberOfSamples = 0; 117 | switch (samplingMode) { 118 | case UNDERSAMPLE: 119 | numberOfSamples = Collections.min(sampleCounts); 120 | break; 121 | case OVERSAMPLE: 122 | numberOfSamples = Collections.max(sampleCounts); 123 | break; 124 | case MODERATE_OVERSAMPLE: 125 | numberOfSamples = Math.min(2 * Collections.min(sampleCounts), Collections.max(sampleCounts)); 126 | break; 127 | } 128 | return Collections.nCopies(sampleCounts.size(), numberOfSamples); 129 | } 130 | return sampleCounts; 131 | } 132 | 133 | static ArrayList getRelevantNGrams(List tokens, List subjectTokens) { 134 | ArrayList nGrams; 135 | nGrams = new ArrayList<>(); 136 | 137 | final int end = tokens.size() - N/2; 138 | final int subjectLength = subjectTokens.size(); 139 | for(int i = N/2 - 1; i <= end - subjectLength; i++) { 140 | if (tokens.subList(i, i+subjectLength).equals(subjectTokens)) { 141 | List ngram = new LinkedList<>(); 142 | ngram.addAll(tokens.subList(i-N/2, i)); 143 | ngram.add(subjectTokens.stream().collect(Collectors.joining(" "))); 144 | ngram.addAll(tokens.subList(i+subjectLength, i+N/2+subjectLength)); 145 | nGrams.add(new NGram(ngram)); 146 | } 147 | } 148 | return nGrams; 149 | } 150 | 151 | static ArrayList getRelevantCharNGrams(List tokens, List subjectTokens) { 152 | ArrayList nGrams; 153 | nGrams = new ArrayList<>(); 154 | 155 | final int end = tokens.size(); 156 | final int subjectLength = subjectTokens.size(); 157 | for(int i = 1; i <= end - subjectLength; i++) { 158 | if (tokens.subList(i, i+subjectLength).equals(subjectTokens)) { 159 | List ngram = new LinkedList<>(); 160 | try { 161 | ngram.addAll(getNCharsBefore(tokens, i, N / 2 * AVERAGE_WORD_LENGTH)); 162 | ngram.add(subjectTokens.stream().collect(Collectors.joining(" "))); 163 | ngram.addAll(getNCharsAfter(tokens, i, N / 2 * AVERAGE_WORD_LENGTH)); 164 | nGrams.add(new NGram(ngram)); 165 | } catch(ArrayIndexOutOfBoundsException e) { 166 | // ok 167 | } 168 | } 169 | } 170 | return nGrams; 171 | } 172 | 173 | private static List getNCharsBefore(List tokens, int idx, int n) { 174 | int i = idx - 1; 175 | String chars = tokens.get(i); 176 | while (chars.length() < n) { 177 | i -= 1; 178 | chars = tokens.get(i) + " " + chars; 179 | } 180 | return chars.chars().mapToObj(c -> String.valueOf((char) c)).skip(chars.length() - n).collect(Collectors.toList()); 181 | } 182 | 183 | private static List getNCharsAfter(List tokens, int idx, int n) { 184 | int i = idx + 1; 185 | String chars = tokens.get(i); 186 | while (chars.length() < n) { 187 | i += 1; 188 | chars = chars + " " + tokens.get(i); 189 | } 190 | return chars.chars().mapToObj(c -> String.valueOf((char) c)).limit(n).collect(Collectors.toList()); 191 | } 192 | 193 | } 194 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/PythonDict.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | import java.util.Collections; 4 | import java.util.LinkedList; 5 | import java.util.List; 6 | import java.util.stream.Collectors; 7 | import java.util.stream.IntStream; 8 | 9 | class PythonDict { 10 | private List nGrams = new LinkedList<>(); 11 | private List groundTruths = new LinkedList<>(); 12 | 13 | void add(NGram nGram, int groundTruth) { 14 | nGrams.add(nGram); 15 | groundTruths.add(groundTruth); 16 | } 17 | 18 | void addAll(List nGrams, int groundTruth) { 19 | nGrams.forEach(nGram -> add(nGram, groundTruth)); 20 | } 21 | 22 | @Override 23 | public String toString() { 24 | String dict; 25 | dict = "{'ngrams':["; 26 | dict += nGrams.stream().map(NGram::toString).collect(Collectors.joining(",")); 27 | dict += "],\n"; 28 | dict += "'groundtruths':[" + oneHotEncode(groundTruths) + "]}"; 29 | return dict; 30 | } 31 | 32 | private static String oneHotEncode(List groundTruths) { 33 | int categories = Collections.max(groundTruths) + 1; 34 | return groundTruths.stream().map(i -> oneHotEncode(i, categories)).collect(Collectors.joining(",")); 35 | } 36 | 37 | private static String oneHotEncode(int n, int categories) { 38 | String list = IntStream.range(0, categories).mapToObj(i -> n == i ? "1" : "0").collect(Collectors.joining(",")); 39 | return "[" + list + "]"; 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/PythonGateway.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | import py4j.GatewayServer; 4 | 5 | public class PythonGateway { 6 | 7 | public static String hello(String name) { 8 | return "Hello " + name + "!"; 9 | } 10 | 11 | public static void main(String[] args) { 12 | GatewayServer server = new GatewayServer(new PythonGateway()); 13 | System.out.println("Starting server"); 14 | server.start(); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/SamplingMode.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | public enum SamplingMode { 4 | NONE, 5 | UNDERSAMPLE, 6 | OVERSAMPLE, 7 | MODERATE_OVERSAMPLE 8 | } 9 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/SentenceDatabaseCreator.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | import java.io.File; 4 | import java.io.FileReader; 5 | import java.io.IOException; 6 | import java.nio.file.Files; 7 | import java.nio.file.Paths; 8 | import java.util.*; 9 | import java.util.stream.Collectors; 10 | 11 | import static de.hhu.mabre.languagetool.FileTokenizer.tokenizedSentences; 12 | import static de.hhu.mabre.languagetool.FileTokenizer.tokenize; 13 | import static de.hhu.mabre.languagetool.SamplingMode.NONE; 14 | import static de.hhu.mabre.languagetool.SamplingMode.UNDERSAMPLE; 15 | import static de.hhu.mabre.languagetool.SubjectGrepper.grep; 16 | import static de.hhu.mabre.languagetool.SubsetType.TRAINING; 17 | import static de.hhu.mabre.languagetool.SubsetType.VALIDATION; 18 | 19 | /** 20 | * Create a database with sentence starts for the neural network. 21 | */ 22 | public class SentenceDatabaseCreator { 23 | 24 | public static void main(String[] args) { 25 | if(args.length < 4) { 26 | System.out.println("parameters: language-code corpus subject1 subject2 …"); 27 | System.exit(-1); 28 | } 29 | 30 | String languageCode = args[0]; 31 | String corpusFilename = args[1]; 32 | List subjects = Arrays.asList(args).subList(2, args.length); 33 | 34 | List relevantLines = new ArrayList<>(0); 35 | try { 36 | relevantLines = grep(new FileReader(corpusFilename), subjects.toArray(new String[subjects.size()])); 37 | } catch (IOException e) { 38 | e.printStackTrace(); 39 | System.exit(1); 40 | } 41 | List> relevantSentences = tokenizedSentences(languageCode, String.join("\n", relevantLines)); 42 | 43 | EnumMap>> sets = randomlySplit(relevantSentences, 20); 44 | 45 | String basename = System.getProperty("java.io.tmpdir") + File.separator + String.join("_", subjects); 46 | 47 | writeDatabase(databaseFromSentences(languageCode, sets.get(TRAINING), subjects, UNDERSAMPLE), basename + "_training.json"); 48 | writeDatabase(databaseFromSentences(languageCode, sets.get(VALIDATION), subjects, NONE), basename +"_validate.json"); 49 | } 50 | 51 | static EnumMap> randomlySplit(List items, int validatePercentage) { 52 | Collections.shuffle(items); 53 | int totalLines = items.size(); 54 | int firstTrainingIndex = validatePercentage * totalLines / 100; 55 | List trainingLines = items.subList(firstTrainingIndex, totalLines); 56 | List validationLines = items.subList(0, firstTrainingIndex); 57 | EnumMap> sets = new EnumMap<>(SubsetType.class); 58 | sets.put(TRAINING, trainingLines); 59 | sets.put(VALIDATION, validationLines); 60 | return sets; 61 | } 62 | 63 | private static void writeDatabase(SentencesDict dict, String filename) { 64 | try { 65 | Files.write(Paths.get(filename), Collections.singletonList(dict.toJson())); 66 | System.out.println(filename + " created"); 67 | } catch (IOException e) { 68 | e.printStackTrace(); 69 | } 70 | } 71 | 72 | static SentencesDict databaseFromSentences(String languageCode, List> tokenizedSentences, List subjects, SamplingMode samplingMode) { 73 | List> tokenizedSubjects = new ArrayList<>(); 74 | for (String subject: subjects) { 75 | tokenizedSubjects.add(tokenize(languageCode, subject)); 76 | } 77 | return createDatabase(tokenizedSentences, tokenizedSubjects, samplingMode); 78 | } 79 | 80 | static SentencesDict createDatabase(List> tokenizedSentences, List> subjects, SamplingMode samplingMode) { 81 | List>> sentenceStarts = new ArrayList<>(); 82 | List>> sentenceEndings = new ArrayList<>(); 83 | 84 | for (List subject: subjects) { 85 | sentenceStarts.add(getRelevantSentenceBeginnings(tokenizedSentences, subject)); 86 | sentenceEndings.add(getRelevantSentenceEndings(tokenizedSentences, subject)); 87 | } 88 | 89 | SentencesDict db = new SentencesDict(subjects.size()); 90 | 91 | List numberOfSamples = getNumberOfSamples(sentenceStarts.stream().map(List::size).collect(Collectors.toList()), samplingMode); 92 | System.out.println("sampling to " + Arrays.toString(numberOfSamples.toArray())); 93 | 94 | for (int n = 0; n < sentenceStarts.size(); n++) { 95 | for (int i = 0; i < numberOfSamples.get(n); i++) { 96 | db.add(sentenceStarts.get(n).get(i % sentenceStarts.get(n).size()), 97 | sentenceEndings.get(n).get(i % sentenceEndings.get(n).size()), 98 | n); 99 | } 100 | } 101 | return db; 102 | } 103 | 104 | private static List getNumberOfSamples(List sampleCounts, SamplingMode samplingMode) { 105 | if (samplingMode != NONE) { 106 | int numberOfSamples = 0; 107 | switch (samplingMode) { 108 | case UNDERSAMPLE: 109 | numberOfSamples = Collections.min(sampleCounts); 110 | break; 111 | case OVERSAMPLE: 112 | numberOfSamples = Collections.max(sampleCounts); 113 | break; 114 | case MODERATE_OVERSAMPLE: 115 | numberOfSamples = Math.min(2 * Collections.min(sampleCounts), Collections.max(sampleCounts)); 116 | break; 117 | } 118 | return Collections.nCopies(sampleCounts.size(), numberOfSamples); 119 | } 120 | return sampleCounts; 121 | } 122 | 123 | static List> getRelevantSentenceBeginnings(List> tokenizedSentences, List subjectTokens) { 124 | List> nGrams = new ArrayList<>(); 125 | final int subjectLength = subjectTokens.size(); 126 | for (List tokens : tokenizedSentences) { 127 | for (int i = 0; i <= tokens.size() - subjectLength; i++) { 128 | if (tokens.subList(i, i + subjectLength).equals(subjectTokens)) { 129 | nGrams.add(tokens.subList(0, i)); 130 | } 131 | } 132 | } 133 | return nGrams; 134 | } 135 | 136 | static List> getRelevantSentenceEndings(List> tokenizedSentences, List subjectTokens) { // parameterize 137 | List> nGrams = new ArrayList<>(); 138 | final int subjectLength = subjectTokens.size(); 139 | for (List tokens : tokenizedSentences) { 140 | for (int i = 0; i <= tokens.size() - subjectLength; i++) { 141 | if (tokens.subList(i, i + subjectLength).equals(subjectTokens)) { 142 | nGrams.add(tokens.subList(i + 1, tokens.size())); 143 | } 144 | } 145 | } 146 | return nGrams; 147 | } 148 | 149 | } 150 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/SentencesDict.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | import com.google.gson.Gson; 4 | 5 | import java.util.LinkedList; 6 | import java.util.List; 7 | import java.util.stream.IntStream; 8 | 9 | class SentencesDict { 10 | private final List> tokensBefore = new LinkedList<>(); 11 | private final List> tokensAfter = new LinkedList<>(); 12 | private final List groundTruths = new LinkedList<>(); 13 | private final int nCategories; 14 | 15 | SentencesDict(int nCategories) { 16 | this.nCategories = nCategories; 17 | } 18 | 19 | void add(List tokensBefore, List tokensAfter, int groundTruth) { 20 | this.tokensBefore.add(tokensBefore); 21 | this.tokensAfter.add(tokensAfter); 22 | groundTruths.add(oneHotEncode(groundTruth)); 23 | } 24 | 25 | public String toJson() { 26 | return new Gson().toJson(this); 27 | } 28 | 29 | private long[] oneHotEncode(int n) { 30 | return IntStream.range(0, nCategories).mapToLong(i -> n == i ? 1 : 0).toArray(); 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/SubjectGrepper.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.IOException; 5 | import java.io.Reader; 6 | import java.util.Arrays; 7 | import java.util.LinkedList; 8 | import java.util.List; 9 | 10 | public class SubjectGrepper { 11 | static List grep(Reader fileReader, String... subjects) throws IOException { 12 | LinkedList results = new LinkedList<>(); 13 | BufferedReader reader = new BufferedReader(fileReader); 14 | String line; 15 | while((line = reader.readLine()) != null) { 16 | if(Arrays.stream(subjects).filter(line::contains).count() > 0) { 17 | results.add(line); 18 | } 19 | } 20 | return results; 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/SubsetType.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | public enum SubsetType { 4 | TRAINING, VALIDATION 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/transformationrules/Das_Dass.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool.transformationrules; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Arrays; 5 | import java.util.List; 6 | import java.util.OptionalInt; 7 | 8 | public class Das_Dass implements TransformationRule { 9 | public boolean applicable(List tokenizedSentence) { 10 | return randomIndexOfPattern(tokenizedSentence, Arrays.asList("das")).isPresent(); 11 | } 12 | 13 | public List apply(List tokenizedSentence) { 14 | OptionalInt idx = randomIndexOfPattern(tokenizedSentence, Arrays.asList("das")); 15 | if (idx.isPresent()) { 16 | int i = idx.getAsInt(); 17 | List erroneousSentence = new ArrayList<>(tokenizedSentence); 18 | erroneousSentence.set(i, "dass"); 19 | return erroneousSentence; 20 | } 21 | throw new IllegalArgumentException(this.getClass().toString() + " cannot be applied to " + String.join(" ", tokenizedSentence)); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/transformationrules/Dass_Das.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool.transformationrules; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Arrays; 5 | import java.util.List; 6 | import java.util.OptionalInt; 7 | 8 | public class Dass_Das implements TransformationRule { 9 | public boolean applicable(List tokenizedSentence) { 10 | return randomIndexOfPattern(tokenizedSentence, Arrays.asList("dass")).isPresent(); 11 | } 12 | 13 | public List apply(List tokenizedSentence) { 14 | OptionalInt idx = randomIndexOfPattern(tokenizedSentence, Arrays.asList("dass")); 15 | if (idx.isPresent()) { 16 | int i = idx.getAsInt(); 17 | List erroneousSentence = new ArrayList<>(tokenizedSentence); 18 | erroneousSentence.set(i, "das"); 19 | return erroneousSentence; 20 | } 21 | throw new IllegalArgumentException(this.getClass().toString() + " cannot be applied to " + String.join(" ", tokenizedSentence)); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/transformationrules/KommaDas_Das.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool.transformationrules; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Arrays; 5 | import java.util.List; 6 | import java.util.OptionalInt; 7 | 8 | public class KommaDas_Das implements TransformationRule { 9 | public boolean applicable(List tokenizedSentence) { 10 | return randomIndexOfPattern(tokenizedSentence, Arrays.asList(",", "das")).isPresent(); 11 | } 12 | 13 | public List apply(List tokenizedSentence) { 14 | OptionalInt idx = randomIndexOfPattern(tokenizedSentence, Arrays.asList(",", "das")); 15 | if (idx.isPresent()) { 16 | int i = idx.getAsInt(); 17 | List erroneousSentence = new ArrayList<>(tokenizedSentence); 18 | erroneousSentence.remove(i); 19 | return erroneousSentence; 20 | } 21 | throw new IllegalArgumentException(this.getClass().toString() + " cannot be applied to " + String.join(" ", tokenizedSentence)); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/transformationrules/KommaDass_Das.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool.transformationrules; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Arrays; 5 | import java.util.List; 6 | import java.util.OptionalInt; 7 | 8 | public class KommaDass_Das implements TransformationRule { 9 | public boolean applicable(List tokenizedSentence) { 10 | return randomIndexOfPattern(tokenizedSentence, Arrays.asList(",", "dass")).isPresent(); 11 | } 12 | 13 | public List apply(List tokenizedSentence) { 14 | OptionalInt idx = randomIndexOfPattern(tokenizedSentence, Arrays.asList(",", "dass")); 15 | if (idx.isPresent()) { 16 | int i = idx.getAsInt(); 17 | List erroneousSentence = new ArrayList<>(tokenizedSentence); 18 | erroneousSentence.remove(i); 19 | erroneousSentence.set(i, "das"); 20 | return erroneousSentence; 21 | } 22 | throw new IllegalArgumentException(this.getClass().toString() + " cannot be applied to " + String.join(" ", tokenizedSentence)); 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/transformationrules/KommaOhneDass_OhneDas.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool.transformationrules; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Arrays; 5 | import java.util.List; 6 | import java.util.OptionalInt; 7 | 8 | public class KommaOhneDass_OhneDas implements TransformationRule { 9 | public boolean applicable(List tokenizedSentence) { 10 | return randomIndexOfPattern(tokenizedSentence, Arrays.asList(",", "ohne", "dass")).isPresent(); 11 | } 12 | 13 | public List apply(List tokenizedSentence) { 14 | OptionalInt idx = randomIndexOfPattern(tokenizedSentence, Arrays.asList(",", "ohne", "dass")); 15 | if (idx.isPresent()) { 16 | int i = idx.getAsInt(); 17 | List erroneousSentence = new ArrayList<>(tokenizedSentence); 18 | erroneousSentence.remove(i); 19 | erroneousSentence.set(i + 1, "das"); 20 | return erroneousSentence; 21 | } 22 | throw new IllegalArgumentException(this.getClass().toString() + " cannot be applied to " + String.join(" ", tokenizedSentence)); 23 | } 24 | 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/transformationrules/TransformationRule.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool.transformationrules; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | import java.util.OptionalInt; 6 | 7 | public interface TransformationRule { 8 | 9 | boolean applicable(List tokenizedSentence); 10 | 11 | List apply(List tokenizedSentence); 12 | 13 | default OptionalInt randomIndexOfPattern(List tokenizedSentence, List pattern) { 14 | ArrayList matchingIndices = new ArrayList<>(); 15 | int patternSize = pattern.size(); 16 | for (int i = 0; i < tokenizedSentence.size() - patternSize; i++) { 17 | if (tokenizedSentence.subList(i, i + patternSize).equals(pattern)) { 18 | matchingIndices.add(i); 19 | } 20 | } 21 | if (matchingIndices.isEmpty()) { 22 | return OptionalInt.empty(); 23 | } else { 24 | return OptionalInt.of(matchingIndices.get((int) (Math.random() * matchingIndices.size()))); 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/transformationrules/ihm_im.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool.transformationrules; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Arrays; 5 | import java.util.List; 6 | import java.util.OptionalInt; 7 | 8 | public class ihm_im implements TransformationRule { 9 | public boolean applicable(List tokenizedSentence) { 10 | return randomIndexOfPattern(tokenizedSentence, Arrays.asList("ihm")).isPresent(); 11 | } 12 | 13 | public List apply(List tokenizedSentence) { 14 | OptionalInt idx = randomIndexOfPattern(tokenizedSentence, Arrays.asList("ihm")); 15 | if (idx.isPresent()) { 16 | int i = idx.getAsInt(); 17 | List erroneousSentence = new ArrayList<>(tokenizedSentence); 18 | erroneousSentence.set(i, "im"); 19 | return erroneousSentence; 20 | } 21 | throw new IllegalArgumentException(this.getClass().toString() + " cannot be applied to " + String.join(" ", tokenizedSentence)); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/de/hhu/mabre/languagetool/transformationrules/im_ihm.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool.transformationrules; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Arrays; 5 | import java.util.List; 6 | import java.util.OptionalInt; 7 | 8 | public class im_ihm implements TransformationRule { 9 | public boolean applicable(List tokenizedSentence) { 10 | return randomIndexOfPattern(tokenizedSentence, Arrays.asList("im")).isPresent(); 11 | } 12 | 13 | public List apply(List tokenizedSentence) { 14 | OptionalInt idx = randomIndexOfPattern(tokenizedSentence, Arrays.asList("im")); 15 | if (idx.isPresent()) { 16 | int i = idx.getAsInt(); 17 | List erroneousSentence = new ArrayList<>(tokenizedSentence); 18 | erroneousSentence.set(i, "ihm"); 19 | return erroneousSentence; 20 | } 21 | throw new IllegalArgumentException(this.getClass().toString() + " cannot be applied to " + String.join(" ", tokenizedSentence)); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/python/EvalResult.py: -------------------------------------------------------------------------------- 1 | class EvalResult(): 2 | def __init__(self): 3 | self.tp = 0 4 | self.fp = 0 5 | self.tn = 0 6 | self.fn = 0 7 | 8 | def add_tp(self): 9 | self.tp += 1 10 | 11 | def add_fp(self): 12 | self.fp += 1 13 | 14 | def add_tn(self): 15 | self.tn += 1 16 | 17 | def add_fn(self): 18 | self.fn += 1 19 | 20 | def recall(self): 21 | if (self.tp + self.fn) > 0: 22 | return self.tp / (self.tp + self.fn) 23 | return 0 24 | 25 | def precision(self): 26 | if (self.tp + self.fp) > 0: 27 | return self.tp / (self.tp + self.fp) 28 | return 1 29 | 30 | def __str__(self): 31 | return "\n" % (self.tp, self.fp, self.tn, self.fn, self.precision(), self.recall()) 32 | 33 | def __repr__(self): 34 | return self.__str__() 35 | -------------------------------------------------------------------------------- /src/main/python/LayeredScorer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LayeredScorer: 5 | 6 | def __init__(self, weights_path: str): 7 | self.W1 = np.loadtxt(weights_path + "W_fc1.txt") 8 | self.b1 = np.loadtxt(weights_path + "b_fc1.txt") 9 | try: 10 | self.W2 = np.loadtxt(weights_path + "W_fc2.txt") 11 | self.b2 = np.loadtxt(weights_path + "b_fc2.txt") 12 | except FileNotFoundError: 13 | self.W2 = None 14 | self.b2 = None 15 | 16 | def scores(self, m: np.ndarray): 17 | l1 = np.matmul(m, self.W1) + self.b1 18 | if self.W2 is None: 19 | return l1 20 | return np.matmul(l1 * (l1 > 0), self.W2) + self.b2 21 | 22 | def context_size(self, embedding_size): 23 | return self.W1.shape[0]/embedding_size 24 | -------------------------------------------------------------------------------- /src/main/python/embedding/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2017 The TensorFlow Authors. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright 2017, The TensorFlow Authors. 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | 205 | -------------------------------------------------------------------------------- /src/main/python/embedding/README.md: -------------------------------------------------------------------------------- 1 | This directory contains models for unsupervised training of word embeddings 2 | using the model described in: 3 | 4 | (Mikolov, et. al.) [Efficient Estimation of Word Representations in Vector Space](http://arxiv.org/abs/1301.3781), 5 | ICLR 2013. 6 | 7 | Detailed instructions on how to get started and use them are available in the 8 | tutorials. Brief instructions are below. 9 | 10 | * [Word2Vec Tutorial](http://tensorflow.org/tutorials/word2vec) 11 | 12 | Assuming you have cloned the git repository, navigate into this directory. To download the example text and evaluation data: 13 | 14 | ```shell 15 | curl http://mattmahoney.net/dc/text8.zip > text8.zip 16 | unzip text8.zip 17 | curl https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/word2vec/source-archive.zip > source-archive.zip 18 | unzip -p source-archive.zip word2vec/trunk/questions-words.txt > questions-words.txt 19 | rm text8.zip source-archive.zip 20 | ``` 21 | 22 | You will need to compile the ops as follows: 23 | 24 | ```shell 25 | TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') 26 | g++ -std=c++11 -shared word2vec_ops.cc word2vec_kernels.cc -o word2vec_ops.so -fPIC -I $TF_INC -O2 -D_GLIBCXX_USE_CXX11_ABI=0 27 | ``` 28 | 29 | On Mac, add `-undefined dynamic_lookup` to the g++ command. 30 | 31 | (For an explanation of what this is doing, see the tutorial on [Adding a New Op to TensorFlow](https://www.tensorflow.org/how_tos/adding_an_op/#building_the_op_library). The flag `-D_GLIBCXX_USE_CXX11_ABI=0` is included to support newer versions of gcc. However, if you compiled TensorFlow from source using gcc 5 or later, you may need to exclude the flag.) 32 | Then run using: 33 | 34 | ```shell 35 | python word2vec_optimized.py \ 36 | --train_data=text8 \ 37 | --eval_data=questions-words.txt \ 38 | --save_path=/tmp/ 39 | ``` 40 | 41 | Here is a short overview of what is in this directory. 42 | 43 | File | What's in it? 44 | --- | --- 45 | `word2vec.py` | A version of word2vec implemented using TensorFlow ops and minibatching. 46 | `word2vec_test.py` | Integration test for word2vec. 47 | `word2vec_optimized.py` | A version of word2vec implemented using C ops that does no minibatching. 48 | `word2vec_optimized_test.py` | Integration test for word2vec_optimized. 49 | `word2vec_kernels.cc` | Kernels for the custom input and training ops. 50 | `word2vec_ops.cc` | The declarations of the custom ops. 51 | -------------------------------------------------------------------------------- /src/main/python/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Import generated word2vec optimized ops into embedding package.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | -------------------------------------------------------------------------------- /src/main/python/embedding/cbow.py: -------------------------------------------------------------------------------- 1 | # References 2 | # - https://www.tensorflow.org/versions/r0.10/tutorials/word2vec/index.html 3 | # - https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/examples/tutorials/word2vec/word2vec_basic.py 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import collections 10 | import json 11 | import math 12 | import sys 13 | import logging 14 | 15 | import numpy as np 16 | import tensorflow as tf 17 | from six.moves import xrange # pylint: disable=redefined-builtin 18 | 19 | from embedding.common import build_dataset, read_data 20 | from languagetool.languagetool import LanguageTool 21 | 22 | logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.DEBUG, stream=sys.stdout) 23 | 24 | logging.info(sys.argv) 25 | 26 | if len(sys.argv) != 6: 27 | print("Parameters: training-file language-code vocabulary-size output-dir") 28 | print("Example: /tmp/c de-DE 50000 20000 /tmp") 29 | sys.exit(1) 30 | 31 | filename = sys.argv[1] 32 | lt = LanguageTool(sys.argv[2]) 33 | max_words_in_vocabulary = int(sys.argv[3]) 34 | num_steps = int(sys.argv[4]) + 1 35 | outdir = sys.argv[5] 36 | 37 | 38 | # Read the data into a list of strings. 39 | words = read_data(filename) 40 | logging.info('number of tokens in input file:', len(words), flush=True) 41 | 42 | # Step 2: Build the dictionary and replace rare words with UNK token. 43 | 44 | data, count, dictionary, reverse_dictionary = build_dataset(words, max_words_in_vocabulary, lt.tag_token) 45 | del words # Hint to reduce memory. 46 | vocabulary_size = len(dictionary) 47 | logging.info('vocabulary size', vocabulary_size) 48 | logging.info('Most common words (+UNK)', count[:5]) 49 | logging.info('Sample data', data[:10], [reverse_dictionary[i] for i in data[:10]], flush=True) 50 | 51 | data_index = 0 52 | 53 | 54 | # Step 3: Function to generate a training batch for the skip-gram model. 55 | def generate_batch(batch_size, context_window): 56 | # all context tokens should be used, hence no associated num_skips argument 57 | global data_index 58 | context_size = 2 * context_window 59 | batch = np.ndarray(shape=(batch_size, context_size), dtype=np.int32) 60 | labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) 61 | span = 2 * context_window + 1 # [ context_window target context_window ] 62 | buffer = collections.deque(maxlen=span) 63 | for _ in range(span): 64 | buffer.append(data[data_index]) 65 | data_index = (data_index + 1) % len(data) 66 | for i in range(batch_size): 67 | # context tokens are just all the tokens in buffer except the target 68 | batch[i, :] = [token for idx, token in enumerate(buffer) if idx != context_window] 69 | labels[i, 0] = buffer[context_window] 70 | buffer.append(data[data_index]) 71 | data_index = (data_index + 1) % len(data) 72 | return batch, labels 73 | 74 | 75 | batch, labels = generate_batch(batch_size=8, context_window=1) 76 | for i in range(8): 77 | print(batch[i, 0], reverse_dictionary[batch[i, 0]], 78 | batch[i, 1], reverse_dictionary[batch[i, 1]], 79 | '->', labels[i, 0], reverse_dictionary[labels[i, 0]]) 80 | 81 | # Step 4: Build and train a skip-gram model. 82 | 83 | batch_size = 128 84 | embedding_size = 128 # Dimension of the embedding vector. 85 | context_window = 2 # How many words to consider left and right. 86 | context_size = 2 * context_window 87 | 88 | # We pick a random validation set to sample nearest neighbors. Here we limit the 89 | # validation samples to the words that have a low numeric ID, which by 90 | # construction are also the most frequent. 91 | valid_size = 16 # Random set of words to evaluate similarity on. 92 | valid_window = 100 # Only pick dev samples in the head of the distribution. 93 | valid_examples = np.random.choice(valid_window, valid_size, replace=False) 94 | # valid_target_examples = [[dictionary["d"], dictionary["like"], dictionary["go"], dictionary["to"]], 95 | # [dictionary[","], dictionary["and"], dictionary["we"], dictionary["will"]]] 96 | valid_target_examples = [[dictionary["wir"], dictionary["sind"], dictionary["Großen"], dictionary["und"]], 97 | [dictionary["es"], dictionary["ist"], dictionary["zu"], dictionary["sehen"]]] 98 | num_sampled = 64 # Number of negative examples to sample. 99 | 100 | graph = tf.Graph() 101 | 102 | with graph.as_default(): 103 | # Input data. 104 | train_inputs = tf.placeholder(tf.int32, shape=[batch_size, context_size]) 105 | train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1]) 106 | valid_dataset = tf.constant(valid_examples, dtype=tf.int32) 107 | valid_target_dataset = tf.constant(valid_target_examples, dtype=tf.int32) 108 | 109 | # Ops and variables pinned to the CPU because of missing GPU implementation 110 | with tf.device('/cpu:0'): 111 | # Look up embeddings for inputs. 112 | embeddings = tf.Variable(tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0)) 113 | embed = tf.nn.embedding_lookup(embeddings, train_inputs) 114 | # take mean of embeddings of context words for context embedding 115 | embed_context = tf.reshape(embed, [batch_size, embedding_size * context_size]) 116 | 117 | # Construct the variables for the NCE loss 118 | nce_weights = tf.Variable(tf.truncated_normal([vocabulary_size, embedding_size * context_size], stddev=1.0 / math.sqrt(embedding_size))) 119 | nce_biases = tf.Variable(tf.zeros([vocabulary_size])) 120 | 121 | # Compute the average NCE loss for the batch. 122 | # tf.nce_loss automatically draws a new sample of the negative labels each 123 | # time we evaluate the loss. 124 | loss = tf.reduce_mean(tf.nn.nce_loss(nce_weights, nce_biases, train_labels, embed_context, num_sampled, vocabulary_size)) 125 | 126 | # Construct the SGD optimizer using a learning rate of 1.0. 127 | optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss) 128 | 129 | # Compute the cosine similarity between minibatch examples and all embeddings. 130 | norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True)) 131 | normalized_embeddings = embeddings / norm 132 | valid_embeddings = tf.nn.embedding_lookup(normalized_embeddings, valid_dataset) 133 | similarity = tf.matmul(valid_embeddings, normalized_embeddings, transpose_b=True) 134 | 135 | valid_target_probabilities = tf.matmul(tf.reshape(tf.nn.embedding_lookup(embeddings, valid_target_dataset), [len(valid_target_examples), embedding_size * context_size]), nce_weights, transpose_b=True) + nce_biases 136 | 137 | # Add variable initializer. 138 | init = tf.initialize_all_variables() 139 | 140 | # Step 5: Begin training. 141 | logging.info("steps", num_steps, flush=True) 142 | 143 | with tf.Session(graph=graph) as session: 144 | # We must initialize all variables before we use them. 145 | init.run() 146 | logging.info("Initialized") 147 | 148 | average_loss = 0 149 | for step in xrange(num_steps): 150 | batch_inputs, batch_labels = generate_batch(batch_size, context_window) 151 | feed_dict = {train_inputs: batch_inputs, train_labels: batch_labels} 152 | 153 | # We perform one update step by evaluating the optimizer op (including it 154 | # in the list of returned values for session.run() 155 | _, loss_val = session.run([optimizer, loss], feed_dict=feed_dict) 156 | average_loss += loss_val 157 | 158 | if step % 2000 == 0: 159 | if step > 0: 160 | average_loss /= 2000 161 | # The average loss is an estimate of the loss over the last 2000 batches. 162 | print("Average loss at step ", step, ": ", average_loss, flush=True) 163 | average_loss = 0 164 | 165 | # Note that this is expensive (~20% slowdown if computed every 500 steps) 166 | if step % 10000 == 0: 167 | sim = similarity.eval() 168 | for i in xrange(valid_size): 169 | valid_word = reverse_dictionary[valid_examples[i]] 170 | top_k = 8 # number of nearest neighbors 171 | nearest = (-sim[i, :]).argsort()[1:top_k + 1] 172 | log_str = "Nearest to %s:" % valid_word 173 | for k in xrange(top_k): 174 | close_word = reverse_dictionary[nearest[k]] 175 | log_str = "%s %s," % (log_str, close_word) 176 | logging.info(log_str) 177 | 178 | pss = valid_target_probabilities.eval() 179 | for context, ps in zip(valid_target_examples, pss): 180 | top_k = 8 181 | nearest = (-ps).argsort()[1:top_k + 1] 182 | context_words = list(map(lambda w: reverse_dictionary[w], context)) 183 | nearest_words = list(map(lambda w: reverse_dictionary[w], nearest)) 184 | logging.info("Nearest to %s: %s" % (str(context_words), str(nearest_words))) 185 | 186 | 187 | final_embeddings = normalized_embeddings.eval() 188 | 189 | json.dump(dictionary, open(outdir + "/dictionary.json", "w")) 190 | np.savetxt(outdir + "/embeddings.txt", embeddings.eval()) 191 | np.savetxt(outdir + "/nce_weights.txt", nce_weights.eval()) 192 | np.savetxt(outdir + "/nce_biases.txt", nce_biases.eval()) 193 | 194 | # Step 6: Visualize the embeddings. 195 | 196 | def plot_with_labels(low_dim_embs, labels, filename='tsne.png'): 197 | assert low_dim_embs.shape[0] >= len(labels), "More labels than embeddings" 198 | plt.figure(figsize=(18, 18)) # in inches 199 | for i, label in enumerate(labels): 200 | x, y = low_dim_embs[i, :] 201 | plt.scatter(x, y) 202 | plt.annotate(label, 203 | xy=(x, y), 204 | xytext=(5, 2), 205 | textcoords='offset points', 206 | ha='right', 207 | va='bottom') 208 | 209 | plt.savefig(filename) 210 | 211 | 212 | try: 213 | from sklearn.manifold import TSNE 214 | import matplotlib.pyplot as plt 215 | 216 | tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000) 217 | plot_only = 500 218 | low_dim_embs = tsne.fit_transform(final_embeddings[:plot_only, :]) 219 | labels = [reverse_dictionary[i] for i in xrange(plot_only)] 220 | plot_with_labels(low_dim_embs, labels) 221 | 222 | except ImportError: 223 | logging.warn("Please install sklearn and matplotlib to visualize embeddings.") 224 | -------------------------------------------------------------------------------- /src/main/python/embedding/cbow_checker.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | 5 | from languagetool.languagetool import LanguageTool 6 | 7 | # base_path = "/home/markus/Dokumente/GitHub/projektarbeit/lt/languagetool-neural-network/res/cbow/eng_ordered/" #"/home/markus/Dokumente/GitHub/projektarbeit/lt/languagetool-neural-network/res/cbow/eng/" 8 | # base_path = "/home/markus/Dokumente/GitHub/projektarbeit/lt/languagetool-neural-network/res/cbow/deu_ordered/" #"/home/markus/Dokumente/GitHub/projektarbeit/lt/languagetool-neural-network/res/cbow/eng/" 9 | # fallback = lambda _: "UNK" 10 | base_path = "/home/markus/Dokumente/GitHub/projektarbeit/lt/languagetool-neural-network/res/cbow/deu_pos/" 11 | fallback = lambda token: str(lt.tag_token(token)) 12 | 13 | lt = LanguageTool("de-DE") 14 | 15 | dictionary = json.load(open(base_path + "dictionary.json")) 16 | inverse_dictionary = {i: w for w, i in dictionary.items()} 17 | embeddings = np.loadtxt(base_path + "embeddings.txt") 18 | nce_weights = np.loadtxt(base_path + "nce_weights.txt") 19 | nce_biases = np.loadtxt(base_path + "nce_biases.txt") 20 | context_size = int(nce_weights.shape[1]/embeddings.shape[1]) 21 | context_window = int(context_size / 2) 22 | 23 | 24 | def best_fits(tokenized_sentence: [str], candidates: [str]) -> [(float, str)]: 25 | middle_index = int(len(tokenized_sentence) / 2) 26 | context = tokenized_sentence[:middle_index] + tokenized_sentence[middle_index+1:] 27 | return best_fits_for_context(context, candidates) 28 | 29 | 30 | def best_fits_for_context(context: [str], candidates: [str]) -> [(float, str)]: 31 | ps = np.reshape(embeddings[np.array(get_embedding_indices(context))], [-1]) @ nce_weights.T + nce_biases 32 | candidates_indices = get_embedding_indices(candidates) 33 | order = abs(ps[candidates_indices]).argsort() 34 | return [((ps[candidates_indices])[i], candidates[i]) for i in order] 35 | 36 | 37 | def get_embedding_indices(candidates): 38 | return list(map(lambda token: safe_get(dictionary, token), candidates)) 39 | 40 | 41 | def safe_get(dictionary: dict, token: str) -> str: 42 | if token in dictionary: 43 | return dictionary[token] 44 | else: 45 | fallback_token = fallback(token) 46 | if fallback_token in dictionary: 47 | # print("POS lookup", token) 48 | return dictionary[fallback_token] 49 | else: 50 | # print("UNK", token) 51 | return dictionary["UNK"] 52 | 53 | 54 | def best_fits_with_offsetting(tokenized_sentence: [str], candidates: [str]) -> [(float, str)]: 55 | middle_index = int(len(tokenized_sentence) / 2) 56 | scores = [] 57 | for candidate in candidates: 58 | tokenized_candidate_sentence = tokenized_sentence[:middle_index] + [candidate] + tokenized_sentence[middle_index+1:] 59 | score = 1 60 | for offset in range(len(tokenized_sentence) - context_size): 61 | context = tokenized_candidate_sentence[offset:int(context_size/2)+offset] + tokenized_candidate_sentence[int(context_size/2)+1+offset:int(context_size/2)+1+offset+int(context_size/2)] 62 | center = tokenized_candidate_sentence[int(context_size/2)+offset] 63 | [(p, _)] = best_fits_for_context(context, [center]) 64 | # print(context, center, p) 65 | score += abs(p) 66 | scores.append((score, candidate)) 67 | scores.sort(key=lambda pair: pair[0]) 68 | return scores 69 | 70 | 71 | def print_measures(error_detection_tp, error_detection_fp, error_detection_tn, error_detection_fn): 72 | accs = 0 73 | ps = 0 74 | rs = 0 75 | for key in error_detection_tp.keys(): 76 | tp = error_detection_tp[key] 77 | fp = error_detection_fp[key] 78 | tn = error_detection_tn[key] 79 | fn = error_detection_fn[key] 80 | accuracy = (tp + tn) / (tp + fp + tn + fn) 81 | precision = tp / (tp + fp) if (tp + fp) > 0 else np.nan 82 | recall = tp / (tp + fn) if (tp + fn) > 0 else np.nan 83 | print(key, (tp + fn), "acc", accuracy, "p", precision, "r", recall) 84 | accs += accuracy 85 | ps += precision 86 | rs += recall 87 | n_keys = len(error_detection_tp.keys()) 88 | print("all", "acc", accs/n_keys, "ps", ps/n_keys, "rs", rs/n_keys) 89 | 90 | 91 | def evaluate_text(tokenized_sentences: [str], candidates: [str], threshold: float=1) -> float: 92 | tp = 0 93 | fp = 0 94 | error_detection_tp = {token: 0 for token in candidates} 95 | error_detection_fp = {token: 0 for token in candidates} 96 | error_detection_tn = {token: 0 for token in candidates} 97 | error_detection_fn = {token: 0 for token in candidates} 98 | for i, token in enumerate(tokenized_sentences[context_window+1:-context_window-1], context_window+1): 99 | if token in candidates: 100 | sentence = tokenized_sentences[i - context_window : i + context_window + 1] 101 | best_fits = best_fits_with_offsetting(sentence, candidates) 102 | (score_best, best_fit) = best_fits[0] 103 | (score_second, second_fit) = best_fits[1] 104 | if best_fit == token: 105 | tp += 1 106 | error_detection_tn[second_fit] += 1 107 | if score_second - score_best > threshold: 108 | error_detection_tp[best_fit] += 1 109 | else: 110 | fp += 1 111 | error_detection_fn[second_fit] += 1 112 | if score_second - score_best > threshold: 113 | error_detection_fp[best_fit] += 1 114 | print("false positive", best_fit, sentence) 115 | print_measures(error_detection_tp, error_detection_fp, error_detection_tn, error_detection_fn) 116 | print_measures(error_detection_tp, error_detection_fp, error_detection_tn, error_detection_fn) 117 | return tp / (tp + fp) 118 | 119 | 120 | # print(best_fits("would like to go to".split(), ["to", "too"])) 121 | # print(best_fits_with_offsetting("We would like to go to a".split(), ["to", "too"])) 122 | # print(best_fits("allow him to do business".split(), ["to", "too"])) 123 | # print(best_fits_with_offsetting("we allow him to do business with".split(), ["to", "too"])) 124 | # print(best_fits("a bit too heavy ,".split(), ["to", "too"])) 125 | # print(best_fits_with_offsetting("is a bit too heavy , so".split(), ["to", "too"])) 126 | # print(best_fits("are , too , very".split(), ["to", "too"])) 127 | # print(best_fits_with_offsetting("They are , too , very interested".split(), ["to", "too"])) 128 | # print(best_fits("I like the lecture about".split(), ["the", "then", "than"])) 129 | # print(best_fits_with_offsetting(". I like the lecture about natural".split(), ["the", "then", "than"])) 130 | # print(best_fits(", is then the processing".split(), ["the", "then", "than"])) 131 | # print(best_fits_with_offsetting("agreement , is then the processing of".split(), ["the", "then", "than"])) 132 | # print(best_fits("at the then current stock".split(), ["the", "then", "than"])) 133 | # print(best_fits_with_offsetting(", at the then current stock exchange".split(), ["the", "then", "than"])) 134 | # print(best_fits("on more than one credit".split(), ["the", "then", "than"])) 135 | # print(best_fits_with_offsetting("delay on more than one credit card".split(), ["the", "then", "than"])) 136 | # print(best_fits("nationals other than those who".split(), ["the", "then", "than"])) 137 | # print(best_fits_with_offsetting("nationals nationals other than those who do".split(), ["the", "then", "than"])) 138 | 139 | print(best_fits_with_offsetting(", wenn sie schon alle Zusammenhänge etwas".split(), ["schon", "schön"])) 140 | print(best_fits_with_offsetting("das eigentlich auch schon , wir haben".split(), ["schon", "schön"])) 141 | print(best_fits_with_offsetting("zu selten , schön und wichtig um".split(), ["schon", "schön"])) 142 | print(best_fits_with_offsetting(", die nicht schön blühen , noch".split(), ["schon", "schön"])) 143 | -------------------------------------------------------------------------------- /src/main/python/embedding/common.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import logging 3 | 4 | from languagetool.languagetool import LanguageTool 5 | 6 | 7 | def read_data(filename: str) -> [str]: 8 | """Read tokens from tokenized file""" 9 | with open(filename) as f: 10 | data = f.read().split() 11 | return data 12 | 13 | 14 | def build_dataset(words, max_words_in_vocabulary: int=20000, pos_tagger=None): 15 | count = [['UNK', -1]] 16 | counter = collections.Counter(words) 17 | logging.info("unique words", len(counter)) 18 | most_common_counter = counter.most_common(max_words_in_vocabulary - 1) 19 | if pos_tagger is not None: 20 | vocabulary = {token for token, _ in most_common_counter} 21 | words_and_tags = [] 22 | logging.info("tagging words") 23 | for word in words: 24 | if word not in vocabulary: 25 | words_and_tags.append(str(pos_tagger(word))) 26 | else: 27 | words_and_tags.append(word) 28 | tagged_counter = collections.Counter(words_and_tags).most_common() 29 | count.extend(tagged_counter) 30 | else: 31 | count.extend(most_common_counter) 32 | words_and_tags = words 33 | dictionary = dict() 34 | logging.info("Buildings dictionary") 35 | for word, _ in count: 36 | dictionary[word] = len(dictionary) 37 | data = list() 38 | unk_count = 0 39 | for word in words_and_tags: 40 | if word in dictionary: 41 | index = dictionary[word] 42 | else: 43 | index = 0 44 | unk_count += 1 45 | data.append(index) 46 | count[0][1] = unk_count 47 | reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys())) 48 | return data, count, dictionary, reverse_dictionary 49 | -------------------------------------------------------------------------------- /src/main/python/embedding/question-words.txt: -------------------------------------------------------------------------------- 1 | I you my your 2 | he she his her 3 | go goes say says 4 | go went say said 5 | Berlin Germany Paris France -------------------------------------------------------------------------------- /src/main/python/embedding/word2vec_kernels.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow/core/framework/op.h" 17 | #include "tensorflow/core/framework/op_kernel.h" 18 | #include "tensorflow/core/lib/core/stringpiece.h" 19 | #include "tensorflow/core/lib/gtl/map_util.h" 20 | #include "tensorflow/core/lib/random/distribution_sampler.h" 21 | #include "tensorflow/core/lib/random/philox_random.h" 22 | #include "tensorflow/core/lib/random/simple_philox.h" 23 | #include "tensorflow/core/lib/strings/str_util.h" 24 | #include "tensorflow/core/platform/thread_annotations.h" 25 | #include "tensorflow/core/util/guarded_philox_random.h" 26 | 27 | namespace tensorflow { 28 | 29 | // Number of examples to precalculate. 30 | const int kPrecalc = 3000; 31 | // Number of words to read into a sentence before processing. 32 | const int kSentenceSize = 1000; 33 | 34 | namespace { 35 | 36 | bool ScanWord(StringPiece* input, string* word) { 37 | str_util::RemoveLeadingWhitespace(input); 38 | StringPiece tmp; 39 | if (str_util::ConsumeNonWhitespace(input, &tmp)) { 40 | word->assign(tmp.data(), tmp.size()); 41 | return true; 42 | } else { 43 | return false; 44 | } 45 | } 46 | 47 | } // end namespace 48 | 49 | class SkipgramWord2vecOp : public OpKernel { 50 | public: 51 | explicit SkipgramWord2vecOp(OpKernelConstruction* ctx) 52 | : OpKernel(ctx), rng_(&philox_) { 53 | string filename; 54 | OP_REQUIRES_OK(ctx, ctx->GetAttr("filename", &filename)); 55 | OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_size", &batch_size_)); 56 | OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size", &window_size_)); 57 | OP_REQUIRES_OK(ctx, ctx->GetAttr("min_count", &min_count_)); 58 | OP_REQUIRES_OK(ctx, ctx->GetAttr("subsample", &subsample_)); 59 | OP_REQUIRES_OK(ctx, Init(ctx->env(), filename)); 60 | 61 | mutex_lock l(mu_); 62 | example_pos_ = corpus_size_; 63 | label_pos_ = corpus_size_; 64 | label_limit_ = corpus_size_; 65 | sentence_index_ = kSentenceSize; 66 | for (int i = 0; i < kPrecalc; ++i) { 67 | NextExample(&precalc_examples_[i].input, &precalc_examples_[i].label); 68 | } 69 | } 70 | 71 | void Compute(OpKernelContext* ctx) override { 72 | Tensor words_per_epoch(DT_INT64, TensorShape({})); 73 | Tensor current_epoch(DT_INT32, TensorShape({})); 74 | Tensor total_words_processed(DT_INT64, TensorShape({})); 75 | Tensor examples(DT_INT32, TensorShape({batch_size_})); 76 | auto Texamples = examples.flat(); 77 | Tensor labels(DT_INT32, TensorShape({batch_size_})); 78 | auto Tlabels = labels.flat(); 79 | { 80 | mutex_lock l(mu_); 81 | for (int i = 0; i < batch_size_; ++i) { 82 | Texamples(i) = precalc_examples_[precalc_index_].input; 83 | Tlabels(i) = precalc_examples_[precalc_index_].label; 84 | precalc_index_++; 85 | if (precalc_index_ >= kPrecalc) { 86 | precalc_index_ = 0; 87 | for (int j = 0; j < kPrecalc; ++j) { 88 | NextExample(&precalc_examples_[j].input, 89 | &precalc_examples_[j].label); 90 | } 91 | } 92 | } 93 | words_per_epoch.scalar()() = corpus_size_; 94 | current_epoch.scalar()() = current_epoch_; 95 | total_words_processed.scalar()() = total_words_processed_; 96 | } 97 | ctx->set_output(0, word_); 98 | ctx->set_output(1, freq_); 99 | ctx->set_output(2, words_per_epoch); 100 | ctx->set_output(3, current_epoch); 101 | ctx->set_output(4, total_words_processed); 102 | ctx->set_output(5, examples); 103 | ctx->set_output(6, labels); 104 | } 105 | 106 | private: 107 | struct Example { 108 | int32 input; 109 | int32 label; 110 | }; 111 | 112 | int32 batch_size_ = 0; 113 | int32 window_size_ = 5; 114 | float subsample_ = 1e-3; 115 | int min_count_ = 5; 116 | int32 vocab_size_ = 0; 117 | Tensor word_; 118 | Tensor freq_; 119 | int64 corpus_size_ = 0; 120 | std::vector corpus_; 121 | std::vector precalc_examples_; 122 | int precalc_index_ = 0; 123 | std::vector sentence_; 124 | int sentence_index_ = 0; 125 | 126 | mutex mu_; 127 | random::PhiloxRandom philox_ GUARDED_BY(mu_); 128 | random::SimplePhilox rng_ GUARDED_BY(mu_); 129 | int32 current_epoch_ GUARDED_BY(mu_) = -1; 130 | int64 total_words_processed_ GUARDED_BY(mu_) = 0; 131 | int32 example_pos_ GUARDED_BY(mu_); 132 | int32 label_pos_ GUARDED_BY(mu_); 133 | int32 label_limit_ GUARDED_BY(mu_); 134 | 135 | // {example_pos_, label_pos_} is the cursor for the next example. 136 | // example_pos_ wraps around at the end of corpus_. For each 137 | // example, we randomly generate [label_pos_, label_limit) for 138 | // labels. 139 | void NextExample(int32* example, int32* label) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 140 | while (true) { 141 | if (label_pos_ >= label_limit_) { 142 | ++total_words_processed_; 143 | ++sentence_index_; 144 | if (sentence_index_ >= kSentenceSize) { 145 | sentence_index_ = 0; 146 | for (int i = 0; i < kSentenceSize; ++i, ++example_pos_) { 147 | if (example_pos_ >= corpus_size_) { 148 | ++current_epoch_; 149 | example_pos_ = 0; 150 | } 151 | if (subsample_ > 0) { 152 | int32 word_freq = freq_.flat()(corpus_[example_pos_]); 153 | // See Eq. 5 in http://arxiv.org/abs/1310.4546 154 | float keep_prob = 155 | (std::sqrt(word_freq / (subsample_ * corpus_size_)) + 1) * 156 | (subsample_ * corpus_size_) / word_freq; 157 | if (rng_.RandFloat() > keep_prob) { 158 | i--; 159 | continue; 160 | } 161 | } 162 | sentence_[i] = corpus_[example_pos_]; 163 | } 164 | } 165 | const int32 skip = 1 + rng_.Uniform(window_size_); 166 | label_pos_ = std::max(0, sentence_index_ - skip); 167 | label_limit_ = 168 | std::min(kSentenceSize, sentence_index_ + skip + 1); 169 | } 170 | if (sentence_index_ != label_pos_) { 171 | break; 172 | } 173 | ++label_pos_; 174 | } 175 | *example = sentence_[sentence_index_]; 176 | *label = sentence_[label_pos_++]; 177 | } 178 | 179 | Status Init(Env* env, const string& filename) { 180 | string data; 181 | TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &data)); 182 | StringPiece input = data; 183 | string w; 184 | corpus_size_ = 0; 185 | std::unordered_map word_freq; 186 | while (ScanWord(&input, &w)) { 187 | ++(word_freq[w]); 188 | ++corpus_size_; 189 | } 190 | if (corpus_size_ < window_size_ * 10) { 191 | return errors::InvalidArgument("The text file ", filename, 192 | " contains too little data: ", 193 | corpus_size_, " words"); 194 | } 195 | typedef std::pair WordFreq; 196 | std::vector ordered; 197 | for (const auto& p : word_freq) { 198 | if (p.second >= min_count_) ordered.push_back(p); 199 | } 200 | LOG(INFO) << "Data file: " << filename << " contains " << data.size() 201 | << " bytes, " << corpus_size_ << " words, " << word_freq.size() 202 | << " unique words, " << ordered.size() 203 | << " unique frequent words."; 204 | word_freq.clear(); 205 | std::sort(ordered.begin(), ordered.end(), 206 | [](const WordFreq& x, const WordFreq& y) { 207 | return x.second > y.second; 208 | }); 209 | vocab_size_ = static_cast(1 + ordered.size()); 210 | Tensor word(DT_STRING, TensorShape({vocab_size_})); 211 | Tensor freq(DT_INT32, TensorShape({vocab_size_})); 212 | word.flat()(0) = "UNK"; 213 | static const int32 kUnkId = 0; 214 | std::unordered_map word_id; 215 | int64 total_counted = 0; 216 | for (std::size_t i = 0; i < ordered.size(); ++i) { 217 | const auto& w = ordered[i].first; 218 | auto id = i + 1; 219 | word.flat()(id) = w; 220 | auto word_count = ordered[i].second; 221 | freq.flat()(id) = word_count; 222 | total_counted += word_count; 223 | word_id[w] = id; 224 | } 225 | freq.flat()(kUnkId) = corpus_size_ - total_counted; 226 | word_ = word; 227 | freq_ = freq; 228 | corpus_.reserve(corpus_size_); 229 | input = data; 230 | while (ScanWord(&input, &w)) { 231 | corpus_.push_back(gtl::FindWithDefault(word_id, w, kUnkId)); 232 | } 233 | precalc_examples_.resize(kPrecalc); 234 | sentence_.resize(kSentenceSize); 235 | return Status::OK(); 236 | } 237 | }; 238 | 239 | REGISTER_KERNEL_BUILDER(Name("SkipgramWord2vec").Device(DEVICE_CPU), SkipgramWord2vecOp); 240 | 241 | class NegTrainWord2vecOp : public OpKernel { 242 | public: 243 | explicit NegTrainWord2vecOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 244 | base_.Init(0, 0); 245 | 246 | OP_REQUIRES_OK(ctx, ctx->GetAttr("num_negative_samples", &num_samples_)); 247 | 248 | std::vector vocab_count; 249 | OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_count", &vocab_count)); 250 | 251 | std::vector vocab_weights; 252 | vocab_weights.reserve(vocab_count.size()); 253 | for (const auto& f : vocab_count) { 254 | float r = std::pow(static_cast(f), 0.75f); 255 | vocab_weights.push_back(r); 256 | } 257 | sampler_ = new random::DistributionSampler(vocab_weights); 258 | } 259 | 260 | ~NegTrainWord2vecOp() { delete sampler_; } 261 | 262 | void Compute(OpKernelContext* ctx) override { 263 | Tensor w_in = ctx->mutable_input(0, false); 264 | OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(w_in.shape()), 265 | errors::InvalidArgument("Must be a matrix")); 266 | Tensor w_out = ctx->mutable_input(1, false); 267 | OP_REQUIRES(ctx, w_in.shape() == w_out.shape(), 268 | errors::InvalidArgument("w_in.shape == w_out.shape")); 269 | const Tensor& examples = ctx->input(2); 270 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(examples.shape()), 271 | errors::InvalidArgument("Must be a vector")); 272 | const Tensor& labels = ctx->input(3); 273 | OP_REQUIRES(ctx, examples.shape() == labels.shape(), 274 | errors::InvalidArgument("examples.shape == labels.shape")); 275 | const Tensor& learning_rate = ctx->input(4); 276 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(learning_rate.shape()), 277 | errors::InvalidArgument("Must be a scalar")); 278 | 279 | auto Tw_in = w_in.matrix(); 280 | auto Tw_out = w_out.matrix(); 281 | auto Texamples = examples.flat(); 282 | auto Tlabels = labels.flat(); 283 | auto lr = learning_rate.scalar()(); 284 | const int64 vocab_size = w_in.dim_size(0); 285 | const int64 dims = w_in.dim_size(1); 286 | const int64 batch_size = examples.dim_size(0); 287 | OP_REQUIRES(ctx, vocab_size == sampler_->num(), 288 | errors::InvalidArgument("vocab_size mismatches: ", vocab_size, 289 | " vs. ", sampler_->num())); 290 | 291 | // Gradient accumulator for v_in. 292 | Tensor buf(DT_FLOAT, TensorShape({dims})); 293 | auto Tbuf = buf.flat(); 294 | 295 | // Scalar buffer to hold sigmoid(+/- dot). 296 | Tensor g_buf(DT_FLOAT, TensorShape({})); 297 | auto g = g_buf.scalar(); 298 | 299 | // The following loop needs 2 random 32-bit values per negative 300 | // sample. We reserve 8 values per sample just in case the 301 | // underlying implementation changes. 302 | auto rnd = base_.ReserveSamples32(batch_size * num_samples_ * 8); 303 | random::SimplePhilox srnd(&rnd); 304 | 305 | for (int64 i = 0; i < batch_size; ++i) { 306 | const int32 example = Texamples(i); 307 | DCHECK(0 <= example && example < vocab_size) << example; 308 | const int32 label = Tlabels(i); 309 | DCHECK(0 <= label && label < vocab_size) << label; 310 | auto v_in = Tw_in.chip<0>(example); 311 | 312 | // Positive: example predicts label. 313 | // forward: x = v_in' * v_out 314 | // l = log(sigmoid(x)) 315 | // backward: dl/dx = g = sigmoid(-x) 316 | // dl/d(v_in) = g * v_out' 317 | // dl/d(v_out) = v_in' * g 318 | { 319 | auto v_out = Tw_out.chip<0>(label); 320 | auto dot = (v_in * v_out).sum(); 321 | g = (dot.exp() + 1.f).inverse(); 322 | Tbuf = v_out * (g() * lr); 323 | v_out += v_in * (g() * lr); 324 | } 325 | 326 | // Negative samples: 327 | // forward: x = v_in' * v_sample 328 | // l = log(sigmoid(-x)) 329 | // backward: dl/dx = g = -sigmoid(x) 330 | // dl/d(v_in) = g * v_out' 331 | // dl/d(v_out) = v_in' * g 332 | for (int j = 0; j < num_samples_; ++j) { 333 | const int sample = sampler_->Sample(&srnd); 334 | if (sample == label) continue; // Skip. 335 | auto v_sample = Tw_out.chip<0>(sample); 336 | auto dot = (v_in * v_sample).sum(); 337 | g = -((-dot).exp() + 1.f).inverse(); 338 | Tbuf += v_sample * (g() * lr); 339 | v_sample += v_in * (g() * lr); 340 | } 341 | 342 | // Applies the gradient on v_in. 343 | v_in += Tbuf; 344 | } 345 | } 346 | 347 | private: 348 | int32 num_samples_ = 0; 349 | random::DistributionSampler* sampler_ = nullptr; 350 | GuardedPhiloxRandom base_; 351 | }; 352 | 353 | REGISTER_KERNEL_BUILDER(Name("NegTrainWord2vec").Device(DEVICE_CPU), NegTrainWord2vecOp); 354 | 355 | } // end namespace tensorflow 356 | -------------------------------------------------------------------------------- /src/main/python/embedding/word2vec_ops.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow/core/framework/op.h" 17 | 18 | namespace tensorflow { 19 | 20 | REGISTER_OP("SkipgramWord2vec") 21 | .Output("vocab_word: string") 22 | .Output("vocab_freq: int32") 23 | .Output("words_per_epoch: int64") 24 | .Output("current_epoch: int32") 25 | .Output("total_words_processed: int64") 26 | .Output("examples: int32") 27 | .Output("labels: int32") 28 | .SetIsStateful() 29 | .Attr("filename: string") 30 | .Attr("batch_size: int") 31 | .Attr("window_size: int = 5") 32 | .Attr("min_count: int = 5") 33 | .Attr("subsample: float = 1e-3") 34 | .Doc(R"doc( 35 | Parses a text file and creates a batch of examples. 36 | 37 | vocab_word: A vector of words in the corpus. 38 | vocab_freq: Frequencies of words. Sorted in the non-ascending order. 39 | words_per_epoch: Number of words per epoch in the data file. 40 | current_epoch: The current epoch number. 41 | total_words_processed: The total number of words processed so far. 42 | examples: A vector of word ids. 43 | labels: A vector of word ids. 44 | filename: The corpus's text file name. 45 | batch_size: The size of produced batch. 46 | window_size: The number of words to predict to the left and right of the target. 47 | min_count: The minimum number of word occurrences for it to be included in the 48 | vocabulary. 49 | subsample: Threshold for word occurrence. Words that appear with higher 50 | frequency will be randomly down-sampled. Set to 0 to disable. 51 | )doc"); 52 | 53 | REGISTER_OP("NegTrainWord2vec") 54 | .Input("w_in: Ref(float)") 55 | .Input("w_out: Ref(float)") 56 | .Input("examples: int32") 57 | .Input("labels: int32") 58 | .Input("lr: float") 59 | .SetIsStateful() 60 | .Attr("vocab_count: list(int)") 61 | .Attr("num_negative_samples: int") 62 | .Doc(R"doc( 63 | Training via negative sampling. 64 | 65 | w_in: input word embedding. 66 | w_out: output word embedding. 67 | examples: A vector of word ids. 68 | labels: A vector of word ids. 69 | vocab_count: Count of words in the vocabulary. 70 | num_negative_samples: Number of negative samples per example. 71 | )doc"); 72 | 73 | } // end namespace tensorflow 74 | -------------------------------------------------------------------------------- /src/main/python/embedding/word2vec_optimized_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for word2vec_optimized module.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | import tensorflow as tf 25 | 26 | import word2vec_optimized 27 | 28 | flags = tf.app.flags 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | class Word2VecTest(tf.test.TestCase): 34 | 35 | def setUp(self): 36 | FLAGS.train_data = os.path.join(self.get_temp_dir() + "test-text.txt") 37 | FLAGS.eval_data = os.path.join(self.get_temp_dir() + "eval-text.txt") 38 | FLAGS.save_path = self.get_temp_dir() 39 | with open(FLAGS.train_data, "w") as f: 40 | f.write( 41 | """alice was beginning to get very tired of sitting by her sister on 42 | the bank, and of having nothing to do: once or twice she had peeped 43 | into the book her sister was reading, but it had no pictures or 44 | conversations in it, 'and what is the use of a book,' thought alice 45 | 'without pictures or conversations?' So she was considering in her own 46 | mind (as well as she could, for the hot day made her feel very sleepy 47 | and stupid), whether the pleasure of making a daisy-chain would be 48 | worth the trouble of getting up and picking the daisies, when suddenly 49 | a White rabbit with pink eyes ran close by her.\n""") 50 | with open(FLAGS.eval_data, "w") as f: 51 | f.write("alice she rabbit once\n") 52 | 53 | def testWord2VecOptimized(self): 54 | FLAGS.batch_size = 5 55 | FLAGS.num_neg_samples = 10 56 | FLAGS.epochs_to_train = 1 57 | FLAGS.min_count = 0 58 | word2vec_optimized.main([]) 59 | 60 | 61 | if __name__ == "__main__": 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /src/main/python/embedding/word2vec_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for word2vec module.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | import tensorflow as tf 25 | 26 | import word2vec 27 | 28 | flags = tf.app.flags 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | class Word2VecTest(tf.test.TestCase): 34 | 35 | def setUp(self): 36 | FLAGS.train_data = os.path.join(self.get_temp_dir(), "test-text.txt") 37 | FLAGS.eval_data = os.path.join(self.get_temp_dir(), "eval-text.txt") 38 | FLAGS.save_path = self.get_temp_dir() 39 | with open(FLAGS.train_data, "w") as f: 40 | f.write( 41 | """alice was beginning to get very tired of sitting by her sister on 42 | the bank, and of having nothing to do: once or twice she had peeped 43 | into the book her sister was reading, but it had no pictures or 44 | conversations in it, 'and what is the use of a book,' thought alice 45 | 'without pictures or conversations?' So she was considering in her own 46 | mind (as well as she could, for the hot day made her feel very sleepy 47 | and stupid), whether the pleasure of making a daisy-chain would be 48 | worth the trouble of getting up and picking the daisies, when suddenly 49 | a White rabbit with pink eyes ran close by her.\n""") 50 | with open(FLAGS.eval_data, "w") as f: 51 | f.write("alice she rabbit once\n") 52 | 53 | def testWord2Vec(self): 54 | FLAGS.batch_size = 5 55 | FLAGS.num_neg_samples = 10 56 | FLAGS.epochs_to_train = 1 57 | FLAGS.min_count = 0 58 | word2vec.main([]) 59 | 60 | 61 | if __name__ == "__main__": 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /src/main/python/embedding/word_to_char_embedding.py: -------------------------------------------------------------------------------- 1 | from ast import literal_eval 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import os 6 | 7 | import sys 8 | 9 | import re 10 | from sklearn.manifold import TSNE 11 | 12 | 13 | def load_embedding(dictionary_path: str, embedding_path: str) -> (dict, np.ndarray): 14 | with open(dictionary_path) as dictionary_file: 15 | dictionary = literal_eval(dictionary_file.read()) 16 | embedding = np.loadtxt(embedding_path) 17 | return dictionary, embedding 18 | 19 | 20 | def to_char_embedding(dictionary: dict, embedding: np.ndarray) -> (dict, np.ndarray): 21 | """ 22 | Calculate a character embedding by averaging an existing word embedding, as shown in 23 | http://minimaxir.com/2017/04/char-embeddings/ 24 | """ 25 | vectors = {"UNK": (np.zeros(embedding[0].shape), 1)} 26 | for token, idx in dictionary.items(): 27 | if token == "UNK": 28 | continue 29 | for char in token: 30 | if char in vectors: 31 | vectors[char] = (vectors[char][0] + embedding[idx], 32 | vectors[char][1] + 1) 33 | else: 34 | vectors[char] = (embedding[idx], 1) 35 | 36 | embedding_items = {c: v[0]/v[1] for c, v in vectors.items()}.items() 37 | char_dictionary = {e[0]: idx for idx, e in enumerate(embedding_items)} 38 | char_embedding = np.array(list(map(lambda ei: ei[1], embedding_items))) 39 | return char_dictionary, char_embedding 40 | 41 | 42 | def save_char_embedding(dictionary: dict, embedding: np.ndarray, path: str): 43 | with open(os.path.join(path, "char_dictionary.txt"), "w") as embedding_file: 44 | embedding_file.write(str(dictionary)) 45 | np.savetxt(os.path.join(path, "char_embeddings.txt"), embedding) 46 | 47 | 48 | def get_color(char: str) -> str: 49 | if re.match("[A-ZÄÖÜ]", char) is not None: 50 | return "red" 51 | if re.match("[a-zäöüß]", char) is not None: 52 | return "blue" 53 | if re.match("[0-9]", char) is not None: 54 | return "green" 55 | return "yellow" 56 | 57 | 58 | def plot(dictionary: dict, embedding: np.ndarray, path: str="/tmp/char_tsne.png"): 59 | print("Plotting ...") 60 | tsne = TSNE(perplexity=7, n_components=2, init='pca', n_iter=5000, method='exact') 61 | low_dim_embedding = tsne.fit_transform(embedding) 62 | reverse_dict = {v: k for k, v in dictionary.items()} 63 | labels = [reverse_dict[i] for i in range(len(dictionary))] 64 | colors = list(map(get_color, labels)) 65 | plt.figure(figsize=(32, 32)) 66 | for i, label in enumerate(labels): 67 | x, y = low_dim_embedding[i, :] 68 | plt.scatter(x, y, color=colors[i]) 69 | plt.annotate(label, 70 | xy=(x, y), 71 | xytext=(5, 2), 72 | textcoords='offset points', 73 | ha='right', 74 | va='bottom') 75 | plt.savefig(path) 76 | print("Plot saved to", path) 77 | 78 | 79 | if __name__ == '__main__': 80 | if len(sys.argv) != 4: 81 | print("dictionary_path embedding_path output_dir") 82 | exit(-1) 83 | 84 | dictionary, embedding = load_embedding(sys.argv[1], sys.argv[2]) 85 | char_dictionary, char_embedding = to_char_embedding(dictionary, embedding) 86 | plot(char_dictionary, char_embedding) 87 | save_char_embedding(char_dictionary, char_embedding, sys.argv[3]) 88 | -------------------------------------------------------------------------------- /src/main/python/eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from ast import literal_eval 3 | from typing import Callable, List 4 | 5 | import numpy as np 6 | from nltk import edit_distance 7 | 8 | from EvalResult import EvalResult 9 | from LayeredScorer import LayeredScorer 10 | from parallel import parmap 11 | from repl import has_error 12 | 13 | 14 | def evaluate_ngrams(ngrams: [[str]], subjects: [str], 15 | check_function: Callable[[List[str]], bool]) -> EvalResult: 16 | eval_result = EvalResult() 17 | substitutions = create_substitution_dict(subjects) 18 | for ngram in ngrams: 19 | middle = int(len(ngram) / 2) 20 | middle_word = ngram[middle] 21 | for substitution in substitutions[middle_word]: 22 | eval_ngram = ngram[:] 23 | eval_ngram[middle] = substitution 24 | error_detected = check_function(eval_ngram) 25 | if not error_detected and substitution == middle_word: 26 | eval_result.add_tn() 27 | elif error_detected and substitution == middle_word: 28 | eval_result.add_fp() 29 | print("false positive:", eval_ngram) 30 | elif not error_detected and substitution != middle_word: 31 | eval_result.add_fn() 32 | elif error_detected and substitution != middle_word: 33 | eval_result.add_tp() 34 | return eval_result 35 | 36 | 37 | def get_relevant_ngrams(sentences: str, subjects: [str], n: int=5) -> [[str]]: 38 | ngrams = [] 39 | half_context_length = int(n / 2) 40 | tokens = sentences.split(" ") 41 | for i in range(half_context_length, len(tokens) - half_context_length + 1): 42 | if tokens[i] in subjects: 43 | ngrams.append(tokens[i-half_context_length:i+half_context_length+1]) 44 | return ngrams 45 | 46 | 47 | def similar_words(word: str, words: List[str]): 48 | if len(word) < 4: 49 | max_distance = 1 50 | else: 51 | max_distance = 1 52 | return list(filter(lambda w: edit_distance(word, w, substitution_cost=1, transpositions=True) <= max_distance, 53 | words)) 54 | 55 | 56 | def create_substitution_dict(subjects: List[str]) -> dict: 57 | return {subject: similar_words(subject, subjects) for subject in subjects} 58 | 59 | 60 | def main(): 61 | if len(sys.argv) != 4: 62 | raise ValueError("Expected dict, finalembedding, W, b") 63 | 64 | dictionary_path = sys.argv[1] 65 | embedding_path = sys.argv[2] 66 | weights_path = sys.argv[3] 67 | 68 | with open(dictionary_path) as dictionary_file: 69 | dictionary = literal_eval(dictionary_file.read()) 70 | embedding = np.loadtxt(embedding_path) 71 | 72 | subjects = ["als", "also", "da", "das", "dass", "de", "den", "denn", "die", "durch", "zur", "ihm", "im", "um", "nach", "noch", "war", "was"] 73 | # subjects = ["and", "end", "as", "at", "is", "do", "for", "four", "form", "from", "he", "if", "is", "its", "it", "no", "now", "on", "one", "same", "some", "than", "that", "then", "their", "there", "them", "the", "they", "to", "was", "way", "were", "where"] 74 | print(subjects) 75 | 76 | # eval_subjects = subjects 77 | eval_subjects = ["das", "dass"] 78 | 79 | with open("/tmp/tokens-de") as file: 80 | sentences = file.read() 81 | 82 | scorer = LayeredScorer(weights_path) 83 | ngrams = get_relevant_ngrams(sentences, eval_subjects, n=scorer.context_size(embedding.shape[1])+1) 84 | 85 | eval_results = parmap(lambda t: evaluate_ngrams(ngrams, eval_subjects, lambda ngram: has_error(dictionary, embedding, scorer, ngram, subjects, error_threshold=t, suggestion_threshold=t)), 86 | np.arange(0, 1, .025)) 87 | # [.3]) 88 | print(eval_results) 89 | 90 | 91 | if __name__ == '__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /src/main/python/languagetool/languagetool.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Iterable 2 | 3 | from py4j.java_gateway import JavaGateway, JavaObject, GatewayParameters 4 | from py4j.java_collections import JavaList, ListConverter 5 | from py4j.protocol import Py4JNetworkError 6 | 7 | 8 | class LanguageTool: 9 | 10 | tagsets = { 11 | "en": [None, '$', "''", ',', '.', ':', 'CC', 'CD', 'DT', 'EX', 'IN', 'JJ', 'JJR', 'JJS', 'MD', 'NN', 'NN:U', 'NN:UN', 'NNP', 'NNPS', 'NNS', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP$', 'WRB', '``'] 12 | } 13 | 14 | def __init__(self, languageCode): 15 | """ 16 | Parameters 17 | ---------- 18 | languageCode: code like "de-DE" or "en" 19 | """ 20 | try: 21 | self.gateway = JavaGateway(gateway_parameters=GatewayParameters(auto_convert=True)) 22 | self.languageCode = languageCode 23 | self.language = self.gateway.jvm.org.languagetool.Languages.getLanguageForShortCode(languageCode) 24 | self.tagger = self.language.getTagger() 25 | except Py4JNetworkError as e: 26 | raise RuntimeError("Could not connect to JVM. Is ./gradlew pythonGateway running?") from e 27 | 28 | def tokenized_sentences(self, sentences: str) -> JavaList: 29 | """ 30 | Split a string into sentences, and each sentences into tokens using LanguageTool. 31 | """ 32 | return self.gateway.jvm.de.hhu.mabre.languagetool.FileTokenizer.tokenizedSentences(self.languageCode, sentences) 33 | 34 | def tokenize(self, sentences: str) -> [str]: 35 | """ 36 | Tokenize one or several sentences using LanguageTool. 37 | """ 38 | return list(self.gateway.jvm.de.hhu.mabre.languagetool.FileTokenizer.tokenize(self.languageCode, sentences)) 39 | 40 | @staticmethod 41 | def _get_tags_of_tagged_tokens(taggedToken: JavaObject): 42 | return list(map(lambda reading: reading.getPOSTag(), taggedToken.getReadings())) 43 | 44 | def tag(self, tokenizedSentences: Union[Iterable[str], JavaList]) -> [(str, [str])]: 45 | """ 46 | Tag a tokenized text using the tagger of a language. All valid tags for a each token are returned. 47 | """ 48 | if type(tokenizedSentences) is not JavaList: 49 | tokens = ListConverter().convert(tokenizedSentences, self.gateway._gateway_client) 50 | else: 51 | tokens = tokenizedSentences 52 | return list(zip(tokens, map(LanguageTool._get_tags_of_tagged_tokens, self.tagger.tag(tokens)))) 53 | 54 | def tag_token(self, token: str) -> [str]: 55 | """ 56 | Tag a single token using the tagger of a language. All valid tags for the token are returned. 57 | """ 58 | tokens = ListConverter().convert([token], self.gateway._gateway_client) 59 | tags = list(set(LanguageTool._get_tags_of_tagged_tokens(self.tagger.tag(tokens)[0]))) 60 | tags.sort() # hm, "Zusammenhänge" gets tags more than once 61 | return tags 62 | 63 | 64 | def tagset(self): 65 | return LanguageTool.tagsets[self.language.getShortCode()] 66 | -------------------------------------------------------------------------------- /src/main/python/nn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def weight_variable(shape): 4 | initial = tf.truncated_normal(shape,stddev = 0.1) 5 | return tf.Variable(initial) 6 | 7 | 8 | def bias_variable(shape): 9 | initial = tf.constant(0.1,shape=shape) 10 | return tf.Variable(initial) 11 | 12 | 13 | def conv2d(x, W, padding: str="SAME"): 14 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding=padding) 15 | 16 | 17 | def max_pool(x, l): 18 | return tf.nn.max_pool(x, ksize=[1, l, 1, 1], strides=[1, 1, 1, 1], padding='VALID') 19 | 20 | 21 | def write_4dmat(path: str, mat): 22 | with open(path, "w") as f: 23 | for i in range(mat.size): 24 | f.write(str(mat.flatten(order='C')[i])) 25 | if i != mat.size: 26 | f.write(", ") 27 | if (i+1) % mat.shape[-1] == 0: 28 | f.write("&\n") 29 | -------------------------------------------------------------------------------- /src/main/python/nn_word_sequence.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import codecs 3 | import json 4 | import math 5 | import sys 6 | from ast import literal_eval 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | from sklearn.utils import shuffle 11 | 12 | import nn 13 | from repl import get_probabilities 14 | 15 | 16 | class NeuralNetwork: 17 | def __init__(self, dictionary_path: str, embedding_path: str, training_data_file: str, test_data_file: str, 18 | batch_size: int=1000, epochs: int=1000, use_after: bool=True, keep_prob: float=0.7): 19 | print(locals()) 20 | 21 | self.hidden_layer_size = 8 22 | self.use_after = use_after 23 | self.num_conv_filters = 32 24 | self.max_sequence_length = 30 25 | self.batch_size = batch_size 26 | self.epochs = epochs 27 | self.keep_prob = keep_prob 28 | 29 | with open(dictionary_path) as dictionary_file: 30 | self.dictionary = literal_eval(dictionary_file.read()) 31 | self.embedding = np.loadtxt(embedding_path) 32 | print("embedding shape", np.shape(self.embedding)) 33 | 34 | self._embedding_size = np.shape(self.embedding)[1] 35 | 36 | self._db = self.get_db(training_data_file) 37 | self._TRAINING_SAMPLES = len(self._db["groundTruths"]) 38 | self._num_outputs = len(self._db["groundTruths"][0]) 39 | self._current_batch_number = 0 40 | 41 | self._db_validate = self.get_db(test_data_file) 42 | 43 | self.setup_net() 44 | 45 | print("determined parameters: embedding_size=%d" % self._embedding_size) 46 | 47 | def setup_net(self): 48 | input_length = self.max_sequence_length * (self.use_after + 1) 49 | 50 | context_output_size = 4 * self._embedding_size 51 | 52 | with tf.name_scope('input'): 53 | self.x = tf.placeholder(tf.float32, [None, input_length * self._embedding_size]) 54 | self.x_context = tf.placeholder(tf.float32, [None, context_output_size]) 55 | 56 | with tf.name_scope('ground-truth'): 57 | self.y_ = tf.placeholder(tf.float32, shape=[None, self._num_outputs]) 58 | 59 | with tf.name_scope('conv_layer'): 60 | self.dropout = tf.placeholder(tf.float32) 61 | x_image = tf.reshape(tf.nn.dropout(self.x, self.dropout), [-1, input_length, self._embedding_size, 1]) 62 | filter_size = 5 63 | self.W_conv1 = nn.weight_variable([filter_size, self._embedding_size, 1, self.num_conv_filters]) 64 | self.b_conv1 = nn.bias_variable([self.num_conv_filters]) 65 | h_conv1 = tf.nn.relu(nn.conv2d(x_image, self.W_conv1, padding="VALID") + self.b_conv1) 66 | h_pool1 = nn.max_pool(h_conv1, input_length - filter_size + 1) 67 | h_pool1_flat = tf.reshape(h_pool1, [-1, self.num_conv_filters]) 68 | 69 | with tf.name_scope('readout_layer'): 70 | self.W_fc2 = nn.weight_variable([self.num_conv_filters + context_output_size, self._num_outputs]) 71 | self.b_fc2 = nn.bias_variable([self._num_outputs]) 72 | fc2_input = tf.concat([h_pool1_flat, tf.nn.dropout(self.x_context, self.dropout)], axis=1) 73 | self.y = tf.matmul(fc2_input, self.W_fc2) + self.b_fc2 74 | 75 | with tf.name_scope('train'): 76 | cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.y, labels=self.y_)) 77 | tf.summary.scalar('loss_function', cross_entropy) 78 | self.train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 79 | 80 | with tf.name_scope('accuracy'): 81 | self.correct_prediction = tf.equal(tf.argmax(self.y, 1), tf.argmax(self.y_, 1)) 82 | self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32)) 83 | tf.summary.scalar('accuracy', self.accuracy) 84 | 85 | self.sess = tf.InteractiveSession() 86 | self.sess.run(tf.global_variables_initializer()) 87 | 88 | def get_word_representation(self, word): 89 | if word in self.dictionary: 90 | return self.embedding[self.dictionary[word]] 91 | else: 92 | return self.embedding[self.dictionary["UNK"]] 93 | 94 | @staticmethod 95 | def take(sentence: [str], max_length: int): 96 | """keep first <=max_sequence_length words, fill rest with UNK""" 97 | return (sentence + ["UNK"] * max_length)[:max_length] 98 | 99 | @staticmethod 100 | def takeLast(sentence: [str], max_length: int): 101 | """keep last <=max_sequence_length words, fill rest with UNK""" 102 | return (["UNK"] * max_length + sentence)[-max_length:] 103 | 104 | def embed(self, words: [str]) -> np.ndarray: 105 | return np.array(list(map(lambda w: self.get_word_representation(w), words))) 106 | 107 | def get_db(self, path): 108 | db = dict() 109 | raw_db = json.load(open(path)) 110 | db["tokensBefore"] = np.asarray(list(map(lambda ws: self.embed(self.takeLast(ws, self.max_sequence_length)).flatten(), raw_db["tokensBefore"]))) 111 | db["tokensAfter"] = np.asarray(list(map(lambda ws: self.embed(self.take(ws, self.max_sequence_length)).flatten(), raw_db["tokensAfter"]))) 112 | db["context"] = np.asarray(list(map(lambda b, a: self.embed(self.takeLast(b, 2) + self.take(a, 2)).flatten(), raw_db["tokensBefore"], raw_db["tokensAfter"]))) 113 | db["groundTruths"] = raw_db["groundTruths"] 114 | db["context_str"] = list(map(lambda b, a: " ".join(self.takeLast(b, self.max_sequence_length)) + " … " + " ".join(self.take(a, self.max_sequence_length)), raw_db["tokensBefore"], raw_db["tokensAfter"])) 115 | db["tokensBefore"], db["tokensAfter"], db["groundTruths"], db["context_str"] = \ 116 | shuffle(db["tokensBefore"], db["tokensAfter"], db["groundTruths"], db["context_str"]) 117 | print("%s loaded, containing %d entries, class distribution: %s" 118 | % (path, len(db["groundTruths"]), str(np.sum(np.asarray(db["groundTruths"]), axis=0)))) 119 | return db 120 | 121 | def get_batch(self): 122 | if self._current_batch_number * self.batch_size > self._TRAINING_SAMPLES: 123 | self._current_batch_number = 0 124 | start_index = self._current_batch_number * self.batch_size 125 | end_index = (self._current_batch_number + 1) * self.batch_size 126 | batch = dict() 127 | tokens_before = self._db["tokensBefore"][start_index:end_index] 128 | tokens_after = self._db["tokensAfter"][start_index:end_index] 129 | context = self._db["context"][start_index:end_index] 130 | self.assign_sentences_to_batch(batch, tokens_before, tokens_after, context) 131 | batch[self.y_] = self._db["groundTruths"][start_index:end_index] 132 | self._current_batch_number = self._current_batch_number + 1 133 | batch[self.dropout] = self.keep_prob 134 | # print("d" + str(len(batch[self.word1]))) 135 | return batch 136 | 137 | def get_all_training_data(self): 138 | batch = dict() 139 | tokens_before = self._db["tokensBefore"][:] 140 | tokens_after = self._db["tokensAfter"][:] 141 | context = self._db["context"][:] 142 | self.assign_sentences_to_batch(batch, tokens_before, tokens_after, context) 143 | batch[self.y_] = self._db["groundTruths"][:] 144 | batch[self.dropout] = 1 145 | # print("d" + str(len(batch[self.word1]))) 146 | return batch 147 | 148 | def assign_sentences_to_batch(self, batch, tokens_before, tokens_after, context): 149 | if self.use_after: 150 | batch[self.x] = np.concatenate([tokens_before, tokens_after], axis=1) 151 | else: 152 | batch[self.x] = tokens_before 153 | batch[self.x_context] = np.concatenate([context], axis=1) 154 | 155 | def train(self): 156 | steps = math.ceil(self._TRAINING_SAMPLES / self.batch_size) 157 | print("Steps: %d, %d steps in %d epochs" % (steps * self.epochs, steps, self.epochs)) 158 | for e in range(self.epochs): 159 | for i in range(steps): 160 | fd = self.get_batch() 161 | _ = self.sess.run([self.train_step], fd) # train with next batch 162 | if e % 10 == 0: 163 | self._print_accuracy(e) 164 | if e % 1000 == 0 and e > 0: 165 | self.validate() 166 | self.validate_error_detection() 167 | self._print_accuracy(self.epochs) 168 | 169 | def _print_accuracy(self, epoch): 170 | train_acc = self.sess.run([self.accuracy], self.get_all_training_data()) 171 | print("epoch %d, training accuracy %f" % (epoch, train_acc[0])) 172 | sys.stdout.flush() 173 | 174 | def save_weights(self, output_path): 175 | nn.write_4dmat(output_path + "/W_conv1.txt", tf.transpose(self.W_conv1, (3, 0, 1, 2)).eval()) 176 | nn.write_4dmat(output_path + "/b_conv1.txt", self.b_conv1.eval()) 177 | np.savetxt(output_path + "/b_fc2.txt", self.b_fc2.eval()) 178 | np.savetxt(output_path + "/W_fc2.txt", self.W_fc2.eval()) 179 | 180 | def get_suggestion(self, tokens_before, tokens_after, context, threshold=.5): 181 | scores = self.get_score(tokens_before, tokens_after, context) 182 | if np.max(scores) > threshold and np.min(scores) < -threshold: 183 | return np.argmax(scores) 184 | else: 185 | return -1 186 | 187 | def get_score(self, tokens_before, tokens_after, context): 188 | fd = {self.y_: [list(np.zeros(self._num_outputs))], 189 | self.x: [np.concatenate([tokens_before, tokens_after])] if self.use_after else [tokens_before], 190 | self.x_context: [context], 191 | self.dropout: 1} 192 | scores = self.y.eval(fd)[0] 193 | return scores 194 | 195 | def validate(self, verbose=False, threshold=.5): 196 | print("--- Validation of word prediction, threshold", threshold) 197 | 198 | correct = list(np.zeros(self._num_outputs)) 199 | incorrect = list(np.zeros(self._num_outputs)) 200 | unclassified = list(np.zeros(self._num_outputs)) 201 | tp = 0 202 | fp = 0 203 | tn = 0 204 | fn = 0 205 | 206 | for i in range(len(self._db_validate["groundTruths"])): 207 | suggestion = self.get_suggestion(self._db_validate["tokensBefore"][i], self._db_validate["tokensAfter"][i], self._db_validate["context"][i], threshold=threshold) 208 | ground_truth = np.argmax(self._db_validate["groundTruths"][i]) 209 | if suggestion == -1: 210 | unclassified[ground_truth] = unclassified[ground_truth] + 1 211 | if verbose: 212 | print("no decision:", self._db_validate["context_str"][i]) 213 | tn = tn + 1 214 | fn = fn + 1 215 | elif suggestion == ground_truth: 216 | correct[ground_truth] = correct[ground_truth] + 1 217 | if verbose: 218 | print("correct suggestion:", self._db_validate["context_str"][i]) 219 | tp = tp + 1 220 | tn = tn + 1 221 | else: 222 | incorrect[ground_truth] = incorrect[ground_truth] + 1 223 | if verbose: 224 | print("possible wrong suggestion:", self._db_validate["context_str"][i]) 225 | fp = fp + 1 226 | fn = fn + 1 227 | 228 | accuracy = list(map(lambda c, i: c/(c+i) if (c+i) > 0 else np.nan, correct, incorrect)) 229 | total_accuracy = list(map(lambda c, i, u: c/(c+i+u), correct, incorrect, unclassified)) 230 | 231 | print("correct:", correct) 232 | print("incorrect:", incorrect) 233 | print("accuracy:", accuracy) 234 | print("unclassified:", unclassified) 235 | print("total accuracy:", total_accuracy) 236 | 237 | print("tp", tp) 238 | print("tn", tn) 239 | print("fp", fp) 240 | print("fn", fn) 241 | print("precision:", float(tp)/(tp+fp) if (tp+fp) > 0 else 1) 242 | print("recall:", float(tp)/(tp+fn) if (tp+fn) > 0 else 0) 243 | 244 | def validate_error_detection(self, suggestion_threshold: float=0.5, error_threshold: float=0.2, verbose=False): 245 | print("--- Error Detection Validation: suggestion_threshold %4.2f, error_threshold %4.2f" 246 | % (suggestion_threshold, error_threshold)) 247 | 248 | correct = list(np.zeros(self._num_outputs)) 249 | incorrect = list(np.zeros(self._num_outputs)) 250 | unclassified = list(np.zeros(self._num_outputs)) 251 | tp = 0 252 | fp = 0 253 | fn = 0 254 | 255 | for i in range(len(self._db_validate["groundTruths"])): 256 | scores = self.get_score(self._db_validate["tokensBefore"][i], self._db_validate["tokensAfter"][i], self._db_validate["context"][i]) 257 | probabilities = get_probabilities(scores) 258 | best_match = self.get_suggestion(self._db_validate["tokensBefore"][i], self._db_validate["tokensAfter"][i], self._db_validate["context"][i]) 259 | best_match_score = scores[best_match] 260 | ground_truth = np.argmax(self._db_validate["groundTruths"][i]) 261 | ground_truth_probability = probabilities[ground_truth] 262 | 263 | if best_match_score > suggestion_threshold and error_threshold > ground_truth_probability: 264 | # suggest alternative 265 | incorrect[ground_truth] = incorrect[ground_truth] + 1 266 | if verbose: 267 | print("false alarm:", " ".join(self._db_validate["context_str"][i])) 268 | fp = fp + 1 269 | fn = fn + 1 270 | elif ground_truth_probability > suggestion_threshold: 271 | # ground truth will be suggested 272 | correct[ground_truth] = correct[ground_truth] + 1 273 | if verbose: 274 | print("correct suggestion included:", " ".join(self._db_validate["context_str"][i])) 275 | tp = tp + 1 276 | else: 277 | # nothing happens 278 | unclassified[ground_truth] = unclassified[ground_truth] + 1 279 | if verbose: 280 | print("no decision:", " ".join(self._db_validate["context_str"][i])) 281 | fn = fn + 1 282 | 283 | accuracy = list(map(lambda c, i: c/(c+i), correct, incorrect)) 284 | total_accuracy = list(map(lambda c, i, u: c/(c+i+u), correct, incorrect, unclassified)) 285 | 286 | print("correct:", correct) 287 | print("incorrect:", incorrect) 288 | print("accuracy:", accuracy) 289 | print("unclassified:", unclassified) 290 | print("total accuracy:", total_accuracy) 291 | 292 | print("tp", tp) 293 | print("fp", fp) 294 | print("fn", fn) 295 | print("precision:", float(tp)/(tp+fp) if (tp+fp) > 0 else 1) 296 | print("recall:", float(tp)/(tp+fn) if (tp+fn) > 0 else 0) 297 | accuracy = list(map(lambda c, i: c/(c+i), correct, incorrect)) 298 | print("accuracy:", accuracy) 299 | 300 | micro_accuracy = np.sum(correct)/(np.sum(correct)+np.sum(incorrect)) 301 | print("micro accuracy:", micro_accuracy) 302 | 303 | 304 | def main(): 305 | if len(sys.argv) != 6: 306 | print("dictionary_path embedding_path training_data_file test_data_file output_path") 307 | exit(-1) 308 | dictionary_path = sys.argv[1] 309 | embedding_path = sys.argv[2] 310 | training_data_file = sys.argv[3] 311 | test_data_file = sys.argv[4] 312 | output_path = sys.argv[5] 313 | network = NeuralNetwork(dictionary_path, embedding_path, training_data_file, test_data_file) 314 | network.train() 315 | network.save_weights(output_path) 316 | network.validate(verbose=True, threshold=.5) 317 | network.validate(verbose=False, threshold=1) 318 | network.validate_error_detection(verbose=False, suggestion_threshold=.5, error_threshold=.5) 319 | network.validate_error_detection(verbose=False, suggestion_threshold=.5, error_threshold=.4) 320 | network.validate_error_detection(verbose=False, suggestion_threshold=.5, error_threshold=.3) 321 | network.validate_error_detection(verbose=False, suggestion_threshold=.5, error_threshold=.2) 322 | 323 | 324 | if __name__ == '__main__': 325 | main() 326 | -------------------------------------------------------------------------------- /src/main/python/nn_words.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import codecs 3 | import math 4 | import sys 5 | from ast import literal_eval 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | from sklearn.utils import shuffle 10 | 11 | import nn 12 | from repl import get_probabilities 13 | 14 | 15 | class NeuralNetwork: 16 | def __init__(self, dictionary_path: str, embedding_path: str, training_data_file: str, test_data_file: str, 17 | batch_size: int=1000, epochs: int=1000, use_hidden_layer: bool=False, use_conv_layer: bool=False): 18 | print(locals()) 19 | 20 | self.use_hidden_layer = use_hidden_layer 21 | self.use_conv_layer = use_conv_layer 22 | self.hidden_layer_size = 32 23 | self.num_conv_filters = 32 24 | self.batch_size = batch_size 25 | self.epochs = epochs 26 | 27 | with open(dictionary_path) as dictionary_file: 28 | self.dictionary = literal_eval(dictionary_file.read()) 29 | self.embedding = np.loadtxt(embedding_path) 30 | print("embedding shape", np.shape(self.embedding)) 31 | 32 | self._input_size = np.shape(self.embedding)[1] 33 | 34 | self._db = self.get_db(training_data_file) 35 | self._TRAINING_SAMPLES = len(self._db["groundtruths"]) 36 | self._num_outputs = len(self._db["groundtruths"][0]) 37 | self._current_batch_number = 0 38 | 39 | self._num_inputs = len(self._db["ngrams"][0]) - 1 40 | 41 | self._db_validate = self.get_db(test_data_file) 42 | 43 | self.setup_net() 44 | 45 | print("determined parameters: num_inputs=%d, input_size=%d" % (self._num_inputs, self._input_size)) 46 | 47 | def setup_net(self): 48 | with tf.name_scope('input'): 49 | self.words = [] 50 | for i in range(self._num_inputs): 51 | self.words.append(tf.placeholder(tf.float32, [None, self._input_size])) 52 | x = tf.concat(self.words, 1) 53 | 54 | with tf.name_scope('ground-truth'): 55 | self.y_ = tf.placeholder(tf.float32, shape=[None, self._num_outputs]) 56 | 57 | if self.use_conv_layer: 58 | with tf.name_scope('conv_layer'): 59 | pooling_positions = 3 60 | x_image = tf.reshape(x, [-1, self._input_size, self._num_inputs, 1]) 61 | self.W_conv1 = nn.weight_variable([self._input_size, self._num_inputs - (pooling_positions - 1), 1, self.num_conv_filters]) 62 | self.b_conv1 = nn.bias_variable([self.num_conv_filters]) 63 | h_conv1 = tf.nn.relu(nn.conv2d(x_image, self.W_conv1) + self.b_conv1) 64 | h_pool1 = nn.max_pool(h_conv1, self._input_size, self._num_inputs - (pooling_positions - 1)) 65 | h_pool1_flat = tf.reshape(h_pool1, [-1, self.num_conv_filters * pooling_positions]) 66 | 67 | if self.use_hidden_layer: 68 | with tf.name_scope('hidden_layer'): 69 | self.W_fc1 = nn.weight_variable([self._num_inputs * self._input_size, self._num_inputs * self.hidden_layer_size]) 70 | self.b_fc1 = nn.bias_variable([self._num_inputs * self.hidden_layer_size]) 71 | hidden_layer = tf.nn.relu(tf.matmul(x, self.W_fc1) + self.b_fc1) 72 | 73 | with tf.name_scope('readout_layer'): 74 | if self.use_hidden_layer: 75 | self.W_fc2 = nn.weight_variable([self._num_inputs * self.hidden_layer_size, self._num_outputs]) 76 | self.b_fc2 = nn.bias_variable([self._num_outputs]) 77 | self.y = tf.matmul(hidden_layer, self.W_fc2) + self.b_fc2 78 | elif self.use_conv_layer: 79 | self.W_fc2 = nn.weight_variable([self.num_conv_filters * pooling_positions, self._num_outputs]) 80 | self.b_fc2 = nn.bias_variable([self._num_outputs]) 81 | self.y = tf.matmul(h_pool1_flat, self.W_fc2) + self.b_fc2 82 | else: 83 | self.W_fc1 = nn.weight_variable([self._num_inputs * self._input_size, self._num_outputs]) 84 | self.b_fc1 = nn.bias_variable([self._num_outputs]) 85 | self.y = tf.matmul(x, self.W_fc1) + self.b_fc1 86 | 87 | with tf.name_scope('train'): 88 | cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.y, labels=self.y_)) 89 | tf.summary.scalar('loss_function', cross_entropy) 90 | self.train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 91 | 92 | with tf.name_scope('accuracy'): 93 | self.correct_prediction = tf.equal(tf.argmax(self.y, 1), tf.argmax(self.y_, 1)) 94 | self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32)) 95 | tf.summary.scalar('accuracy', self.accuracy) 96 | 97 | self.sess = tf.InteractiveSession() 98 | self.sess.run(tf.global_variables_initializer()) 99 | 100 | def get_word_representation(self, word): 101 | if word in self.dictionary: 102 | return self.embedding[self.dictionary[word]] 103 | else: 104 | return self.embedding[self.dictionary["UNK"]] 105 | 106 | def get_db(self, path): 107 | db = dict() 108 | raw_db = eval(codecs.open(path, "r", "utf-8").read()) 109 | # raw_db = {'ngrams':[['too','early','to','rule','out'],['park','next','to','the','town'],['has','right','to','destroy','houses'],['ll','have','to','move','them'],['percent','increase','to','its','budget'],['ll','continue','to','improve','in'],['This','applies','to','footwear','too'],['t','appear','to','be','too'],['taking','action','to','prevent','a'],['too','close','to','the','fire'],['It','rates','to','get','a'],['it','was','too','early','to'],['houses',',','too','.','We'],['to','footwear','too','-','can'],['to','be','too','much','money'],['been','waiting','too','long','for'],['leave','rates','too','low','for'],['low','for','too','long',','],['not','going','too','close','to']],'groundtruths':[[1,0],[1,0],[1,0],[1,0],[1,0],[1,0],[1,0],[1,0],[1,0],[1,0],[1,0],[0,1],[0,1],[0,1],[0,1],[0,1],[0,1],[0,1],[0,1]]} 110 | db["ngrams"] = np.asarray( 111 | list(map(lambda ws: list(map(lambda w: self.get_word_representation(w), ws)), raw_db["ngrams"]))) 112 | db["groundtruths"] = raw_db["groundtruths"] 113 | db["ngrams"], db["groundtruths"], db["ngrams_raw"] = shuffle(db["ngrams"], db["groundtruths"], raw_db["ngrams"]) 114 | print("%s loaded, containing %d entries, class distribution: %s" 115 | % (path, len(db["groundtruths"]), str(np.sum(np.asarray(db["groundtruths"]), axis=0)))) 116 | return db 117 | 118 | def get_batch(self): 119 | if self._current_batch_number * self.batch_size > self._TRAINING_SAMPLES: 120 | self._current_batch_number = 0 121 | start_index = self._current_batch_number * self.batch_size 122 | end_index = (self._current_batch_number + 1) * self.batch_size 123 | batch = dict() 124 | ngrams = self._db["ngrams"][start_index:end_index] 125 | self.assign_ngram_to_batch(batch, ngrams) 126 | batch[self.y_] = self._db["groundtruths"][start_index:end_index] 127 | self._current_batch_number = self._current_batch_number + 1 128 | # print("d" + str(len(batch[self.word1]))) 129 | return batch 130 | 131 | def get_all_training_data(self): 132 | batch = dict() 133 | ngrams = self._db["ngrams"][:] 134 | self.assign_ngram_to_batch(batch, ngrams) 135 | batch[self.y_] = self._db["groundtruths"][:] 136 | # print("d" + str(len(batch[self.word1]))) 137 | return batch 138 | 139 | def assign_ngram_to_batch(self, batch, ngrams): 140 | for idx, word in enumerate(self.words): 141 | i = self.context_index_for_ngram_position(idx) 142 | batch[word] = ngrams[:, i] 143 | 144 | def train(self): 145 | steps = math.ceil(self._TRAINING_SAMPLES / self.batch_size) 146 | print("Steps: %d, %d steps in %d epochs" % (steps * self.epochs, steps, self.epochs)) 147 | for e in range(self.epochs): 148 | for i in range(steps): 149 | fd = self.get_batch() 150 | _ = self.sess.run([self.train_step], fd) # train with next batch 151 | if e % 1 == 0: 152 | self._print_accuracy(e) 153 | if e % 1000 == 0: 154 | print("--- VALIDATION PERFORMANCE -") 155 | self.validate() 156 | self.validate_error_detection() 157 | print("----------------------------") 158 | self._print_accuracy(self.epochs) 159 | 160 | def _print_accuracy(self, epoch): 161 | train_acc = self.sess.run([self.accuracy], self.get_all_training_data()) 162 | print("epoch %d, training accuracy %f" % (epoch, train_acc[0])) 163 | sys.stdout.flush() 164 | 165 | def save_weights(self, output_path): 166 | if self.use_conv_layer: 167 | nn.write_4dmat(output_path + "/W_conv1.txt", tf.transpose(self.W_conv1, (3, 0, 1, 2)).eval()) 168 | nn.write_4dmat(output_path + "/b_conv1.txt", self.b_conv1.eval()) 169 | else: 170 | np.savetxt(output_path + "/W_fc1.txt", self.W_fc1.eval()) 171 | np.savetxt(output_path + "/b_fc1.txt", self.b_fc1.eval()) 172 | if self.use_hidden_layer or self.use_conv_layer: 173 | np.savetxt(output_path + "/b_fc2.txt", self.b_fc2.eval()) 174 | np.savetxt(output_path + "/W_fc2.txt", self.W_fc2.eval()) 175 | 176 | def get_suggestion(self, ngram, threshold=.5): 177 | scores = self.get_score(ngram) 178 | if np.max(scores) > threshold and np.min(scores) < -threshold: 179 | return np.argmax(scores) 180 | else: 181 | return -1 182 | 183 | def get_score(self, ngram): 184 | fd = {self.y_: [list(np.zeros(self._num_outputs))]} 185 | for idx, word in enumerate(self.words): 186 | i = self.context_index_for_ngram_position(idx) 187 | fd[word] = [ngram[i]] 188 | scores = self.y.eval(fd)[0] 189 | return scores 190 | 191 | def context_index_for_ngram_position(self, idx: int) -> int: 192 | """for a position in a ngram, return idx if idx < floor(n/2), idx+1 otherwise""" 193 | if idx < int(self._num_inputs / 2): 194 | i = idx 195 | else: 196 | i = idx + 1 197 | return i 198 | 199 | def validate(self, verbose=False, threshold=.5): 200 | print("--- Validation of word prediction, threshold", threshold) 201 | 202 | correct = list(np.zeros(self._num_outputs)) 203 | incorrect = list(np.zeros(self._num_outputs)) 204 | unclassified = list(np.zeros(self._num_outputs)) 205 | tp = 0 206 | fp = 0 207 | tn = 0 208 | fn = 0 209 | 210 | for i in range(len(self._db_validate["groundtruths"])): 211 | suggestion = self.get_suggestion(self._db_validate["ngrams"][i], threshold=threshold) 212 | ground_truth = np.argmax(self._db_validate["groundtruths"][i]) 213 | if suggestion == -1: 214 | unclassified[ground_truth] = unclassified[ground_truth] + 1 215 | if verbose: 216 | print("no decision:", " ".join(self._db_validate["ngrams_raw"][i])) 217 | tn = tn + 1 218 | fn = fn + 1 219 | elif suggestion == ground_truth: 220 | correct[ground_truth] = correct[ground_truth] + 1 221 | if verbose: 222 | print("correct suggestion:", " ".join(self._db_validate["ngrams_raw"][i])) 223 | tp = tp + 1 224 | tn = tn + 1 225 | else: 226 | incorrect[ground_truth] = incorrect[ground_truth] + 1 227 | if verbose: 228 | print("possible wrong suggestion:", " ".join(self._db_validate["ngrams_raw"][i])) 229 | fp = fp + 1 230 | fn = fn + 1 231 | 232 | accuracy = list(map(lambda c, i: c/(c+i) if (c+i) > 0 else np.nan, correct, incorrect)) 233 | total_accuracy = list(map(lambda c, i, u: c/(c+i+u), correct, incorrect, unclassified)) 234 | 235 | print("correct:", correct) 236 | print("incorrect:", incorrect) 237 | print("accuracy:", accuracy) 238 | print("unclassified:", unclassified) 239 | print("total accuracy:", total_accuracy) 240 | 241 | print("tp", tp) 242 | print("tn", tn) 243 | print("fp", fp) 244 | print("fn", fn) 245 | print("precision:", float(tp)/(tp+fp) if (tp+fp) > 0 else 1) 246 | print("recall:", float(tp)/(tp+fn) if (tp+fn) > 0 else 0) 247 | 248 | def validate_error_detection(self, suggestion_threshold: float=0.5, error_threshold: float=0.2, verbose=False): 249 | print("--- Error Detection Validation: suggestion_threshold %4.2f, error_threshold %4.2f" 250 | % (suggestion_threshold, error_threshold)) 251 | 252 | correct = list(np.zeros(self._num_outputs)) 253 | incorrect = list(np.zeros(self._num_outputs)) 254 | unclassified = list(np.zeros(self._num_outputs)) 255 | tp = 0 256 | fp = 0 257 | fn = 0 258 | 259 | for i in range(len(self._db_validate["groundtruths"])): 260 | scores = self.get_score(self._db_validate["ngrams"][i]) 261 | probabilities = get_probabilities(scores) 262 | best_match = self.get_suggestion(self._db_validate["ngrams"][i]) 263 | best_match_score = scores[best_match] 264 | ground_truth = np.argmax(self._db_validate["groundtruths"][i]) 265 | ground_truth_probability = probabilities[ground_truth] 266 | 267 | if best_match_score > suggestion_threshold and error_threshold > ground_truth_probability: 268 | # suggest alternative 269 | incorrect[ground_truth] = incorrect[ground_truth] + 1 270 | if verbose: 271 | print("false alarm:", " ".join(self._db_validate["ngrams_raw"][i])) 272 | fp = fp + 1 273 | fn = fn + 1 274 | elif ground_truth_probability > suggestion_threshold: 275 | # ground truth will be suggested 276 | correct[ground_truth] = correct[ground_truth] + 1 277 | if verbose: 278 | print("correct suggestion included:", " ".join(self._db_validate["ngrams_raw"][i])) 279 | tp = tp + 1 280 | else: 281 | # nothing happens 282 | unclassified[ground_truth] = unclassified[ground_truth] + 1 283 | if verbose: 284 | print("no decision:", " ".join(self._db_validate["ngrams_raw"][i])) 285 | fn = fn + 1 286 | 287 | accuracy = list(map(lambda c, i: c/(c+i), correct, incorrect)) 288 | total_accuracy = list(map(lambda c, i, u: c/(c+i+u), correct, incorrect, unclassified)) 289 | 290 | print("correct:", correct) 291 | print("incorrect:", incorrect) 292 | print("accuracy:", accuracy) 293 | print("unclassified:", unclassified) 294 | print("total accuracy:", total_accuracy) 295 | 296 | print("tp", tp) 297 | print("fp", fp) 298 | print("fn", fn) 299 | print("precision:", float(tp)/(tp+fp) if (tp+fp) > 0 else 1) 300 | print("recall:", float(tp)/(tp+fn) if (tp+fn) > 0 else 0) 301 | accuracy = list(map(lambda c, i: c/(c+i), correct, incorrect)) 302 | print("accuracy:", accuracy) 303 | 304 | micro_accuracy = np.sum(correct)/(np.sum(correct)+np.sum(incorrect)) 305 | print("micro accuracy:", micro_accuracy) 306 | 307 | 308 | def main(): 309 | if len(sys.argv) != 6: 310 | print("dictionary_path embedding_path training_data_file test_data_file output_path") 311 | exit(-1) 312 | dictionary_path = sys.argv[1] 313 | embedding_path = sys.argv[2] 314 | training_data_file = sys.argv[3] 315 | test_data_file = sys.argv[4] 316 | output_path = sys.argv[5] 317 | network = NeuralNetwork(dictionary_path, embedding_path, training_data_file, test_data_file) 318 | network.train() 319 | network.save_weights(output_path) 320 | network.validate(verbose=True, threshold=.5) 321 | network.validate(verbose=False, threshold=1) 322 | network.validate_error_detection(verbose=False, suggestion_threshold=.5, error_threshold=.5) 323 | network.validate_error_detection(verbose=False, suggestion_threshold=.5, error_threshold=.4) 324 | network.validate_error_detection(verbose=False, suggestion_threshold=.5, error_threshold=.3) 325 | network.validate_error_detection(verbose=False, suggestion_threshold=.5, error_threshold=.2) 326 | 327 | 328 | if __name__ == '__main__': 329 | main() 330 | -------------------------------------------------------------------------------- /src/main/python/nn_words_correct_C.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import codecs 3 | import json 4 | import math 5 | import sys 6 | from ast import literal_eval 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | from sklearn.utils import shuffle 11 | 12 | import nn 13 | from repl import get_probabilities 14 | 15 | 16 | class NeuralNetwork: 17 | def __init__(self, dictionary_path: str, embedding_path: str, training_data_file: str, test_data_file: str, 18 | batch_size: int=1000, epochs: int=200, keep_prob: float=0.7): 19 | print(locals()) 20 | 21 | self.num_conv_filters = 32 22 | self.max_sequence_length = 50 23 | self.batch_size = batch_size 24 | self.epochs = epochs 25 | self.keep_prob = keep_prob 26 | 27 | with open(dictionary_path) as dictionary_file: 28 | self.dictionary = literal_eval(dictionary_file.read()) 29 | self.embedding = np.loadtxt(embedding_path) 30 | print("embedding shape", np.shape(self.embedding)) 31 | 32 | self._embedding_size = np.shape(self.embedding)[1] 33 | 34 | self._db = self.get_db(training_data_file) 35 | self._TRAINING_SAMPLES = len(self._db["groundTruths"]) 36 | self._num_outputs = 2 37 | self._current_batch_number = 0 38 | 39 | self._db_validate = self.get_db(test_data_file) 40 | 41 | self.setup_net() 42 | 43 | print("determined parameters: embedding_size=%d" % self._embedding_size) 44 | 45 | def setup_net(self): 46 | input_length = self.max_sequence_length 47 | 48 | with tf.name_scope('input'): 49 | self.x = tf.placeholder(tf.float32, [None, input_length * self._embedding_size]) 50 | 51 | with tf.name_scope('ground-truth'): 52 | self.y_ = tf.placeholder(tf.float32, shape=[None, self._num_outputs]) 53 | 54 | with tf.name_scope('conv_layer'): 55 | self.dropout = tf.placeholder(tf.float32) 56 | x_image = tf.reshape(tf.nn.dropout(self.x, self.dropout), [-1, input_length, self._embedding_size, 1]) 57 | filter_size = 5 58 | self.W_conv1 = nn.weight_variable([filter_size, self._embedding_size, 1, self.num_conv_filters]) 59 | self.b_conv1 = nn.bias_variable([self.num_conv_filters]) 60 | h_conv1 = tf.nn.relu(nn.conv2d(x_image, self.W_conv1, padding="VALID") + self.b_conv1) 61 | h_pool1 = nn.max_pool(h_conv1, input_length - filter_size + 1) 62 | h_pool1_flat = tf.reshape(h_pool1, [-1, self.num_conv_filters]) 63 | 64 | with tf.name_scope('readout_layer'): 65 | self.W_fc2 = nn.weight_variable([self.num_conv_filters, self._num_outputs]) 66 | self.b_fc2 = nn.bias_variable([self._num_outputs]) 67 | self.y = tf.matmul(h_pool1_flat, self.W_fc2) + self.b_fc2 68 | 69 | with tf.name_scope('train'): 70 | cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.y, labels=self.y_)) 71 | tf.summary.scalar('loss_function', cross_entropy) 72 | self.train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 73 | 74 | with tf.name_scope('accuracy'): 75 | self.correct_prediction = tf.equal(tf.argmax(self.y, 1), tf.argmax(self.y_, 1)) 76 | self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32)) 77 | tf.summary.scalar('accuracy', self.accuracy) 78 | 79 | self.sess = tf.InteractiveSession() 80 | self.sess.run(tf.global_variables_initializer()) 81 | 82 | def get_word_representation(self, word): 83 | if word in self.dictionary: 84 | return self.embedding[self.dictionary[word]] 85 | else: 86 | return self.embedding[self.dictionary["UNK"]] 87 | 88 | @staticmethod 89 | def take(sentence: [str], max_length: int): 90 | """keep first <=max_sequence_length words, fill rest with UNK""" 91 | return (sentence + ["UNK"] * max_length)[:max_length] 92 | 93 | def embed(self, words: [str]) -> np.ndarray: 94 | return np.array(list(map(lambda w: self.get_word_representation(w), words))) 95 | 96 | def get_db(self, path): 97 | db = dict() 98 | raw_db = json.load(open(path)) 99 | db["tokens"] = np.asarray(list(map(lambda ws: self.embed(self.take(ws, self.max_sequence_length)).flatten(), raw_db["tokens"]))) 100 | db["groundTruths"] = np.asarray(list(map(lambda gt: [1, 0] if gt == 1 else [0, 1], raw_db["groundTruths"]))) 101 | db["context_str"] = list(map(lambda ts: " ".join(self.take(ts, self.max_sequence_length)), raw_db["tokens"])) 102 | db["tokens"], db["groundTruths"], db["context_str"] = \ 103 | shuffle(db["tokens"], db["groundTruths"], db["context_str"]) 104 | print("%s loaded, containing %d entries, class distribution: %s" 105 | % (path, len(db["groundTruths"]), str(np.sum(np.asarray(db["groundTruths"]), axis=0)))) 106 | return db 107 | 108 | def get_batch(self): 109 | if self._current_batch_number * self.batch_size > self._TRAINING_SAMPLES: 110 | self._current_batch_number = 0 111 | start_index = self._current_batch_number * self.batch_size 112 | end_index = (self._current_batch_number + 1) * self.batch_size 113 | batch = dict() 114 | tokens = self._db["tokens"][start_index:end_index] 115 | batch[self.x] = tokens 116 | batch[self.y_] = self._db["groundTruths"][start_index:end_index] 117 | self._current_batch_number = self._current_batch_number + 1 118 | batch[self.dropout] = self.keep_prob 119 | # print("d" + str(len(batch[self.word1]))) 120 | return batch 121 | 122 | def get_all_training_data(self): 123 | batch = dict() 124 | tokens = self._db["tokens"][:] 125 | batch[self.x] = tokens 126 | batch[self.y_] = self._db["groundTruths"][:] 127 | batch[self.dropout] = 1 128 | # print("d" + str(len(batch[self.word1]))) 129 | return batch 130 | 131 | def train(self): 132 | steps = math.ceil(self._TRAINING_SAMPLES / self.batch_size) 133 | print("Steps: %d, %d steps in %d epochs" % (steps * self.epochs, steps, self.epochs)) 134 | for e in range(self.epochs): 135 | for i in range(steps): 136 | fd = self.get_batch() 137 | _ = self.sess.run([self.train_step], fd) # train with next batch 138 | if e % 10 == 0: 139 | self._print_accuracy(e) 140 | if e % 1000 == 0 and e > 0: 141 | self.validate() 142 | self.validate_error_detection() 143 | self._print_accuracy(self.epochs) 144 | 145 | def _print_accuracy(self, epoch): 146 | train_acc = self.sess.run([self.accuracy], self.get_all_training_data()) 147 | print("epoch %d, training accuracy %f" % (epoch, train_acc[0])) 148 | sys.stdout.flush() 149 | 150 | def save_weights(self, output_path): 151 | nn.write_4dmat(output_path + "/W_conv1.txt", tf.transpose(self.W_conv1, (3, 0, 1, 2)).eval()) 152 | nn.write_4dmat(output_path + "/b_conv1.txt", self.b_conv1.eval()) 153 | np.savetxt(output_path + "/b_fc2.txt", self.b_fc2.eval()) 154 | np.savetxt(output_path + "/W_fc2.txt", self.W_fc2.eval()) 155 | 156 | def get_suggestion(self, tokens, threshold=.5): 157 | scores = self.get_score(tokens) 158 | if np.max(scores) > threshold and np.min(scores) < -threshold: 159 | return np.argmax(scores) 160 | else: 161 | return -1 162 | 163 | def get_score(self, tokens): 164 | fd = {self.y_: [list(np.zeros(self._num_outputs))], 165 | self.x: [tokens], 166 | self.dropout: 1} 167 | scores = self.y.eval(fd)[0] 168 | return scores 169 | 170 | def validate(self, verbose=False, threshold=.5): 171 | print("--- Validation of correct/incorrect prediction, threshold", threshold) 172 | 173 | correct = list(np.zeros(self._num_outputs)) 174 | incorrect = list(np.zeros(self._num_outputs)) 175 | unclassified = list(np.zeros(self._num_outputs)) 176 | tp = 0 177 | fp = 0 178 | tn = 0 179 | fn = 0 180 | 181 | for i in range(len(self._db_validate["groundTruths"])): 182 | suggestion = self.get_suggestion(self._db_validate["tokens"][i], threshold=threshold) 183 | classification = "correct" if suggestion == 0 else "incorrect" 184 | ground_truth = np.argmax(self._db_validate["groundTruths"][i]) 185 | if suggestion == -1: 186 | unclassified[ground_truth] = unclassified[ground_truth] + 1 187 | if verbose: 188 | print("no decision:", self._db_validate["context_str"][i]) 189 | tn = tn + 1 190 | fn = fn + 1 191 | elif suggestion == ground_truth: 192 | correct[ground_truth] = correct[ground_truth] + 1 193 | if verbose: 194 | print("correctly classified as %s:" % classification, self._db_validate["context_str"][i]) 195 | tp = tp + 1 196 | tn = tn + 1 197 | else: 198 | incorrect[ground_truth] = incorrect[ground_truth] + 1 199 | if verbose: 200 | print("incorrectly classified as %s:" % classification, self._db_validate["context_str"][i]) 201 | fp = fp + 1 202 | fn = fn + 1 203 | 204 | accuracy = list(map(lambda c, i: c/(c+i) if (c+i) > 0 else np.nan, correct, incorrect)) 205 | total_accuracy = list(map(lambda c, i, u: c/(c+i+u), correct, incorrect, unclassified)) 206 | 207 | print("correct:", correct) 208 | print("incorrect:", incorrect) 209 | print("accuracy:", accuracy) 210 | print("unclassified:", unclassified) 211 | print("total accuracy:", total_accuracy) 212 | 213 | print("tp", tp) 214 | print("tn", tn) 215 | print("fp", fp) 216 | print("fn", fn) 217 | print("precision:", float(tp)/(tp+fp) if (tp+fp) > 0 else 1) 218 | print("recall:", float(tp)/(tp+fn) if (tp+fn) > 0 else 0) 219 | 220 | 221 | def main(): 222 | if len(sys.argv) != 6: 223 | print("dictionary_path embedding_path training_data_file test_data_file output_path") 224 | exit(-1) 225 | dictionary_path = sys.argv[1] 226 | embedding_path = sys.argv[2] 227 | training_data_file = sys.argv[3] 228 | test_data_file = sys.argv[4] 229 | output_path = sys.argv[5] 230 | network = NeuralNetwork(dictionary_path, embedding_path, training_data_file, test_data_file) 231 | network.train() 232 | network.save_weights(output_path) 233 | network.validate(verbose=True, threshold=.25) 234 | network.validate(verbose=False, threshold=.5) 235 | network.validate(verbose=False, threshold=1) 236 | 237 | 238 | if __name__ == '__main__': 239 | main() 240 | -------------------------------------------------------------------------------- /src/main/python/parallel.py: -------------------------------------------------------------------------------- 1 | # copied from https://stackoverflow.com/questions/3288595/multiprocessing-how-to-use-pool-map-on-a-function-defined-in-a-class 2 | import multiprocessing 3 | 4 | 5 | def fun(f, q_in, q_out): 6 | while True: 7 | i, x = q_in.get() 8 | if i is None: 9 | break 10 | q_out.put((i, f(x))) 11 | 12 | 13 | def parmap(f, X, nprocs=multiprocessing.cpu_count()): 14 | q_in = multiprocessing.Queue(1) 15 | q_out = multiprocessing.Queue() 16 | 17 | proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out)) 18 | for _ in range(nprocs)] 19 | for p in proc: 20 | p.daemon = True 21 | p.start() 22 | 23 | sent = [q_in.put((i, x)) for i, x in enumerate(X)] 24 | [q_in.put((None, None)) for _ in range(nprocs)] 25 | res = [q_out.get() for _ in range(len(sent))] 26 | 27 | [p.join() for p in proc] 28 | 29 | return [x for i, x in sorted(res)] 30 | -------------------------------------------------------------------------------- /src/main/python/random_forest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import codecs 3 | import sys 4 | from ast import literal_eval 5 | 6 | import numpy as np 7 | from sklearn.ensemble import RandomForestClassifier 8 | from sklearn.naive_bayes import GaussianNB 9 | from sklearn.svm import SVC 10 | from sklearn.tree import DecisionTreeClassifier 11 | from sklearn.utils import shuffle 12 | 13 | from repl import get_probabilities 14 | 15 | 16 | class RandomForest: 17 | def __init__(self, dictionary_path: str, embedding_path: str, training_data_file: str, test_data_file: str, 18 | n_estimators: int=120): 19 | print(locals()) 20 | 21 | with open(dictionary_path) as dictionary_file: 22 | self.dictionary = literal_eval(dictionary_file.read()) 23 | self.embedding = np.loadtxt(embedding_path) 24 | print("embedding shape", np.shape(self.embedding)) 25 | 26 | self._input_size = np.shape(self.embedding)[1] 27 | 28 | self._db = self.get_db(training_data_file) 29 | self._TRAINING_SAMPLES = len(self._db["groundtruths"]) 30 | self._num_outputs = np.max(self._db["groundtruths"]) + 1 31 | 32 | self._num_inputs = len(self._db["ngrams"][0]) - 1 33 | 34 | self._db_validate = self.get_db(test_data_file) 35 | 36 | self._classifier = RandomForestClassifier(n_estimators=n_estimators, n_jobs=-1) 37 | 38 | print("determined parameters: num_inputs=%d, input_size=%d" % (self._num_inputs, self._input_size)) 39 | 40 | def get_word_representation(self, word): 41 | if word in self.dictionary: 42 | return self.embedding[self.dictionary[word]] 43 | else: 44 | return self.embedding[self.dictionary["UNK"]] 45 | 46 | def get_db(self, path): 47 | db = dict() 48 | raw_db = eval(codecs.open(path, "r", "utf-8").read()) 49 | # raw_db = {'ngrams':[['too','early','to','rule','out'],['park','next','to','the','town'],['has','right','to','destroy','houses'],['ll','have','to','move','them'],['percent','increase','to','its','budget'],['ll','continue','to','improve','in'],['This','applies','to','footwear','too'],['t','appear','to','be','too'],['taking','action','to','prevent','a'],['too','close','to','the','fire'],['It','rates','to','get','a'],['it','was','too','early','to'],['houses',',','too','.','We'],['to','footwear','too','-','can'],['to','be','too','much','money'],['been','waiting','too','long','for'],['leave','rates','too','low','for'],['low','for','too','long',','],['not','going','too','close','to']],'groundtruths':[[1,0],[1,0],[1,0],[1,0],[1,0],[1,0],[1,0],[1,0],[1,0],[1,0],[1,0],[0,1],[0,1],[0,1],[0,1],[0,1],[0,1],[0,1],[0,1]]} 50 | db["ngrams"] = np.asarray( 51 | list(map(lambda ws: list(map(lambda w: self.get_word_representation(w), ws)), raw_db["ngrams"]))) 52 | db["groundtruths"] = list(map(np.argmax, raw_db["groundtruths"])) 53 | db["ngrams"], db["groundtruths"], db["ngrams_raw"] = shuffle(db["ngrams"], db["groundtruths"], raw_db["ngrams"]) 54 | print("%s loaded, containing %d entries, class distribution: %s" 55 | % (path, len(db["groundtruths"]), str(np.sum(np.asarray(db["groundtruths"]), axis=0)))) 56 | return db 57 | 58 | def get_all_data(self, db): 59 | batch = dict() 60 | batch["ngrams"] = list(map(lambda ws: np.concatenate(self.context_from_ngram(ws)), db["ngrams"])) 61 | batch["ngrams_raw"] = db["ngrams_raw"][:] 62 | batch["groundtruths"] = db["groundtruths"][:] 63 | return batch 64 | 65 | def train(self): 66 | training_data = self.get_all_data(self._db) 67 | self._classifier.fit(X=training_data["ngrams"], y=training_data["groundtruths"]) 68 | 69 | def save_weights(self, output_path): 70 | print("TODO stub save_weights") 71 | # np.savetxt(output_path + "/W_fc1.txt", self.W_fc1.eval()) 72 | # np.savetxt(output_path + "/b_fc1.txt", self.b_fc1.eval()) 73 | # if self.use_hidden_layer: 74 | # np.savetxt(output_path + "/b_fc2.txt", self.b_fc2.eval()) 75 | # np.savetxt(output_path + "/W_fc2.txt", self.W_fc2.eval()) 76 | 77 | def get_suggestion(self, ngram): 78 | scores = self.get_score(ngram) 79 | if np.max(scores) > .5 + 1/(1+np.exp(-0.5))-0.5 and np.min(scores) < .5 - (1/(1+np.exp(-0.5))-0.5): 80 | return np.argmax(scores) 81 | else: 82 | return -1 83 | 84 | def get_score(self, ngram): 85 | scores = self._classifier.predict_proba([ngram]) 86 | return scores[0] 87 | 88 | def context_from_ngram(self, ngram: np.ndarray) -> np.ndarray: 89 | middle = int(self._num_inputs / 2) 90 | return np.concatenate([ngram[:middle], ngram[middle+1:]]) 91 | 92 | def validate(self, verbose=False): 93 | correct = list(np.zeros(self._num_outputs)) 94 | incorrect = list(np.zeros(self._num_outputs)) 95 | unclassified = list(np.zeros(self._num_outputs)) 96 | tp = 0 97 | fp = 0 98 | tn = 0 99 | fn = 0 100 | 101 | validation_data = self.get_all_data(self._db_validate) 102 | for ngram, raw_ngram, ground_truth in zip(validation_data["ngrams"], validation_data["ngrams_raw"], validation_data["groundtruths"]): 103 | suggestion = self.get_suggestion(ngram) 104 | if suggestion == -1: 105 | unclassified[ground_truth] = unclassified[ground_truth] + 1 106 | if verbose: 107 | print("no decision:", " ".join(raw_ngram)) 108 | tn = tn + 1 109 | fn = fn + 1 110 | elif suggestion == ground_truth: 111 | correct[ground_truth] = correct[ground_truth] + 1 112 | if verbose: 113 | print("correct suggestion:", " ".join(raw_ngram)) 114 | tp = tp + 1 115 | tn = tn + 1 116 | else: 117 | incorrect[ground_truth] = incorrect[ground_truth] + 1 118 | if verbose: 119 | print("possible wrong suggestion:", " ".join(raw_ngram)) 120 | fp = fp + 1 121 | fn = fn + 1 122 | 123 | accuracy = list(map(lambda c, i: c/(c+i), correct, incorrect)) 124 | total_accuracy = list(map(lambda c, i, u: c/(c+i+u), correct, incorrect, unclassified)) 125 | 126 | print("correct:", correct) 127 | print("incorrect:", incorrect) 128 | print("accuracy:", accuracy) 129 | print("unclassified:", unclassified) 130 | print("total accuracy:", total_accuracy) 131 | 132 | print("tp", tp) 133 | print("tn", tn) 134 | print("fp", fp) 135 | print("fn", fn) 136 | print("precision:", float(tp)/(tp+fp)) 137 | print("recall:", float(tp)/(tp+fn)) 138 | 139 | def validate_error_detection(self, suggestion_threshold: float=0.5, error_threshold: float=0.2, verbose=False): 140 | correct = list(np.zeros(self._num_outputs)) 141 | incorrect = list(np.zeros(self._num_outputs)) 142 | unclassified = list(np.zeros(self._num_outputs)) 143 | tp = 0 144 | fp = 0 145 | fn = 0 146 | 147 | validation_data = self.get_all_data(self._db_validate) 148 | for ngram, raw_ngram, ground_truth in zip(validation_data["ngrams"], validation_data["ngrams_raw"], validation_data["groundtruths"]): 149 | scores = self.get_score(ngram) 150 | probabilities = get_probabilities(scores) 151 | best_match = self.get_suggestion(ngram) 152 | best_match_score = scores[best_match] 153 | ground_truth_probability = probabilities[ground_truth] 154 | 155 | if best_match_score > suggestion_threshold and error_threshold > ground_truth_probability: 156 | # suggest alternative 157 | incorrect[ground_truth] = incorrect[ground_truth] + 1 158 | if verbose: 159 | print("false alarm:", " ".join(raw_ngram)) 160 | fp = fp + 1 161 | fn = fn + 1 162 | elif ground_truth_probability > suggestion_threshold: 163 | # ground truth will be suggested 164 | correct[ground_truth] = correct[ground_truth] + 1 165 | if verbose: 166 | print("correct suggestion included:", " ".join(raw_ngram)) 167 | tp = tp + 1 168 | else: 169 | # nothing happens 170 | unclassified[ground_truth] = unclassified[ground_truth] + 1 171 | if verbose: 172 | print("no decision:", " ".join(raw_ngram)) 173 | fn = fn + 1 174 | 175 | accuracy = list(map(lambda c, i: c/(c+i), correct, incorrect)) 176 | total_accuracy = list(map(lambda c, i, u: c/(c+i+u), correct, incorrect, unclassified)) 177 | 178 | print("correct:", correct) 179 | print("incorrect:", incorrect) 180 | print("accuracy:", accuracy) 181 | print("unclassified:", unclassified) 182 | print("total accuracy:", total_accuracy) 183 | 184 | print("tp", tp) 185 | print("fp", fp) 186 | print("fn", fn) 187 | print("precision:", float(tp)/(tp+fp)) 188 | print("recall:", float(tp)/(tp+fn)) 189 | accuracy = list(map(lambda c, i: c/(c+i), correct, incorrect)) 190 | print("accuracy:", accuracy) 191 | 192 | micro_accuracy = np.sum(correct)/(np.sum(correct)+np.sum(incorrect)) 193 | print("micro accuracy:", micro_accuracy) 194 | 195 | 196 | def main(): 197 | if len(sys.argv) != 6: 198 | print("dictionary_path embedding_path training_data_file test_data_file output_path") 199 | exit(-1) 200 | dictionary_path = sys.argv[1] 201 | embedding_path = sys.argv[2] 202 | training_data_file = sys.argv[3] 203 | test_data_file = sys.argv[4] 204 | output_path = sys.argv[5] 205 | network = RandomForest(dictionary_path, embedding_path, training_data_file, test_data_file) 206 | network.train() 207 | network.save_weights(output_path) 208 | network.validate(verbose=True) 209 | # print(.5) 210 | # network.validate_error_detection(verbose=False, suggestion_threshold=.5, error_threshold=.5) 211 | # print(.4) 212 | # network.validate_error_detection(verbose=False, suggestion_threshold=.5, error_threshold=.4) 213 | # print(.3) 214 | # network.validate_error_detection(verbose=False, suggestion_threshold=.5, error_threshold=.3) 215 | # print(.2) 216 | # network.validate_error_detection(verbose=True, suggestion_threshold=.5, error_threshold=.2) 217 | 218 | 219 | if __name__ == '__main__': 220 | main() 221 | -------------------------------------------------------------------------------- /src/main/python/repl.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from ast import literal_eval 3 | 4 | import numpy as np 5 | 6 | from LayeredScorer import LayeredScorer 7 | 8 | 9 | def get_word_representation(dictionary, embedding, word): # todo static 10 | if word in dictionary: 11 | return embedding[dictionary[word]] 12 | else: 13 | print(" " + word + " is unknown") 14 | return embedding[dictionary["UNK"]] 15 | 16 | 17 | def get_rating(scores, subjects): 18 | probabilities = get_probabilities(scores) 19 | scored_suggestions = list(zip(probabilities, scores, subjects)) 20 | scored_suggestions.sort(key=lambda x: x[0], reverse=True) 21 | return scored_suggestions 22 | 23 | 24 | def get_probabilities(scores): 25 | return 1 / (1 + np.exp(-np.array(scores))) 26 | 27 | 28 | def has_error(dictionary, embedding, scorer: LayeredScorer, ngram, subjects, suggestion_threshold=.5, error_threshold=.2) -> bool: 29 | """ 30 | Parameters 31 | ---------- 32 | suggestion_threshold: 33 | if the probability of another token is higher than this, it is considered as possible suggestion 34 | error_threshold: 35 | if the probability for the used token is less than this, it is considered wrong 36 | """ 37 | middle = int(len(ngram) / 2) 38 | words = np.concatenate(list(map(lambda token: get_word_representation(dictionary, embedding, token), np.delete(ngram, middle)))) 39 | scores = scorer.scores(words) 40 | probabilities = get_probabilities(scores) 41 | best_match_probability = probabilities[np.argmax(probabilities)] 42 | subject_index = subjects.index(ngram[middle]) 43 | subject_probability = probabilities[subject_index] 44 | 45 | print("checked", ngram) 46 | 47 | if best_match_probability > suggestion_threshold and subject_probability < error_threshold: 48 | print("ERROR detected, suggestions:", get_rating(scores, subjects)) 49 | return True 50 | elif subject_probability > suggestion_threshold: 51 | print("ok", get_rating(scores, subjects)) 52 | return False 53 | else: 54 | print("no decision", get_rating(scores, subjects)) 55 | return False 56 | 57 | 58 | def main(): 59 | if len(sys.argv) != 4: 60 | raise ValueError("Expected dict, finalembedding, weights_path") 61 | 62 | dictionary_path = sys.argv[1] 63 | embedding_path = sys.argv[2] 64 | weights_path = sys.argv[3] 65 | 66 | with open(dictionary_path) as dictionary_file: 67 | dictionary = literal_eval(dictionary_file.read()) 68 | embedding = np.loadtxt(embedding_path) 69 | 70 | subjects = ["als", "also", "da", "das", "dass", "de", "den", "denn", "die", "durch", "zur", "ihm", "im", "um", "nach", "noch", "war", "was"] 71 | # subjects = ["and", "end", "as", "at", "is", "do", "for", "four", "form", "from", "he", "if", "is", "its", "it", "no", "now", "on", "one", "same", "some", "than", "that", "then", "their", "there", "them", "the", "they", "to", "was", "way", "were", "where"] 72 | print(subjects) 73 | 74 | while True: 75 | ngram = input("ngram ").split(" ") 76 | try: 77 | scorer = LayeredScorer(weights_path) 78 | has_error(dictionary, embedding, scorer, ngram, subjects, error_threshold=.65, suggestion_threshold=.65) 79 | except ValueError as e: 80 | print(e) 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | -------------------------------------------------------------------------------- /src/main/resources/example-corpus.txt: -------------------------------------------------------------------------------- 1 | I would like to go to school. 2 | We have a car, too. But I don't like to drive. 3 | I'd like to have a cat. 4 | She likes to go swmming, too. 5 | He is too heavy. 6 | Is he smart? 7 | There is a syntax error in your code. You have to fix it. 8 | I've forgotten to do my homework. 9 | Where is my key? 10 | You have to be careful. 11 | Many thanks to http://fscs.hhu.de/ for hosting the demo. 12 | The tasks aren't too difficult. 13 | You have to install Python in order to use the project. 14 | Haskell and Clojure are great languages, too. 15 | -------------------------------------------------------------------------------- /src/test/java/de/hhu/mabre/languagetool/BinarySentenceDatabaseCreatorTest.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | import org.junit.Test; 4 | 5 | import java.util.Arrays; 6 | import java.util.List; 7 | 8 | import static de.hhu.mabre.languagetool.BinarySentenceDatabaseCreator.introduceKommaOhneDass_OhneDas; 9 | import static org.junit.Assert.assertEquals; 10 | 11 | public class BinarySentenceDatabaseCreatorTest { 12 | 13 | @Test 14 | public void introduceKommaOhneDass_OhneDasTest() throws Exception { 15 | List correct = Arrays.asList("Sie", "können", "sich", "ändern", ",", "ohne", "dass", "man", "es", "merkt", "."); 16 | List expected = Arrays.asList("Sie", "können", "sich", "ändern", "ohne", "das", "man", "es", "merkt", "."); 17 | List actual = introduceKommaOhneDass_OhneDas(correct); 18 | assertEquals(expected, actual); 19 | } 20 | 21 | } -------------------------------------------------------------------------------- /src/test/java/de/hhu/mabre/languagetool/FileTokenizerTest.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | import junit.framework.TestCase; 4 | 5 | import java.util.Arrays; 6 | import java.util.List; 7 | 8 | import static de.hhu.mabre.languagetool.FileTokenizer.tokenize; 9 | 10 | public class FileTokenizerTest extends TestCase { 11 | 12 | public void testTokenize() { 13 | List result = tokenize("en", "You’re not here."); 14 | assertEquals(Arrays.asList("You", "’", "re", "not", "here", "."), result); 15 | } 16 | 17 | } 18 | -------------------------------------------------------------------------------- /src/test/java/de/hhu/mabre/languagetool/NGramDatabaseCreatorTest.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | import junit.framework.TestCase; 4 | 5 | import java.util.Arrays; 6 | import java.util.Collections; 7 | import java.util.HashMap; 8 | import java.util.List; 9 | 10 | import static de.hhu.mabre.languagetool.NGramDatabaseCreator.*; 11 | import static de.hhu.mabre.languagetool.SamplingMode.NONE; 12 | import static de.hhu.mabre.languagetool.SamplingMode.UNDERSAMPLE; 13 | import static de.hhu.mabre.languagetool.SubsetType.TRAINING; 14 | import static de.hhu.mabre.languagetool.SubsetType.VALIDATION; 15 | 16 | public class NGramDatabaseCreatorTest extends TestCase { 17 | 18 | public void testGetRelevantNGrams() { 19 | List nGrams = getRelevantNGrams(Arrays.asList("c", "b", "c", "d", "e", "f", "g", "h", "c", "c", "i"), Collections.singletonList("c")); 20 | List expectedNGrams = Arrays.asList( 21 | new NGram("c", "b", "c", "d", "e"), 22 | new NGram("g", "h", "c", "c", "i")); 23 | assertEquals(expectedNGrams, nGrams); 24 | } 25 | 26 | public void testGetRelevantNGramsWithMultiTokens() { 27 | List nGrams = getRelevantNGrams(Arrays.asList("c", "b", "c", "d", "e", "f", "g", "h", "c", "c", "i"), Arrays.asList("d", "e")); 28 | List expectedNGrams = Collections.singletonList(new NGram("b", "c", "d e", "f", "g")); 29 | assertEquals(expectedNGrams, nGrams); 30 | } 31 | 32 | public void testGetRelevantNGramsWithMultiTokensAtBorder() { 33 | List nGrams = getRelevantNGrams(Arrays.asList("c", "b", "c", "d", "e", "f", "g", "c", "d", "c", "i"), Arrays.asList("c", "d")); 34 | List expectedNGrams = Arrays.asList( 35 | new NGram("c", "b", "c d", "e", "f"), 36 | new NGram("f", "g", "c d", "c", "i")); 37 | assertEquals(expectedNGrams, nGrams); 38 | } 39 | 40 | public void testCreateDatabase() { 41 | PythonDict db = createDatabase(Arrays.asList("c", "b", "c", "d", "e", "f", "g", "h", "c", "c", "i"), "c", "f", UNDERSAMPLE); 42 | String expectedDb = "{'ngrams':[['c','b','c','d','e'],['d','e','f','g','h']],\n'groundtruths':[[1,0],[0,1]]}"; 43 | assertEquals(expectedDb, db.toString()); 44 | } 45 | 46 | public void testCreateDatabaseModerateOversample() { 47 | PythonDict db = createDatabase(Arrays.asList("c", "b", "c", "d", "e", "f", "g", "h", "c", "c", "i", "j"), "c", "f", SamplingMode.MODERATE_OVERSAMPLE); 48 | String expectedDb = "{'ngrams':[['c','b','c','d','e'],['g','h','c','c','i'],['d','e','f','g','h'],['d','e','f','g','h']],\n'groundtruths':[[1,0],[1,0],[0,1],[0,1]]}"; 49 | assertEquals(expectedDb, db.toString()); 50 | } 51 | 52 | public void testCreateDatabaseOversample() { 53 | PythonDict db = createDatabase(Arrays.asList("c", "b", "c", "d", "e", "f", "g", "h", "c", "c", "i", "j"), "c", "f", SamplingMode.OVERSAMPLE); 54 | String expectedDb = "{'ngrams':[['c','b','c','d','e'],['g','h','c','c','i'],['h','c','c','i','j'],['d','e','f','g','h'],['d','e','f','g','h'],['d','e','f','g','h']],\n'groundtruths':[[1,0],[1,0],[1,0],[0,1],[0,1],[0,1]]}"; 55 | assertEquals(expectedDb, db.toString()); 56 | } 57 | 58 | public void testCreateDatabaseNoSampling() { 59 | PythonDict db = createDatabase(Arrays.asList("c", "b", "c", "d", "e", "f", "g", "h", "c", "c", "i", "j"), "c", "f", SamplingMode.NONE); 60 | String expectedDb = "{'ngrams':[['c','b','c','d','e'],['g','h','c','c','i'],['h','c','c','i','j'],['d','e','f','g','h']],\n'groundtruths':[[1,0],[1,0],[1,0],[0,1]]}"; 61 | assertEquals(expectedDb, db.toString()); 62 | } 63 | 64 | public void testDatabaseFromSentences() { 65 | PythonDict db = databaseFromSentences("en", "I like that, too. I would like to go to the museum, too.", "to", "too", UNDERSAMPLE); 66 | String expectedDb = "{'ngrams':[['would','like','to','go','to'],['to','go','to','the','museum'],['that',',','too','.','I'],['museum',',','too','.','.']],\n'groundtruths':[[1,0],[1,0],[0,1],[0,1]]}"; 67 | assertEquals(expectedDb, db.toString()); 68 | } 69 | 70 | public void testDatabaseFromSentencesSingleQuoteEscaping() { 71 | PythonDict db = databaseFromSentences("en", "Whare is 'The Station'? I would like to go to the museum.", "Station", "to", UNDERSAMPLE); 72 | String expectedDb = "{'ngrams':[['\\'','The','Station','\\'','?'],['would','like','to','go','to']],\n'groundtruths':[[1,0],[0,1]]}"; 73 | assertEquals(expectedDb, db.toString()); 74 | } 75 | 76 | public void testMultiTokenSubjects() { 77 | PythonDict db = databaseFromSentences("de", "Ich habe das nicht gesagt. Ich habe gesagt, dass ich zum Karaokeabend kommen werde. Ich mag das nicht.", "das", ", dass", NONE); 78 | String expectedDb = "{'ngrams':[['Ich','habe','das','nicht','gesagt'],['Ich','mag','das','nicht','.'],['habe','gesagt',', dass','ich','zum']],\n'groundtruths':[[1,0],[1,0],[0,1]]}"; 79 | assertEquals(expectedDb, db.toString()); 80 | } 81 | 82 | public void testDatabaseFromSentencesAndThreeTokens() { 83 | PythonDict db = databaseFromSentences("en", "I like that, too. I would like to go to two museums, too.", Arrays.asList("to", "too", "two"), UNDERSAMPLE); 84 | String expectedDb = "{'ngrams':[['would','like','to','go','to'],['that',',','too','.','I'],['go','to','two','museums',',']],\n'groundtruths':[[1,0,0],[0,1,0],[0,0,1]]}"; 85 | assertEquals(expectedDb, db.toString()); 86 | } 87 | 88 | public void testRandomlySplit() { 89 | List lines = Arrays.asList("0", "1", "2", "3", "4", "5", "6", "7", "8", "9"); 90 | HashMap sets = randomlySplit(lines, 32); 91 | assertEquals(3, sets.get(VALIDATION).split("\n").length); 92 | assertEquals(7, sets.get(TRAINING).split("\n").length); 93 | } 94 | 95 | public void testGetRelevantCharNGrams() { 96 | List nGrams = getRelevantCharNGrams(Arrays.asList("Because", "we", "want", "to", "test", "the", "implementation", "to", "test"), Collections.singletonList("to")); 97 | List expectedNGrams = Collections.singletonList(new NGram("a", "u", "s", "e", " ", "w", "e", " ", "w", "a", "n", "t", "to", "t", "e", "s", "t", " ", "t", "h", "e", " ", "i", "m", "p")); 98 | assertEquals(expectedNGrams, nGrams); 99 | } 100 | 101 | } 102 | -------------------------------------------------------------------------------- /src/test/java/de/hhu/mabre/languagetool/NGramTest.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | import junit.framework.TestCase; 4 | 5 | public class NGramTest extends TestCase { 6 | 7 | public void testEquals() { 8 | NGram nGram1 = new NGram("g", "h", "c", "i", "j"); 9 | NGram nGram2 = new NGram("g", "h", "c", "i", "j"); 10 | assertEquals(nGram1, nGram2); 11 | } 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/test/java/de/hhu/mabre/languagetool/SentenceDatabaseCreatorTest.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | import org.junit.Test; 4 | 5 | import java.util.Arrays; 6 | import java.util.Collections; 7 | import java.util.List; 8 | 9 | import static de.hhu.mabre.languagetool.SentenceDatabaseCreator.createDatabase; 10 | import static de.hhu.mabre.languagetool.SentenceDatabaseCreator.getRelevantSentenceBeginnings; 11 | import static de.hhu.mabre.languagetool.SentenceDatabaseCreator.getRelevantSentenceEndings; 12 | import static org.junit.Assert.assertEquals; 13 | 14 | public class SentenceDatabaseCreatorTest { 15 | 16 | @Test 17 | public void testGetRelevantSentenceBeginningsWithSingleTokenSubject() { 18 | List> sentences = Arrays.asList( 19 | Arrays.asList("a", "b", "c", "foo"), 20 | Arrays.asList("a", "b", "c", "bar", "d"), 21 | Arrays.asList("a", "b", "foo", "c")); 22 | List> result = getRelevantSentenceBeginnings(sentences, Collections.singletonList("foo")); 23 | List> expected = Arrays.asList( 24 | Arrays.asList("a", "b", "c"), 25 | Arrays.asList("a", "b")); 26 | assertEquals(expected, result); 27 | } 28 | 29 | @Test 30 | public void testGetRelevantSentenceBeginningsWithMultiTokenSubject() { 31 | List> sentences = Arrays.asList( 32 | Arrays.asList("a", "b", "c", "foo", "bar"), 33 | Arrays.asList("a", "b", "b", "foo", "bar", "d"), 34 | Arrays.asList("a", "b", "c", "bar", "d"), 35 | Arrays.asList("a", "b", "foo", "c")); 36 | List> result = getRelevantSentenceBeginnings(sentences, Arrays.asList("foo", "bar")); 37 | List> expected = Arrays.asList( 38 | Arrays.asList("a", "b", "c"), 39 | Arrays.asList("a", "b", "b")); 40 | assertEquals(expected, result); 41 | } 42 | 43 | @Test 44 | public void testGetRelevantSentenceEndingsWithSingleTokenSubject() { 45 | List> sentences = Arrays.asList( 46 | Arrays.asList("a", "b", "c", "foo"), 47 | Arrays.asList("a", "b", "c", "bar", "d"), 48 | Arrays.asList("a", "b", "foo", "c")); 49 | List> result = getRelevantSentenceEndings(sentences, Collections.singletonList("foo")); 50 | List> expected = Arrays.asList( 51 | Collections.emptyList(), 52 | Collections.singletonList("c")); 53 | assertEquals(expected, result); 54 | } 55 | 56 | @Test 57 | public void createDatabaseTest() { 58 | List> sentences = Arrays.asList( 59 | Arrays.asList("a", "b", "c", "foo", "bar"), 60 | Arrays.asList("a", "b", "b", "foo", "bar", "d"), 61 | Arrays.asList("a", "b", "c", "bar", "d"), 62 | Arrays.asList("a", "b", "foo", "c")); 63 | SentencesDict result = createDatabase(sentences, Collections.singletonList(Collections.singletonList("foo")), SamplingMode.NONE); 64 | assertEquals("{\"tokensBefore\":[[\"a\",\"b\",\"c\"],[\"a\",\"b\",\"b\"],[\"a\",\"b\"]],\"tokensAfter\":[[\"bar\"],[\"bar\",\"d\"],[\"c\"]],\"groundTruths\":[[1],[1],[1]],\"nCategories\":1}", result.toJson()); 65 | } 66 | 67 | } -------------------------------------------------------------------------------- /src/test/java/de/hhu/mabre/languagetool/SentencesDictTest.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | import org.junit.Test; 4 | 5 | import java.util.Arrays; 6 | import java.util.Collections; 7 | 8 | import static org.junit.Assert.assertEquals; 9 | 10 | public class SentencesDictTest { 11 | 12 | @Test 13 | public void toJsonTest() { 14 | SentencesDict dict = new SentencesDict(2); 15 | dict.add(Collections.singletonList("Sie"), Arrays.asList("schlau", "."), 0); 16 | dict.add(Arrays.asList("Die", "beiden"), Arrays.asList("schlau", "."), 1); 17 | assertEquals("{\"tokensBefore\":[[\"Sie\"],[\"Die\",\"beiden\"]],\"tokensAfter\":[[\"schlau\",\".\"],[\"schlau\",\".\"]],\"groundTruths\":[[1,0],[0,1]],\"nCategories\":2}", 18 | dict.toJson()); 19 | } 20 | 21 | } -------------------------------------------------------------------------------- /src/test/java/de/hhu/mabre/languagetool/SubjectGrepperTest.java: -------------------------------------------------------------------------------- 1 | package de.hhu.mabre.languagetool; 2 | 3 | import junit.framework.TestCase; 4 | 5 | import java.io.StringReader; 6 | import java.util.List; 7 | 8 | import static de.hhu.mabre.languagetool.SubjectGrepper.grep; 9 | 10 | public class SubjectGrepperTest extends TestCase { 11 | public void testGrepSingleSubject() throws Exception { 12 | List results = grep(new StringReader("foo bar\nlorem ipsum\nblabfoobdfj\nasdf"), "foo"); 13 | assertEquals(2, results.size()); 14 | assertEquals("foo bar", results.get(0)); 15 | } 16 | 17 | public void testGrepTwoSubjects() throws Exception { 18 | List results = grep(new StringReader("foo bar\nlorem ipsum\nblabfoobdfj\nasdf"), "foo", "l"); 19 | assertEquals(3, results.size()); 20 | assertEquals("foo bar", results.get(0)); 21 | } 22 | } -------------------------------------------------------------------------------- /src/test/python/embedding/.cache/v/cache/lastfailed: -------------------------------------------------------------------------------- 1 | { 2 | "test_word_to_char_embedding.py": true 3 | } -------------------------------------------------------------------------------- /src/test/python/embedding/test_common.py: -------------------------------------------------------------------------------- 1 | from embedding.common import build_dataset 2 | 3 | 4 | def test_build_dataset(): 5 | data, count, dictionary, reverse_dictionary = build_dataset(["a", "a", "b", "a", "b", "c", "d"], 3) 6 | assert count == [['UNK', 2], ('a', 3), ('b', 2)] 7 | assert data == [1, 1, 2, 1, 2, 0, 0] 8 | assert dictionary == {'b': 2, 'a': 1, 'UNK': 0} 9 | assert reverse_dictionary == {0: 'UNK', 2: 'b', 1: 'a'} 10 | 11 | 12 | def test_build_dataset_with_tagger(): 13 | tagger = lambda w: "TAG" 14 | data, count, dictionary, reverse_dictionary = build_dataset(["a", "a", "b", "a", "a", "b", "b", "c", "d"], 3, tagger) 15 | assert count == [['UNK', 0], ('a', 4), ('b', 3), ('TAG', 2)] 16 | assert data == [1, 1, 2, 1, 1, 2, 2, 3, 3] 17 | assert dictionary == {'b': 2, 'a': 1, 'UNK': 0, 'TAG': 3} 18 | assert reverse_dictionary == {0: 'UNK', 2: 'b', 1: 'a', 3: 'TAG'} 19 | -------------------------------------------------------------------------------- /src/test/python/embedding/test_word_to_char_embedding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.testing import assert_array_almost_equal 3 | 4 | from embedding.word_to_char_embedding import to_char_embedding 5 | 6 | 7 | def test_to_char_embedding(): 8 | dictionary = {"ab": 0, "bc": 1} 9 | embedding = np.asarray([[1, 2, 3], [3, 4, 5]]) 10 | char_dictionary, char_embedding = to_char_embedding(dictionary, embedding) 11 | assert_array_almost_equal(char_embedding[char_dictionary["UNK"]], [0, 0, 0]) 12 | assert_array_almost_equal(char_embedding[char_dictionary["a"]], [1, 2, 3]) 13 | assert_array_almost_equal(char_embedding[char_dictionary["b"]], [2, 3, 4]) 14 | assert_array_almost_equal(char_embedding[char_dictionary["c"]], [3, 4, 5]) 15 | -------------------------------------------------------------------------------- /src/test/python/test_eval.py: -------------------------------------------------------------------------------- 1 | from eval import get_relevant_ngrams, evaluate_ngrams, similar_words 2 | 3 | 4 | def test_get_single_relevant_ngram(): 5 | ngrams = get_relevant_ngrams("a b c d e", ["c"]) 6 | assert ngrams == [["a", "b", "c", "d", "e"]] 7 | 8 | 9 | def test_get_two_relevant_ngrams(): 10 | ngrams = get_relevant_ngrams("a b c d e f", ["c", "d"]) 11 | assert ngrams == [["a", "b", "c", "d", "e"], ["b", "c", "d", "e", "f"]] 12 | 13 | 14 | def test_evaluate_one_ngram(): 15 | eval_result = evaluate_ngrams([["a", "b", "c", "d", "e"]], ["c", "d"], lambda _: True) 16 | assert eval_result.tp == 1 17 | assert eval_result.fp == 1 18 | assert eval_result.tn == 0 19 | assert eval_result.fn == 0 20 | 21 | 22 | def test_evaluate_ngrams(): 23 | eval_result = evaluate_ngrams([["a", "b", "c", "d", "e"], ["b", "c", "d", "e", "f"]], ["c", "d", "f"], 24 | lambda ng: not(ng[0] == "a" and ng[2] == "c")) 25 | assert eval_result.tp == 4 26 | assert eval_result.fp == 1 27 | assert eval_result.tn == 1 28 | assert eval_result.fn == 0 29 | 30 | 31 | def test_similar_words(): 32 | selected_words = similar_words("dein", ["sein", "dumme", "dein", "deinen", "dienend"]) 33 | assert selected_words == ["sein", "dein", "deinen"] 34 | 35 | 36 | def test_similar_short_words(): 37 | selected_words = similar_words("da", ["da", "dann", "der", "aus"]) 38 | assert selected_words == ["da", "dann", "der"] 39 | --------------------------------------------------------------------------------