├── .gitignore ├── README.md ├── data ├── KB13 │ ├── test │ │ └── data.txt │ ├── train │ │ └── data.txt │ └── val │ │ └── data.txt ├── NL-RX-Synth │ ├── test │ │ ├── data.txt │ │ ├── data_no_dual.txt │ │ ├── new_data.txt │ │ ├── vocab.source │ │ └── vocab.target │ ├── train │ │ ├── data.txt │ │ ├── data_no_dual.txt │ │ ├── new_data.txt │ │ ├── vocab.source │ │ └── vocab.target │ └── val │ │ ├── data.txt │ │ ├── data_no_dual.txt │ │ ├── new_data.txt │ │ ├── vocab.source │ │ └── vocab.target └── NL-RX-Turk │ ├── test │ ├── data.txt │ ├── data_from_all_synth.txt │ ├── data_kb13.txt │ ├── new_data.txt │ ├── vocab.source │ └── vocab.target │ ├── train │ ├── data.txt │ ├── new_data.txt │ ├── vocab.source │ └── vocab.target │ └── val │ ├── data.txt │ ├── new_data.txt │ ├── vocab.source │ └── vocab.target ├── pair_data ├── data_pairs_test(4_depth).txt ├── data_pairs_test(5_depth).txt ├── test_data.txt └── train_data.txt ├── random_generat.jar ├── readme.md ├── regexDFAEquals.py ├── regex_dfa_equals.jar ├── regex_equal_model ├── compare_regex.ipynb ├── compare_regex_model.pth ├── compare_regex_model_share.pth ├── compare_regex_test.ipynb └── compare_vocab.txt ├── requirements.txt ├── seq2seq ├── __init__.py ├── __init__.pyc ├── dataset │ ├── __init__.py │ ├── __init__.pyc │ ├── fields.py │ └── fields.pyc ├── evaluator │ ├── __init__.py │ ├── __init__.pyc │ ├── evaluator.py │ ├── evaluator.pyc │ ├── predictor.py │ └── predictor.pyc ├── loss │ ├── __init__.py │ ├── __init__.pyc │ ├── loss.py │ └── loss.pyc ├── models │ ├── DecoderRNN.py │ ├── DecoderRNN.pyc │ ├── DecoderRNN_ex.py │ ├── EncoderRNN.py │ ├── EncoderRNN.pyc │ ├── TopKDecoder.py │ ├── TopKDecoder.pyc │ ├── __init__.py │ ├── __init__.pyc │ ├── attention.py │ ├── attention.pyc │ ├── baseRNN.py │ ├── baseRNN.pyc │ ├── seq2seq.py │ └── seq2seq.pyc ├── optim │ ├── __init__.py │ ├── __init__.pyc │ ├── optim.py │ └── optim.pyc ├── trainer │ ├── __init__.py │ ├── __init__.pyc │ ├── evaluate.py │ ├── self_critical_trainer.py │ ├── self_critical_trainer.pyc │ ├── supervised_trainer.py │ └── supervised_trainer.pyc └── util │ ├── __init__.py │ ├── __init__.pyc │ ├── checkpoint.py │ └── checkpoint.pyc ├── setup.py ├── softregex-eval.py └── softregex-train.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | *.pyc 3 | **/.ipynb_checkpoints/ 4 | 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SoftRegex: Generating Regex from Natural Language Descriptions using Softened Regex Equivalence 2 | 3 | This is implemantation of the paper [SoftRegex: Generating Regex from Natural Language Descriptions using Softened Regex Equivalence](https://www.aclweb.org/anthology/D19-1677/) [EMNLP 2019] 4 | 5 | -------------- 6 | 7 | ## Summary 8 | 9 | We continue the study of generating semantically correct regular expressions from natural language descriptions (NL). The current state-of-the-art model SemRegex produces regular expressions from NLs by rewarding the reinforced learning based on the semantic (rather than syntactic) equivalence between two regular expressions. Since the regular expression equivalence problem is PSPACE-complete, we introduce the EQ_Reg model for computing the similarity of two regular expressions using deep neural networks. Our EQ_Reg model essentially softens the equivalence of two regular expressions when used as a reward function. We then propose a new regex generation model, SoftRegex, using the EQ_Reg model, and empirically demonstrate that SoftRegex substantially reduces the training time (by a factor of at least 3.6) and produces state-of-the-art results on three benchmark datasets. 10 | 11 | -------------- 12 | 13 | ### Install from source 14 | pip install -r requirements.txt 15 | python setup.py install 16 | 17 | ### train 18 | dataname = kb13, NL-RX-Synth, NL-RX-Turk 19 | python softregex-train.py (dataname) 20 | python softregex-eval.py (dataname) 21 | -------------------------------------------------------------------------------- /data/KB13/test/data.txt: -------------------------------------------------------------------------------- 1 | lines containing words that start with . . * \ b [ ] * \ b . * 2 | lines that have at least 2 words that start with . ( . * \ b [ ] * \ b . * ) { 2 } 3 | lines that start with . . * 4 | lines containing , or . * ( | ) . * 5 | lines that contain words ending in . * \ b [ ] * \ b . * 6 | lines that contain at least 5 words . ( . * \ b [ ] [ ] * \ b . * ) { 5 , } 7 | lines using 3 instances of . * ( . * . * ) { 3 } . * 8 | lines that show in the beginning of the word and at the end of the word . . * ( ( \ b [ ] * \ b ) & ( \ b [ ] * \ b ) ) . * 9 | lines that have 2 vowels ( . * [ ] . * ) { 2 } 10 | lines that have the word . . * \ b \ b . * 11 | lines containing the text . * . * 12 | lines that end in a digit . * [ ] 13 | lines starting with followed by a word with . . * ( ( \ b [ ] + \ b ) & ( . * . * ) ) . * 14 | lines using 1 number and 2 letters . * ( . * [ ] . * ) & ( . * [ ] . * ) { 2 } . * 15 | lines that show and . . * ( . * . * ) & ( . * . * ) . * 16 | lines that contain at least 4 words . ( . * \ b [ ] + \ b . * ) { 4 , } 17 | lines which end with . * 18 | lines using before . * . * . * 19 | lines that do not contain the letter ~ ( . * . * ) 20 | lines ending in . * 21 | lines that have before and after . ( . * . * . * ) & ( . * . * . * ) 22 | lines containing and containing too ( . * . * ) & ( . * . * ) 23 | lines that contain the word . . * \ b \ b . * 24 | lines containing a 5 letter word beginning with . * \ b [ ] { 4 } \ b . * 25 | lines using more than 3 characters . * . { 3 , } . * 26 | lines with and not . ( . * . * ) & ( ~ ( . * . * ) ) 27 | lines that utilize the number . * . * 28 | lines that contain at least 1 vowel and at least 2 numbers ( ( . * [ ] . * ) & ( . * [ ] . * ) { 2 } ) 29 | lines containing and containing too ( . * . * ) & ( . * . * ) 30 | lines ending with . * 31 | lines containing , but not ( . * . * ) & ( ~ ( . * . * ) ) 32 | lines that are 20 characters or less . { 0 , 2 0 } 33 | lines with an that comes after a . . * . * . * 34 | lines using after or . . * ( | ) . * . * 35 | lines that contain words with . . * ( ( \ b [ ] + \ b ) & ( . * . * ) ) . * 36 | lines containing words which begin with and end with . * \ b [ ] * \ b . * 37 | lines that include 3 letters ( . * [ ] . * ) { 3 } 38 | lines having words ending with . . * \ b [ ] * \ b . * 39 | lines that contain 5 letter words . . * \ b [ ] { 5 } \ b . * 40 | lines that have words ending in . . * \ b [ ] * \ b . * 41 | lines which contain 3 or more vowels . ( . * [ ] . * ) { 3 , } 42 | lines utilizing words starting with . . * \ b [ ] * \ b . * 43 | lines that contain a word containing 5 or more letters . . * \ b [ ] { 5 , } \ b . * 44 | lines using the word next to a number . ( . * \ b \ b . * ) & ( . * [ ] . * ) 45 | lines that contain words ending in . * \ b [ ] * \ b . * 46 | lines that contain words using in them . . * \ b [ ] * [ ] * \ b . * 47 | lines that have 6 words . ( . * \ b [ ] + \ b . * ) { 6 } 48 | lines ending with . * 49 | lines that use words ending with . . * \ b [ ] * \ b . * 50 | lines that start with a and end with an . . * 51 | lines using . * . * 52 | lines using words that begin with the letter . . * \ b [ ] * \ b . * 53 | lines that have . . * . * 54 | lines using the word followed by . . * \ b \ b . * . * 55 | lines that utilize the number . . * . * 56 | lines that contain a word in all uppercase . . * \ b [ ] + \ b . * 57 | lines with before . . * . * . * 58 | lines where there is at least 1 word in which follows . * ( ( \ b [ ] [ ] * \ b ) & ( . * . * . * ) ) . * 59 | lines which contain . * . * 60 | lines that have no instances of but at least 1 instance of . ( ~ ( . * . * ) ) & ( ( . * . * ) { 1 , } ) 61 | lines that show the letter and number . ( . * . * ) & ( . * . * ) 62 | lines that have . * . * 63 | lines with . * . * 64 | lines that contain at least 2 words starting with in them . ( . * \ b [ ] * \ b . * ) { 2 } 65 | lines that contain and contain . ( . * . * ) & ( . * . * ) 66 | lines which contain . . * . * 67 | lines using the word . . * \ b \ b . * 68 | lines that use words starting with . . * \ b [ ] * \ b . * 69 | lines using a word starting with a vowel and ending with . * \ b [ ] [ ] * \ b . * 70 | lines which have at least 7 numbers . . * ( . * [ ] . * ) { 7 } . * 71 | lines that start with . * 72 | lines that begin with a number [ ] . * 73 | lines containing the word . . * \ b \ b . * 74 | lines that have followed by the word . * . * \ b \ b . * 75 | lines containing the word . . * \ b \ b . * 76 | lines containing or . * ( | ) . * 77 | lines utilizing the word . . * . * \ b \ b . * 78 | lines that have words with . . * \ b [ ] * [ ] * \ b . * 79 | lines starting with . * 80 | lines using a 3 letter sequence starting with . * [ ] { 2 } . * 81 | lines that have 2 words using 4 letters ( . * \ b [ ] { 4 } \ b . * ) { 2 } 82 | lines having words ending with . . * \ b [ ] * \ b . * 83 | lines that use followed by words starting with . * . * ( ( \ b [ ] + \ b ) & ( . * ) ) . * 84 | lines using . * . * 85 | lines that end with . * 86 | lines using at least 3 ( . * . * ) { 3 , } 87 | lines that start with a vowel and ends in an . [ ] . * 88 | lines that have a capital , but not a . ( . * . * ) & ( ~ ( . * . * ) ) 89 | lines that contain the number at least twice . ( . * . * ) { 2 } 90 | lines that have 3 words and 2 numbers . ( . * \ b [ ] + \ b . * ) { 3 } & ( . * [ ] . * ) { 2 } 91 | lines using 4 instances of . * ( . * . * ) { 4 } . * 92 | lines with 3 numbers and the word . ( . * [ ] . * ) { 3 } & ( . * \ b \ b . * ) . * 93 | 3 letter lines starting with [ ] { 2 } 94 | lines containing a word using the letters . . * \ b [ ] * [ ] * \ b . * 95 | lines that utilize words starting with . . * \ b [ ] * \ b . * 96 | lines using at least 5 ( . * . * ) { 5 , } 97 | lines that have words ending with . . * \ b [ ] * \ b . * 98 | lines which start with . * 99 | lines which contain a word starting with the letter and a word starting with the letter . ( . * \ b [ ] * \ b . * ) & ( . * \ b [ ] * \ b . * ) 100 | lines that contain 2 words that have 2 letters . ( . * \ b [ ] { 2 } \ b . * ) { 2 } 101 | lines that begin with the word . \ b . * 102 | lines that contain the word . . * \ b \ b . * 103 | lines that feature or before words that start with capital letters . . * ( | ) . * \ b [ ] [ ] * \ b . * 104 | lines containing or before or . * ( | ) . * ( | ) . * 105 | lines that begin with a number and end with or . [ ] . * ( | ) 106 | lines that contain words starting with . . * \ b [ ] * \ b . * 107 | lines that end in . * 108 | lines which begin with the letter . . * 109 | lines containing or before or . * ( | ) . * ( | ) . * 110 | lines that contain and contain . ( . * . * ) & ( . * . * ) 111 | lines containing words that end with . . * \ b [ ] * \ b . * 112 | lines with words the word . . * \ b \ b . * 113 | lines using words starting with . . * \ b [ ] * \ b . * 114 | lines containing both and ( . * . * ) & ( . * . * ) 115 | lines that use words starting with . . * \ b [ ] * \ b . * 116 | lines that contain the word but not . ( . * \ b \ b . * ) & ( ~ ( . * . * ) ) 117 | lines having after and before . . * ( ( . * . * . * ) & ( . * . * . * ) ) . * 118 | lines that contain words using the letters . * \ b [ ] * [ ] * \ b . * 119 | lines utilizing the number . . * . * 120 | lines containing and containing the word . * . * & ( . * \ b \ b . * ) 121 | lines that are composed of 4 or more words . ( . * \ b [ ] + \ b . * ) { 4 , } 122 | lines that contain words starting with . . * ( ( \ b [ ] + \ b ) & ( . * ) ) . * 123 | lines that contain 5 or more letters . * ( . * [ ] . * ) { 5 , } . * 124 | lines using a word which contains at least 5 letters . . * \ b [ ] { 5 , } \ b . * 125 | lines that begin with the word . \ b \ b . * 126 | lines containing and also containing ( . * . * ) & ( . * . * ) 127 | lines that use the word . * \ b \ b . * 128 | lines utilizing the number . . * . * 129 | lines that contain 3 words . ( . * \ b [ ] + \ b . * ) { 3 } 130 | lines using words starting with . . * ( ( \ b [ ] + \ b ) & ( . * ) ) . * 131 | lines that include the string . * . * 132 | lines that contain the letter and the number . ( . * . * ) & ( . * . * ) 133 | lines containing only 4 words . ( ( ~ [ ] ) * \ b [ ] + \ b ( ~ [ ] ) * ) { 4 } 134 | lines that start with and end with . * 135 | lines using 5 ( . * . * ) { 5 } 136 | lines containing , or . * ( | ) . * 137 | lines using or . * ( | ) . * 138 | lines with instances of between and . * . * . * . * | . * . * . * . * 139 | lines that contain words and 4 numbers . ( . * \ b [ ] + \ b . * ) & ( . * [ ] . * ) { 4 } 140 | lines that contain but do not contain . ( . * . * ) & ( ~ ( . * . * ) ) 141 | lines that contain at least 3 5 letter words ( . * \ b [ ] { 5 } \ b . * ) { 3 } 142 | lines that contain 2 numbers and 3 words and contain the letter . . * ( . * [ ] . * ) { 2 } & ( . * \ b [ ] + \ b . * ) { 3 } & ( . * . * ) . * 143 | lines using a capital letter followed by a number . * [ ] . * [ ] . * 144 | lines having the letter . . * . * 145 | lines with words that end in . . * \ b [ ] * \ b . * 146 | lines that contain words that end in that do not begin with . * ( ( \ b [ ] + \ b ) & ( . * ) & ( ~ ( . * ) ) ) . * 147 | lines that contain with immediately after it . . * . * 148 | lines that begin with a number [ ] . * 149 | lines using words ending with . . * \ b [ ] * \ b . * 150 | lines containing a letter . * [ ] . * 151 | lines that have words ending in or . . * \ b [ ] * ( | ) \ b . * 152 | lines that contain words that have the letter occuring after the letter . . * \ b [ ] * [ ] * [ ] * \ b . * 153 | lines that have a in them . * . * 154 | lines that have within them words ending in . . * \ b [ ] * \ b . * 155 | lines containing words that end with . * ( ( \ b . * \ b ) & ( [ ] + ) ) . * . * 156 | lines having words ending with . . * ( ( \ b [ ] + \ b ) & ( . * ) ) . * 157 | lines that start with . * 158 | lines that contain or before . . * ( | ) . * . * 159 | lines that contain a word using the letters . * ( ( \ b [ ] + \ b ) & ( . * . * ) ) . * 160 | lines that have the number . . * . * 161 | lines that have a word ending with the letters . . * \ b [ ] * \ b . * 162 | lines with no vowels ~ ( . * [ ] . * ) 163 | lines that have words ending with . . * \ b [ ] * \ b . * 164 | lines that have the letter at the end of a word . . * \ b [ ] * \ b . * 165 | lines containing words that begin with and end with . * \ b [ ] * \ b . * 166 | lines which do not contain the letter . ~ ( . * . * ) 167 | lines that have any instance of . . * . * 168 | lines using or . * ( | ) . * 169 | lines that have at least 3 words beginning with a vowel . ( . * \ b [ ] [ ] * \ b . * ) { 3 } 170 | lines that use words ending with . . * \ b [ ] * \ b . * 171 | lines that ends with letter . * 172 | lines that have words with . . * \ b [ ] * [ ] * \ b . * 173 | lines containing the word . . * \ b \ b . * 174 | lines that use the word followed by words starting with . * \ b \ b . * \ b [ ] * \ b . * 175 | lines that contain words using in them . . * ( . * \ b [ ] * [ ] * \ b . * ) . * 176 | lines that contain digits . . * [ ] . * 177 | lines that have . * . * 178 | lines which contain a word using 2 or more letters . * \ b [ ] { 2 , } \ b . * 179 | lines that utilize words starting with followed by the word . * ( ( \ b [ ] + \ b ) & ( . * ) ) . * \ b \ b . * 180 | lines that have 5 words that all end with the letter . ( . * ( ( \ b . * \ b ) & ( [ ] + ) ) . * ) { 5 } 181 | lines that have all of its letters capitalized . ~ ( . * [ ] . * ) 182 | lines with and ( . * . * ) & ( . * . * ) 183 | lines containing only a letter [ ] 184 | lines containing both letters and numbers , but no capitals . ( . * [ ] . * ) & ( . * [ ] . * ) & ( ~ ( . * [ ] . * ) ) 185 | lines that contain 4 or more ( . * . * ) { 4 } 186 | lines using the vowel combination . . * . * 187 | lines that use words ending in . * ( ( . * ) & ( \ b [ ] [ ] * \ b ) ) . * 188 | lines having words ending with . . * ( ( \ b [ ] + \ b ) & ( . * ) ) . * 189 | lines starting with . * 190 | lines which contain and . * . * . * | . * . * . * 191 | lines that contain the letters . . * . * 192 | lines having words starting with . . * \ b [ ] * \ b . * 193 | lines using , , or . * ( | | ) . * 194 | lines containing ( . * . * ) . * 195 | lines that have at least 2 words of 3 or more letters ( . * \ b [ ] { 3 , } \ b . * ) { 2 , } 196 | lines containing 3 or more ( . * . * ) { 3 , } 197 | lines utilizing . * . * 198 | lines that have but not ( . * . * ) & ( ~ ( . * . * ) ) 199 | lines using at least 6 characters . * . { 6 } . * 200 | lines that contain at least 2 vowels in a word . . * ( ( \ b [ ] + \ b ) & ( . * [ ] . * ) { 2 } ) . * 201 | lines which do not contain a number . ~ ( . * [ ] . * ) 202 | lines using the word . * \ b \ b . * 203 | lines containing words that start with and end with ( . * ) ? 204 | lines that use words that are only 4 letters long . . * \ b [ ] { 4 } \ b . * 205 | lines that contain 5 or more words . . * ( . * \ b [ ] [ ] * \ b . * ) { 5 } . * 206 | lines that begin with the word . \ b \ b . * 207 | -------------------------------------------------------------------------------- /data/KB13/val/data.txt: -------------------------------------------------------------------------------- 1 | lines containing words that start with . . * \ b [ ] * \ b . * 2 | lines that have at least 2 words that start with . ( . * \ b [ ] * \ b . * ) { 2 } 3 | lines that start with . . * 4 | lines containing , or . * ( | ) . * 5 | lines that contain words ending in . * \ b [ ] * \ b . * 6 | lines that contain at least 5 words . ( . * \ b [ ] [ ] * \ b . * ) { 5 , } 7 | lines using 3 instances of . * ( . * . * ) { 3 } . * 8 | lines that show in the beginning of the word and at the end of the word . . * ( ( \ b [ ] * \ b ) & ( \ b [ ] * \ b ) ) . * 9 | lines that have 2 vowels ( . * [ ] . * ) { 2 } 10 | lines that have the word . . * \ b \ b . * 11 | lines containing the text . * . * 12 | lines that end in a digit . * [ ] 13 | lines starting with followed by a word with . . * ( ( \ b [ ] + \ b ) & ( . * . * ) ) . * 14 | lines using 1 number and 2 letters . * ( . * [ ] . * ) & ( . * [ ] . * ) { 2 } . * 15 | lines that show and . . * ( . * . * ) & ( . * . * ) . * 16 | lines that contain at least 4 words . ( . * \ b [ ] + \ b . * ) { 4 , } 17 | lines which end with . * 18 | lines using before . * . * . * 19 | lines that do not contain the letter ~ ( . * . * ) 20 | lines ending in . * 21 | lines that have before and after . ( . * . * . * ) & ( . * . * . * ) 22 | lines containing and containing too ( . * . * ) & ( . * . * ) 23 | lines that contain the word . . * \ b \ b . * 24 | lines containing a 5 letter word beginning with . * \ b [ ] { 4 } \ b . * 25 | lines using more than 3 characters . * . { 3 , } . * 26 | lines with and not . ( . * . * ) & ( ~ ( . * . * ) ) 27 | lines that utilize the number . * . * 28 | lines that contain at least 1 vowel and at least 2 numbers ( ( . * [ ] . * ) & ( . * [ ] . * ) { 2 } ) 29 | lines containing and containing too ( . * . * ) & ( . * . * ) 30 | lines ending with . * 31 | lines containing , but not ( . * . * ) & ( ~ ( . * . * ) ) 32 | lines that are 20 characters or less . { 0 , 2 0 } 33 | lines with an that comes after a . . * . * . * 34 | lines using after or . . * ( | ) . * . * 35 | lines that contain words with . . * ( ( \ b [ ] + \ b ) & ( . * . * ) ) . * 36 | lines containing words which begin with and end with . * \ b [ ] * \ b . * 37 | lines that include 3 letters ( . * [ ] . * ) { 3 } 38 | lines having words ending with . . * \ b [ ] * \ b . * 39 | lines that contain 5 letter words . . * \ b [ ] { 5 } \ b . * 40 | lines that have words ending in . . * \ b [ ] * \ b . * 41 | lines which contain 3 or more vowels . ( . * [ ] . * ) { 3 , } 42 | lines utilizing words starting with . . * \ b [ ] * \ b . * 43 | lines that contain a word containing 5 or more letters . . * \ b [ ] { 5 , } \ b . * 44 | lines using the word next to a number . ( . * \ b \ b . * ) & ( . * [ ] . * ) 45 | lines that contain words ending in . * \ b [ ] * \ b . * 46 | lines that contain words using in them . . * \ b [ ] * [ ] * \ b . * 47 | lines that have 6 words . ( . * \ b [ ] + \ b . * ) { 6 } 48 | lines ending with . * 49 | lines that use words ending with . . * \ b [ ] * \ b . * 50 | lines that start with a and end with an . . * 51 | lines using . * . * 52 | lines using words that begin with the letter . . * \ b [ ] * \ b . * 53 | lines that have . . * . * 54 | lines using the word followed by . . * \ b \ b . * . * 55 | lines that utilize the number . . * . * 56 | lines that contain a word in all uppercase . . * \ b [ ] + \ b . * 57 | lines with before . . * . * . * 58 | lines where there is at least 1 word in which follows . * ( ( \ b [ ] [ ] * \ b ) & ( . * . * . * ) ) . * 59 | lines which contain . * . * 60 | lines that have no instances of but at least 1 instance of . ( ~ ( . * . * ) ) & ( ( . * . * ) { 1 , } ) 61 | lines that show the letter and number . ( . * . * ) & ( . * . * ) 62 | lines that have . * . * 63 | lines with . * . * 64 | lines that contain at least 2 words starting with in them . ( . * \ b [ ] * \ b . * ) { 2 } 65 | lines that contain and contain . ( . * . * ) & ( . * . * ) 66 | lines which contain . . * . * 67 | lines using the word . . * \ b \ b . * 68 | lines that use words starting with . . * \ b [ ] * \ b . * 69 | lines using a word starting with a vowel and ending with . * \ b [ ] [ ] * \ b . * 70 | lines which have at least 7 numbers . . * ( . * [ ] . * ) { 7 } . * 71 | lines that start with . * 72 | lines that begin with a number [ ] . * 73 | lines containing the word . . * \ b \ b . * 74 | lines that have followed by the word . * . * \ b \ b . * 75 | lines containing the word . . * \ b \ b . * 76 | lines containing or . * ( | ) . * 77 | lines utilizing the word . . * . * \ b \ b . * 78 | lines that have words with . . * \ b [ ] * [ ] * \ b . * 79 | lines starting with . * 80 | lines using a 3 letter sequence starting with . * [ ] { 2 } . * 81 | lines that have 2 words using 4 letters ( . * \ b [ ] { 4 } \ b . * ) { 2 } 82 | lines having words ending with . . * \ b [ ] * \ b . * 83 | lines that use followed by words starting with . * . * ( ( \ b [ ] + \ b ) & ( . * ) ) . * 84 | lines using . * . * 85 | lines that end with . * 86 | lines using at least 3 ( . * . * ) { 3 , } 87 | lines that start with a vowel and ends in an . [ ] . * 88 | lines that have a capital , but not a . ( . * . * ) & ( ~ ( . * . * ) ) 89 | lines that contain the number at least twice . ( . * . * ) { 2 } 90 | lines that have 3 words and 2 numbers . ( . * \ b [ ] + \ b . * ) { 3 } & ( . * [ ] . * ) { 2 } 91 | lines using 4 instances of . * ( . * . * ) { 4 } . * 92 | lines with 3 numbers and the word . ( . * [ ] . * ) { 3 } & ( . * \ b \ b . * ) . * 93 | 3 letter lines starting with [ ] { 2 } 94 | lines containing a word using the letters . . * \ b [ ] * [ ] * \ b . * 95 | lines that utilize words starting with . . * \ b [ ] * \ b . * 96 | lines using at least 5 ( . * . * ) { 5 , } 97 | lines that have words ending with . . * \ b [ ] * \ b . * 98 | lines which start with . * 99 | lines which contain a word starting with the letter and a word starting with the letter . ( . * \ b [ ] * \ b . * ) & ( . * \ b [ ] * \ b . * ) 100 | lines that contain 2 words that have 2 letters . ( . * \ b [ ] { 2 } \ b . * ) { 2 } 101 | lines that begin with the word . \ b . * 102 | lines that contain the word . . * \ b \ b . * 103 | lines that feature or before words that start with capital letters . . * ( | ) . * \ b [ ] [ ] * \ b . * 104 | lines containing or before or . * ( | ) . * ( | ) . * 105 | lines that begin with a number and end with or . [ ] . * ( | ) 106 | lines that contain words starting with . . * \ b [ ] * \ b . * 107 | lines that end in . * 108 | lines which begin with the letter . . * 109 | lines containing or before or . * ( | ) . * ( | ) . * 110 | lines that contain and contain . ( . * . * ) & ( . * . * ) 111 | lines containing words that end with . . * \ b [ ] * \ b . * 112 | lines with words the word . . * \ b \ b . * 113 | lines using words starting with . . * \ b [ ] * \ b . * 114 | lines containing both and ( . * . * ) & ( . * . * ) 115 | lines that use words starting with . . * \ b [ ] * \ b . * 116 | lines that contain the word but not . ( . * \ b \ b . * ) & ( ~ ( . * . * ) ) 117 | lines having after and before . . * ( ( . * . * . * ) & ( . * . * . * ) ) . * 118 | lines that contain words using the letters . * \ b [ ] * [ ] * \ b . * 119 | lines utilizing the number . . * . * 120 | lines containing and containing the word . * . * & ( . * \ b \ b . * ) 121 | lines that are composed of 4 or more words . ( . * \ b [ ] + \ b . * ) { 4 , } 122 | lines that contain words starting with . . * ( ( \ b [ ] + \ b ) & ( . * ) ) . * 123 | lines that contain 5 or more letters . * ( . * [ ] . * ) { 5 , } . * 124 | lines using a word which contains at least 5 letters . . * \ b [ ] { 5 , } \ b . * 125 | lines that begin with the word . \ b \ b . * 126 | lines containing and also containing ( . * . * ) & ( . * . * ) 127 | lines that use the word . * \ b \ b . * 128 | lines utilizing the number . . * . * 129 | lines that contain 3 words . ( . * \ b [ ] + \ b . * ) { 3 } 130 | lines using words starting with . . * ( ( \ b [ ] + \ b ) & ( . * ) ) . * 131 | lines that include the string . * . * 132 | lines that contain the letter and the number . ( . * . * ) & ( . * . * ) 133 | lines containing only 4 words . ( ( ~ [ ] ) * \ b [ ] + \ b ( ~ [ ] ) * ) { 4 } 134 | lines that start with and end with . * 135 | lines using 5 ( . * . * ) { 5 } 136 | lines containing , or . * ( | ) . * 137 | lines using or . * ( | ) . * 138 | lines with instances of between and . * . * . * . * | . * . * . * . * 139 | lines that contain words and 4 numbers . ( . * \ b [ ] + \ b . * ) & ( . * [ ] . * ) { 4 } 140 | lines that contain but do not contain . ( . * . * ) & ( ~ ( . * . * ) ) 141 | lines that contain at least 3 5 letter words ( . * \ b [ ] { 5 } \ b . * ) { 3 } 142 | lines that contain 2 numbers and 3 words and contain the letter . . * ( . * [ ] . * ) { 2 } & ( . * \ b [ ] + \ b . * ) { 3 } & ( . * . * ) . * 143 | lines using a capital letter followed by a number . * [ ] . * [ ] . * 144 | lines having the letter . . * . * 145 | lines with words that end in . . * \ b [ ] * \ b . * 146 | lines that contain words that end in that do not begin with . * ( ( \ b [ ] + \ b ) & ( . * ) & ( ~ ( . * ) ) ) . * 147 | lines that contain with immediately after it . . * . * 148 | lines that begin with a number [ ] . * 149 | lines using words ending with . . * \ b [ ] * \ b . * 150 | lines containing a letter . * [ ] . * 151 | lines that have words ending in or . . * \ b [ ] * ( | ) \ b . * 152 | lines that contain words that have the letter occuring after the letter . . * \ b [ ] * [ ] * [ ] * \ b . * 153 | lines that have a in them . * . * 154 | lines that have within them words ending in . . * \ b [ ] * \ b . * 155 | lines containing words that end with . * ( ( \ b . * \ b ) & ( [ ] + ) ) . * . * 156 | lines having words ending with . . * ( ( \ b [ ] + \ b ) & ( . * ) ) . * 157 | lines that start with . * 158 | lines that contain or before . . * ( | ) . * . * 159 | lines that contain a word using the letters . * ( ( \ b [ ] + \ b ) & ( . * . * ) ) . * 160 | lines that have the number . . * . * 161 | lines that have a word ending with the letters . . * \ b [ ] * \ b . * 162 | lines with no vowels ~ ( . * [ ] . * ) 163 | lines that have words ending with . . * \ b [ ] * \ b . * 164 | lines that have the letter at the end of a word . . * \ b [ ] * \ b . * 165 | lines containing words that begin with and end with . * \ b [ ] * \ b . * 166 | lines which do not contain the letter . ~ ( . * . * ) 167 | lines that have any instance of . . * . * 168 | lines using or . * ( | ) . * 169 | lines that have at least 3 words beginning with a vowel . ( . * \ b [ ] [ ] * \ b . * ) { 3 } 170 | lines that use words ending with . . * \ b [ ] * \ b . * 171 | lines that ends with letter . * 172 | lines that have words with . . * \ b [ ] * [ ] * \ b . * 173 | lines containing the word . . * \ b \ b . * 174 | lines that use the word followed by words starting with . * \ b \ b . * \ b [ ] * \ b . * 175 | lines that contain words using in them . . * ( . * \ b [ ] * [ ] * \ b . * ) . * 176 | lines that contain digits . . * [ ] . * 177 | lines that have . * . * 178 | lines which contain a word using 2 or more letters . * \ b [ ] { 2 , } \ b . * 179 | lines that utilize words starting with followed by the word . * ( ( \ b [ ] + \ b ) & ( . * ) ) . * \ b \ b . * 180 | lines that have 5 words that all end with the letter . ( . * ( ( \ b . * \ b ) & ( [ ] + ) ) . * ) { 5 } 181 | lines that have all of its letters capitalized . ~ ( . * [ ] . * ) 182 | lines with and ( . * . * ) & ( . * . * ) 183 | lines containing only a letter [ ] 184 | lines containing both letters and numbers , but no capitals . ( . * [ ] . * ) & ( . * [ ] . * ) & ( ~ ( . * [ ] . * ) ) 185 | lines that contain 4 or more ( . * . * ) { 4 } 186 | lines using the vowel combination . . * . * 187 | lines that use words ending in . * ( ( . * ) & ( \ b [ ] [ ] * \ b ) ) . * 188 | lines having words ending with . . * ( ( \ b [ ] + \ b ) & ( . * ) ) . * 189 | lines starting with . * 190 | lines which contain and . * . * . * | . * . * . * 191 | lines that contain the letters . . * . * 192 | lines having words starting with . . * \ b [ ] * \ b . * 193 | lines using , , or . * ( | | ) . * 194 | lines containing ( . * . * ) . * 195 | lines that have at least 2 words of 3 or more letters ( . * \ b [ ] { 3 , } \ b . * ) { 2 , } 196 | lines containing 3 or more ( . * . * ) { 3 , } 197 | lines utilizing . * . * 198 | lines that have but not ( . * . * ) & ( ~ ( . * . * ) ) 199 | lines using at least 6 characters . * . { 6 } . * 200 | lines that contain at least 2 vowels in a word . . * ( ( \ b [ ] + \ b ) & ( . * [ ] . * ) { 2 } ) . * 201 | lines which do not contain a number . ~ ( . * [ ] . * ) 202 | lines using the word . * \ b \ b . * 203 | lines containing words that start with and end with ( . * ) ? 204 | lines that use words that are only 4 letters long . . * \ b [ ] { 4 } \ b . * 205 | lines that contain 5 or more words . . * ( . * \ b [ ] [ ] * \ b . * ) { 5 } . * 206 | lines that begin with the word . \ b \ b . * 207 | -------------------------------------------------------------------------------- /data/NL-RX-Synth/test/vocab.source: -------------------------------------------------------------------------------- 1 | that 2 | ending 3 | and 4 | 7 5 | either 6 | the 7 | having 8 | lines 9 | before 10 | a 11 | or 12 | zero 13 | by 14 | number 15 | 16 | 6 17 | 2 18 | 19 | 3 20 | , 21 | words 22 | vowel 23 | only 24 | don't 25 | contain 26 | containing 27 | starting 28 | 4 29 | letter 30 | at 31 | more 32 | have 33 | character 34 | 35 | string 36 | times 37 | not 38 | lower-case 39 | followed 40 | once 41 | least 42 | 43 | capital 44 | with 45 | 5 -------------------------------------------------------------------------------- /data/NL-RX-Synth/test/vocab.target: -------------------------------------------------------------------------------- 1 | { 2 | ] 3 | ( 4 | 7 5 | * 6 | & 7 | | 8 | 9 | 10 | ~ 11 | 6 12 | 2 13 | 14 | 15 | 3 16 | , 17 | [ 18 | \ 19 | + 20 | 21 | 4 22 | 23 | 24 | } 25 | ) 26 | 27 | b 28 | . 29 | 30 | 5 -------------------------------------------------------------------------------- /data/NL-RX-Synth/train/vocab.source: -------------------------------------------------------------------------------- 1 | that 2 | ending 3 | and 4 | 7 5 | either 6 | the 7 | having 8 | lines 9 | a 10 | before 11 | or 12 | by 13 | zero 14 | number 15 | 16 | 6 17 | 2 18 | 19 | 3 20 | , 21 | words 22 | vowel 23 | only 24 | don't 25 | contain 26 | containing 27 | starting 28 | 4 29 | letter 30 | at 31 | more 32 | have 33 | character 34 | 35 | string 36 | not 37 | followed 38 | lower-case 39 | times 40 | once 41 | least 42 | 43 | capital 44 | with 45 | 5 -------------------------------------------------------------------------------- /data/NL-RX-Synth/train/vocab.target: -------------------------------------------------------------------------------- 1 | { 2 | ] 3 | ( 4 | 7 5 | * 6 | & 7 | | 8 | 9 | ~ 10 | 11 | 6 12 | 2 13 | 14 | 15 | 3 16 | , 17 | [ 18 | \ 19 | + 20 | 21 | 4 22 | 23 | 24 | } 25 | ) 26 | 27 | b 28 | . 29 | 30 | 5 -------------------------------------------------------------------------------- /data/NL-RX-Synth/val/vocab.source: -------------------------------------------------------------------------------- 1 | that 2 | ending 3 | and 4 | 7 5 | either 6 | having 7 | the 8 | lines 9 | before 10 | a 11 | or 12 | by 13 | zero 14 | number 15 | 16 | 6 17 | 2 18 | 19 | 3 20 | , 21 | words 22 | vowel 23 | don't 24 | only 25 | contain 26 | containing 27 | starting 28 | 4 29 | letter 30 | at 31 | more 32 | have 33 | character 34 | 35 | string 36 | times 37 | followed 38 | lower-case 39 | not 40 | once 41 | least 42 | 43 | capital 44 | with 45 | 5 -------------------------------------------------------------------------------- /data/NL-RX-Synth/val/vocab.target: -------------------------------------------------------------------------------- 1 | { 2 | ] 3 | ( 4 | 7 5 | * 6 | & 7 | | 8 | 9 | ~ 10 | 11 | 6 12 | 2 13 | 14 | 15 | 3 16 | , 17 | [ 18 | \ 19 | + 20 | 21 | 4 22 | 23 | 24 | } 25 | ) 26 | 27 | b 28 | . 29 | 30 | 5 -------------------------------------------------------------------------------- /data/NL-RX-Turk/test/vocab.source: -------------------------------------------------------------------------------- 1 | have6or 2 | should 3 | 4 | last 5 | ahead 6 | lettter 7 | "truck' 8 | uppercase 9 | some 10 | for 11 | do 12 | use 13 | cased 14 | comprise 15 | are 16 | choice 17 | sting 18 | letterfollowed 19 | occurence 20 | strings 21 | does 22 | follows 23 | iin 24 | than 25 | just 26 | character 27 | 6 28 | another 29 | & 30 | occurring 31 | upper 32 | following: 33 | it 34 | minimum 35 | works 36 | at 37 | upper-case 38 | an 39 | before 40 | . 41 | alphabet 42 | precedes 43 | succession: 44 | repeated 45 | on 46 | conclude 47 | vowel 48 | 49 | included 50 | least5times 51 | amount 52 | more 53 | contains 54 | follow 55 | these 56 | containing 57 | nor 58 | may 59 | subsequent 60 | finishing 61 | preceded 62 | showing 63 | dog 64 | things 65 | aren't 66 | occurrences 67 | culminating 68 | appearance 69 | ends 70 | occasions 71 | equal 72 | ; 73 | times 74 | mimnimum 75 | before5or 76 | other 77 | including 78 | uppe 79 | truck' 80 | followed 81 | , 82 | containg 83 | maybe 84 | be 85 | + 86 | small 87 | something 88 | required 89 | least6instances 90 | 'dog 91 | includes 92 | multiple 93 | that 94 | listed 95 | then 96 | words 97 | : 98 | directly 99 | 1 100 | up 101 | preceeding 102 | instance 103 | contaning 104 | the 105 | vowels 106 | alongwith 107 | either 108 | small-case 109 | letters 110 | terminating 111 | ines 112 | least5occurrences 113 | least5instances 114 | 'truck'[ 115 | ltter 116 | your 117 | ttthem 118 | being 119 | characters 120 | need 121 | don't 122 | precede 123 | seen 124 | as 125 | anything 126 | of 127 | number 128 | but 129 | had 130 | can 131 | above 132 | -a 133 | first 134 | starting 135 | preceding 136 | 137 | wih 138 | i 139 | to 140 | 'string 141 | a 142 | doesn't 143 | over 144 | linesthe 145 | appears 146 | numbers 147 | u 148 | digits 149 | possessing 150 | "dog 151 | plus 152 | feature 153 | begining 154 | 7 155 | total 156 | charactere 157 | front 158 | items 159 | bigis 160 | twice 161 | contain6or 162 | only 163 | following 164 | non 165 | ''dog' 166 | begin 167 | lined 168 | 'vowel 169 | lettere 170 | using 171 | time 172 | 'truck 173 | present 174 | beginning 175 | characcter 176 | liens 177 | captial 178 | capitol 179 | combinations 180 | also 181 | isn't 182 | string 183 | 3 184 | capitals 185 | without 186 | is 187 | once 188 | used 189 | empty 190 | lower 191 | nonletter 192 | that's 193 | might 194 | least6times 195 | smal 196 | 'dog" 197 | occurs 198 | or 199 | thrice 200 | sentence: 201 | letter 202 | wherein 203 | non-number 204 | word 205 | 0 206 | 207 | no 208 | ending 209 | where 210 | witha 211 | instances 212 | end 213 | capital 214 | capitla 215 | occur 216 | occurrr 217 | 4 218 | 3+ 219 | with5or 220 | caee 221 | well 222 | beforehand 223 | additional 224 | comes 225 | all 226 | contained 227 | by 228 | and 229 | have 230 | lies 231 | have5or 232 | lines\ 233 | occuring 234 | cannot 235 | zero 236 | lowercase 237 | 5 238 | must 239 | appearing 240 | capitalized 241 | contain 242 | occurrence 243 | 6or 244 | single 245 | were 246 | both 247 | lines 248 | 2 249 | not 250 | combination 251 | which 252 | greater 253 | o 254 | there 255 | acapital 256 | vowel: 257 | lower-case 258 | letter5times 259 | - 260 | line 261 | 5x 262 | consisting 263 | after 264 | lacking 265 | having 266 | coming 267 | with 268 | 'ring 269 | any 270 | has 271 | class 272 | occasion 273 | containging 274 | long 275 | least 276 | truck 277 | case 278 | hwith 279 | so 280 | prior 281 | them 282 | starts 283 | start 284 | between 285 | come 286 | code 287 | stating 288 | in 289 | include 290 | e 291 | less 292 | like 293 | lease 294 | certain 295 | kind 296 | 'dog': 297 | featuring 298 | succeeds 299 | ore 300 | numeral -------------------------------------------------------------------------------- /data/NL-RX-Turk/test/vocab.target: -------------------------------------------------------------------------------- 1 | ~ 2 | 3 | 7 4 | 5 | ) 6 | 7 | | 8 | } 9 | , 10 | ] 11 | 5 12 | + 13 | 14 | 2 15 | \ 16 | & 17 | 6 18 | 3 19 | [ 20 | 21 | * 22 | . 23 | 24 | 25 | 26 | 4 27 | 28 | { 29 | ( 30 | b -------------------------------------------------------------------------------- /data/NL-RX-Turk/train/vocab.source: -------------------------------------------------------------------------------- 1 | lettter 2 | staring 3 | woth 4 | do 5 | use 6 | sting 7 | aq 8 | terminate 9 | th 10 | wit 11 | rate 12 | find 13 | just 14 | lower'case 15 | minimum 16 | succeeded 17 | second 18 | along 19 | accompanied 20 | contains 21 | lettered 22 | nor 23 | aren't 24 | ends 25 | equal 26 | bark 27 | berfe 28 | small 29 | lives 30 | consecutive 31 | appear 32 | you 33 | will 34 | number-containing 35 | 1 36 | contating 37 | concluding 38 | worded 39 | either 40 | started 41 | letters 42 | letter-containing 43 | followsed 44 | finally 45 | ltter 46 | lease6times 47 | characters 48 | as 49 | of 50 | added 51 | starting 52 | ro 53 | doesn't 54 | a 55 | over 56 | order: 57 | numbers 58 | plus 59 | feature 60 | terminal 61 | lower-cased 62 | neither 63 | 'dog'-named 64 | front 65 | consist 66 | twice 67 | occuring6or 68 | contain6or 69 | with6or 70 | lettere 71 | iit 72 | 'truck 73 | latter 74 | beginning 75 | character/vowel 76 | also 77 | numerical 78 | 3 79 | without 80 | nonletter 81 | 'dog" 82 | occurs 83 | thrice 84 | sentence: 85 | contaiing 86 | ending 87 | rmore 88 | every 89 | thst 90 | 4 91 | lineswhich 92 | characater 93 | basis 94 | food 95 | containe 96 | found 97 | -string 98 | appearing 99 | capitalized 100 | placed 101 | there's 102 | greater 103 | shouldn't 104 | after 105 | any 106 | daily 107 | containging 108 | term 109 | long 110 | so 111 | start 112 | equals 113 | lowe-case 114 | stating 115 | include 116 | less 117 | daphne12 118 | numeral 119 | withthe 120 | words: 121 | uppercase 122 | some 123 | choice 124 | cased 125 | sets 126 | follows 127 | does 128 | another 129 | 6 130 | occurring 131 | proceed 132 | straing 133 | each 134 | at 135 | upper-case 136 | alphabet 137 | precedes 138 | took 139 | vowel 140 | encompassing 141 | amount 142 | wiht 143 | incidents 144 | iether 145 | finishing 146 | showing 147 | things: 148 | things 149 | going 150 | sweaters 151 | ; 152 | times 153 | they 154 | mroe 155 | uppe 156 | precedig 157 | characte 158 | a7 159 | that 160 | : 161 | capitalor 162 | vowels 163 | turn 164 | precedeing 165 | ines 166 | wihich 167 | follwed 168 | need 169 | don't 170 | seen 171 | anything 172 | order 173 | first 174 | sister's 175 | i 176 | to 177 | presence 178 | aare 179 | digit 180 | charactered 181 | occring 182 | lined 183 | singular 184 | capitals 185 | (or 186 | made 187 | here 188 | letter 189 | wherein 190 | where 191 | witha 192 | treats 193 | capital 194 | tmes 195 | ,5or 196 | uppercae 197 | preceeded 198 | with5or 199 | additional 200 | taht's 201 | by 202 | lies 203 | aqd 204 | 5 205 | lowercase 206 | must 207 | contain 208 | both 209 | there 210 | not 211 | tring 212 | >= 213 | numbre 214 | leter 215 | - 216 | friend 217 | containing5or 218 | having 219 | 'ring 220 | coming 221 | has 222 | got 223 | class 224 | truck 225 | come 226 | code 227 | lends 228 | named 229 | ring 230 | '"dog" 231 | should 232 | 233 | last 234 | nothing 235 | "dog' 236 | for 237 | 'dog'before 238 | are 239 | consonant 240 | iin 241 | dont 242 | capitial 243 | nes 244 | character 245 | upper 246 | even 247 | it 248 | chracter 249 | sell 250 | an 251 | . 252 | repeated 253 | likes 254 | least5times 255 | more 256 | follow 257 | case-letter 258 | containing 259 | zer 260 | afterwards 261 | may 262 | tines 263 | followed 264 | be 265 | same 266 | something 267 | stores 268 | the 269 | username 270 | terminating 271 | alot 272 | search 273 | your 274 | being 275 | precede 276 | containinbg 277 | link 278 | 1219@e 279 | if 280 | determine 281 | positioned 282 | levtter 283 | walking 284 | lover-case 285 | specialty 286 | beds 287 | 7 288 | letteres 289 | my 290 | following 291 | pet 292 | using 293 | charater 294 | present 295 | these: 296 | lower 297 | best 298 | least6times 299 | instances 300 | end 301 | repeat 302 | well 303 | comes 304 | non-lower-case 305 | all 306 | have 307 | regular 308 | character` 309 | occuring 310 | zero 311 | occurrence 312 | lines 313 | 2 314 | o 315 | would 316 | vowel: 317 | blines 318 | lower-case 319 | ling 320 | consisting 321 | fowel 322 | hving 323 | non-capital 324 | case 325 | starts 326 | them 327 | between 328 | in 329 | ha 330 | e 331 | them: 332 | numbering 333 | kind 334 | lower-cast 335 | ore 336 | lone 337 | misty 338 | whiich 339 | lastly 340 | activation 341 | fun 342 | strings 343 | normal 344 | fo 345 | bfore 346 | tthem 347 | than 348 | finish 349 | immediately 350 | toys 351 | following: 352 | non-word 353 | beging 354 | works 355 | lineswith 356 | on 357 | before 358 | conclude 359 | final 360 | 361 | walk 362 | least6occurrences 363 | vowe 364 | identify 365 | these 366 | preceded 367 | dog 368 | occurrences 369 | type 370 | caoital 371 | else 372 | other 373 | such 374 | including 375 | truck' 376 | , 377 | containg 378 | t 379 | clase 380 | terminates 381 | 'capital') 382 | 'dog 383 | then 384 | words 385 | open 386 | shows 387 | up 388 | instance 389 | out 390 | sequence 391 | than5times 392 | 'dog'and 393 | number 394 | but 395 | he 396 | preceding 397 | 398 | appears 399 | avoid 400 | u 401 | within 402 | possessing 403 | whose 404 | occurance 405 | strings'dog' 406 | total 407 | character-less 408 | items 409 | only 410 | non 411 | man 412 | begin 413 | begins 414 | time 415 | liens 416 | captial 417 | combinations 418 | string 419 | once 420 | is 421 | that's 422 | smal 423 | its 424 | or 425 | word 426 | 0 427 | 428 | no 429 | numberand 430 | occur 431 | opening 432 | ot 433 | taht 434 | contained 435 | and 436 | have5or 437 | arecontaining 438 | elements 439 | yelling 440 | were 441 | non-letter 442 | combination 443 | which 444 | contain5or 445 | line 446 | string'dog' 447 | strong 448 | phrases 449 | with 450 | least 451 | low 452 | prior 453 | higher 454 | characer 455 | three: 456 | like 457 | acharacter 458 | multiple 459 | case-letters -------------------------------------------------------------------------------- /data/NL-RX-Turk/train/vocab.target: -------------------------------------------------------------------------------- 1 | ~ 2 | 3 | 7 4 | 5 | ) 6 | 7 | 5 8 | } 9 | , 10 | | 11 | ] 12 | + 13 | 14 | 2 15 | \ 16 | & 17 | 6 18 | 3 19 | [ 20 | 21 | * 22 | . 23 | 24 | 25 | 26 | 4 27 | 28 | { 29 | ( 30 | b -------------------------------------------------------------------------------- /data/NL-RX-Turk/val/vocab.source: -------------------------------------------------------------------------------- 1 | them' 2 | should 3 | 4 | uppercase 5 | do 6 | choice 7 | are 8 | item 9 | symbol 10 | whixh 11 | strings 12 | does 13 | iin 14 | than 15 | just 16 | character 17 | 6 18 | & 19 | another 20 | occurring 21 | upper 22 | it 23 | minimum 24 | each 25 | at 26 | introducing 27 | upper-case 28 | an 29 | before 30 | . 31 | alphabet 32 | precedes 33 | conclude 34 | repeated 35 | vowel 36 | 37 | along 38 | person 39 | least5times 40 | amount 41 | more 42 | follow 43 | letter' 44 | none 45 | these 46 | containing 47 | nor 48 | may 49 | preceded 50 | dog 51 | occurrences 52 | a2 53 | ends 54 | they 55 | times 56 | ; 57 | show 58 | including 59 | followed 60 | , 61 | be 62 | small 63 | something 64 | 'dog 65 | alphanumeric 66 | that 67 | appear 68 | then 69 | words 70 | : 71 | 1 72 | off 73 | contating 74 | noon 75 | the 76 | vowels 77 | either 78 | letters 79 | terminating 80 | hat 81 | follwed 82 | caps 83 | ltter 84 | leading 85 | being 86 | characters 87 | don't 88 | precede 89 | as 90 | of 91 | series 92 | summary: 93 | number 94 | but 95 | order 96 | withe 97 | can 98 | starting 99 | preceding 100 | 101 | appears 102 | to 103 | i 104 | a 105 | numbers 106 | u 107 | plus 108 | 7 109 | neither 110 | items 111 | twice 112 | only 113 | following 114 | begin 115 | vowelo 116 | time 117 | 'truck 118 | present 119 | beginning 120 | liens 121 | also 122 | string 123 | string' 124 | 3 125 | without 126 | once 127 | is 128 | beffore 129 | lower 130 | that's 131 | least6times 132 | beginnig 133 | occurs 134 | or 135 | thrice 136 | letter 137 | wherein 138 | word 139 | 0 140 | 141 | no 142 | ending 143 | ' 144 | where 145 | end 146 | capital 147 | occur 148 | 4 149 | about 150 | well 151 | beforehand 152 | comes 153 | all 154 | different 155 | contained 156 | by 157 | and 158 | have 159 | have5or 160 | zero 161 | 5 162 | lowercase 163 | must 164 | contain 165 | both 166 | lines 167 | 2 168 | not 169 | there 170 | which 171 | greater 172 | o 173 | lower-case 174 | line 175 | - 176 | after 177 | having 178 | coming 179 | with 180 | any 181 | appearances 182 | long 183 | least 184 | truck 185 | case 186 | starts 187 | prior 188 | so 189 | them 190 | start 191 | come 192 | stating 193 | in 194 | include 195 | e 196 | less 197 | numbering 198 | ring 199 | lease 200 | case-letters 201 | numeral -------------------------------------------------------------------------------- /data/NL-RX-Turk/val/vocab.target: -------------------------------------------------------------------------------- 1 | ~ 2 | 3 | 7 4 | 5 | ) 6 | 7 | | 8 | } 9 | , 10 | ] 11 | 5 12 | + 13 | 14 | 2 15 | \ 16 | & 17 | 6 18 | 3 19 | [ 20 | 21 | * 22 | . 23 | 24 | 25 | 26 | 4 27 | 28 | { 29 | ( 30 | b -------------------------------------------------------------------------------- /random_generat.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/random_generat.jar -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # SoftRegex: Generating Regex from Natural Language Descriptions using Softened Regex Equivalence 2 | 3 | This is implemantation of the paper [SoftRegex: Generating Regex from Natural Language Descriptions using Softened Regex Equivalence](https://www.aclweb.org/anthology/D19-1677/) [EMNLP 2019] 4 | 5 | -------------- 6 | 7 | ## Summary 8 | 9 | We continue the study of generating semantically correct regular expressions from natural language descriptions (NL). The current state-of-the-art model SemRegex produces regular expressions from NLs by rewarding the reinforced learning based on the semantic (rather than syntactic) equivalence between two regular expressions. Since the regular expression equivalence problem is PSPACE-complete, we introduce the EQ_Reg model for computing the similarity of two regular expressions using deep neural networks. Our EQ_Reg model essentially softens the equivalence of two regular expressions when used as a reward function. We then propose a new regex generation model, SoftRegex, using the EQ_Reg model, and empirically demonstrate that SoftRegex substantially reduces the training time (by a factor of at least 3.6) and produces state-of-the-art results on three benchmark datasets. 10 | 11 | -------------- 12 | 13 | ### Install from source 14 | pip install -r requirements.txt 15 | python setup.py install 16 | 17 | ### train 18 | dataname = kb13, NL-RX-Synth, NL-RX-Turk 19 | python softregex-train.py (dataname) 20 | python softregex-eval.py (dataname) 21 | -------------------------------------------------------------------------------- /regexDFAEquals.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import subprocess 4 | 5 | def main(arguments): 6 | parser = argparse.ArgumentParser( 7 | description=__doc__, 8 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 9 | parser.add_argument('--gold', help="gold", 10 | type=str, default="None") 11 | parser.add_argument('--predicted', help="predicted", 12 | type=str, default="None") 13 | args = parser.parse_args(arguments) 14 | if args.gold == args.predicted: 15 | print(1) 16 | return 0 17 | gold = unprocess_regex(args.gold) 18 | predicted = unprocess_regex(args.predicted) 19 | # print(gold, predicted) 20 | if regex_equiv(gold, predicted): 21 | print(1) 22 | else: 23 | print(0) 24 | 25 | def unprocess_regex(regex): 26 | # regex = regex.replace("", " ".join('[AEIOUaeiou]')) 27 | # regex = regex.replace("", " ".join('[0-9]')) 28 | # regex = regex.replace("", " ".join('[A-Za-z]')) 29 | # regex = regex.replace("", " ".join('[A-Z]')) 30 | # regex = regex.replace("", " ".join('[a-z]')) 31 | 32 | regex = regex.replace("", " ".join('AEIOUaeiou')) 33 | regex = regex.replace("", " ".join('0-9')) 34 | regex = regex.replace("", " ".join('A-Za-z')) 35 | regex = regex.replace("", " ".join('A-Z')) 36 | regex = regex.replace("", " ".join('a-z')) 37 | 38 | regex = regex.replace("", " ".join('dog')) 39 | regex = regex.replace("", " ".join('truck')) 40 | regex = regex.replace("", " ".join('ring')) 41 | regex = regex.replace("", " ".join('lake')) 42 | 43 | regex = regex.replace(" ", "") 44 | return regex 45 | 46 | def regex_equiv(gold, predicted): 47 | if gold == predicted: 48 | return True 49 | try: 50 | out = subprocess.check_output(['java', '-jar', 'regex_dfa_equals.jar', '{}'.format(gold), '{}'.format(predicted)]) 51 | # print("out: {}".format(out)) 52 | #print(out.decode('utf-8')) 53 | if '\n1' in out.decode('utf-8'): 54 | return True 55 | else: 56 | return False 57 | except Exception as e: 58 | return False 59 | return False 60 | 61 | def regex_equiv_from_raw(gold, predicted): 62 | gold = unprocess_regex(gold) 63 | predicted = unprocess_regex(predicted) 64 | return regex_equiv(gold, predicted) 65 | 66 | if __name__ == '__main__': 67 | sys.exit(main(sys.argv[1:])) -------------------------------------------------------------------------------- /regex_dfa_equals.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/regex_dfa_equals.jar -------------------------------------------------------------------------------- /regex_equal_model/compare_regex_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/regex_equal_model/compare_regex_model.pth -------------------------------------------------------------------------------- /regex_equal_model/compare_regex_model_share.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/regex_equal_model/compare_regex_model_share.pth -------------------------------------------------------------------------------- /regex_equal_model/compare_regex_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "from torch.nn import LSTM,Embedding,Linear\n", 12 | "from torch.nn import Module\n", 13 | "import torch.nn.functional as F\n", 14 | "from torch.autograd import Variable\n", 15 | "\n", 16 | "class compare_regex(torch.nn.Module):\n", 17 | " def __init__(self, vocab_size, embedding_dim, hidden_dim, target_size):\n", 18 | " super(compare_regex, self).__init__()\n", 19 | " self.hidden_dim = hidden_dim\n", 20 | " self.embedding_dim = embedding_dim\n", 21 | " self.embed = Embedding(vocab_size, embedding_dim, padding_idx=0)\n", 22 | " self.lstm1 = LSTM(embedding_dim ,hidden_dim, bidirectional=True, num_layers=1, batch_first=True)\n", 23 | " self.lstm2 = LSTM(embedding_dim, hidden_dim, bidirectional=True, num_layers=1, batch_first=True)\n", 24 | " self.fc1 = Linear(hidden_dim*2*2, 60)\n", 25 | " self.fc2 = Linear(60, 20)\n", 26 | " self.fc3 = Linear(20, target_size)\n", 27 | "\n", 28 | " \n", 29 | " def init_hidden(self, bs):\n", 30 | " if torch.cuda.is_available():\n", 31 | " return (torch.zeros(2, bs, self.hidden_dim).cuda(),\n", 32 | " torch.zeros(2, bs, self.hidden_dim).cuda())\n", 33 | " else:\n", 34 | " return (torch.zeros(2, bs, self.hidden_dim),\n", 35 | " torch.zeros(2, bs, self.hidden_dim))\n", 36 | " \n", 37 | " def forward(self, bs, line1, line2, input1_lengths,input2_lengths):\n", 38 | " embeded1 = self.embed(line1)\n", 39 | " embeded2 = self.embed(line2)\n", 40 | "# packed1 = torch.nn.utils.rnn.pack_padded_sequence(embeded1, input1_lengths, batch_first=True)\n", 41 | "# packed2 = torch.nn.utils.rnn.pack_padded_sequence(embeded2, input2_lengths, batch_first=True)\n", 42 | " hidden1 = self.init_hidden(bs)\n", 43 | " lstm1_out, last_hidden1 = self.lstm1(embeded1,hidden1)\n", 44 | " hidden2 = self.init_hidden(bs)\n", 45 | " lstm2_out, last_hidden2 = self.lstm2(embeded2,hidden2)\n", 46 | "# unpack1, unpack1_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm1_out, batch_first=True)\n", 47 | "# unpack2, unpack2_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm2_out, batch_first=True)\n", 48 | "\n", 49 | "# lstm1_last_hidden = (torch.gather(lstm1_out,1,torch.tensor(input1_lengths).cuda().expand(self.hidden_dim, 1,-1).transpose(0,2)-1)).cuda()\n", 50 | "# lstm2_last_hidden = (torch.gather(lstm2_out,1,torch.tensor(input2_lengths).cuda().expand(self.hidden_dim, 1,-1).transpose(0,2)-1)).cuda()\n", 51 | "\n", 52 | "\n", 53 | " fc1_out = self.fc1(torch.cat((lstm1_out.mean(1), lstm2_out.mean(1)),1)) #encoder outputs 평균값 concat 97.8%\n", 54 | "# fc1_out = self.fc1(lstm1_out.mean(1) * lstm2_out.mean(1)) #encoder outputs 평균값 multiple\n", 55 | "# fc1_out = self.fc1(torch.cat((lstm1_last_hidden.squeeze(1),lstm2_last_hidden.squeeze(1)), 1)) #last hidden concat 97.1%\n", 56 | "# fc1_out = self.fc1(lstm1_last_hidden.squeeze(1) * lstm2_last_hidden.squeeze(1)) #last hidden multiple\n", 57 | "\n", 58 | " \n", 59 | " fc1_out = F.tanh(fc1_out)\n", 60 | " fc2_out = self.fc2(fc1_out)\n", 61 | " fc2_out = F.tanh(fc2_out)\n", 62 | " fc3_out = self.fc3(fc2_out)\n", 63 | " score = F.log_softmax(fc3_out,dim=1)\n", 64 | " return score" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "vocab = {}\n", 74 | "f = open('./compare_vocab.txt','r')\n", 75 | "for i in f.read().splitlines():\n", 76 | " splitted = i.split('\\t')\n", 77 | " vocab[splitted[0]] = int(splitted[1])\n", 78 | "vocab_size = len(vocab)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "def make_input_seq(lines1, lines2, targets):\n", 88 | " max_len = 40\n", 89 | " lines1_seq2idx = list()\n", 90 | " lines2_seq2idx = list()\n", 91 | " targets_idx = list()\n", 92 | " lines1_seq = [s.split() for s in lines1]\n", 93 | " lines2_seq = [s.split() for s in lines2]\n", 94 | " for line_num in range(len(lines1_seq)):\n", 95 | " if len(lines1_seq[line_num]) > max_len or len(lines2_seq[line_num]) > max_len:\n", 96 | " continue\n", 97 | " lines1_padded = lines1_seq[line_num]+['']*(max_len-len(lines1_seq[line_num]))\n", 98 | " lines2_padded = lines2_seq[line_num]+['']*(max_len-len(lines2_seq[line_num]))\n", 99 | " lines1_seq2idx.append([vocab[i] for i in lines1_padded])\n", 100 | " lines2_seq2idx.append([vocab[i] for i in lines2_padded])\n", 101 | " \n", 102 | " if targets[line_num] == '0':\n", 103 | " targets_idx.append([1,0])\n", 104 | " else:\n", 105 | " targets_idx.append([0,1])\n", 106 | " if torch.cuda.is_available():\n", 107 | " return torch.LongTensor(lines1_seq2idx).cuda(), torch.LongTensor(lines2_seq2idx).cuda(), torch.LongTensor(targets_idx).cuda()\n", 108 | " else:\n", 109 | " return torch.LongTensor(lines1_seq2idx), torch.LongTensor(lines2_seq2idx), torch.LongTensor(targets_idx)\n", 110 | " \n", 111 | "def evaluate_test(model, test_input1, test_input2 , test_target):\n", 112 | " correct = 0\n", 113 | " print(len(test_target))\n", 114 | " tp=0\n", 115 | " tn=0\n", 116 | " fp=0\n", 117 | " fn=0\n", 118 | " count=0\n", 119 | " for i in range(len(test_input1)):\n", 120 | " count+=1\n", 121 | " test_input1_len = torch.tensor([torch.max(test_input1[i].data.nonzero()+1)])\n", 122 | " test_input2_len = torch.tensor([torch.max(test_input2[i].data.nonzero()+1)])\n", 123 | " score = model(1, test_input1[i].unsqueeze(0), test_input2[i].unsqueeze(0) , test_input1_len.tolist(), test_input2_len.tolist())\n", 124 | " if score.argmax().item() == 1 and test_target[i].argmax().item()==1:\n", 125 | " tp+=1\n", 126 | " elif score.argmax().item() == 0 and test_target[i].argmax().item()==0:\n", 127 | " tn+=1\n", 128 | " elif score.argmax().item() == 1 and test_target[i].argmax().item()==0:\n", 129 | " fp+=1\n", 130 | " elif score.argmax().item() == 0 and test_target[i].argmax().item()==1:\n", 131 | " fn+=1\n", 132 | " if score.argmax().item() == test_target[i].argmax().item():\n", 133 | " correct += 1\n", 134 | " precision = tp/(tp+fp)\n", 135 | " recall = tp/(tp+fn)\n", 136 | " f1_score = 2*((precision*recall)/(precision+recall))\n", 137 | " print('precision: {},recall: {},f1 score:{}'.format(precision,recall,f1_score))\n", 138 | " print('total: {}, correct: {}'.format(len(test_target), correct))\n", 139 | " return correct/len(test_target)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "compare_regex_model = torch.load('./compare_regex_model.pth')" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "# f = open('../pair_data/test_data.txt','r')\n", 158 | "f = open('../pair_data/data_pairs_test(4_depth).txt','r')\n", 159 | "# f = open('../pair_data/data_pairs_test(5_depth).txt','r')\n", 160 | "\n", 161 | "\n", 162 | "total_set = set()\n", 163 | "lines1 = list()\n", 164 | "lines2 = list()\n", 165 | "targets = list()\n", 166 | "\n", 167 | "# count = 0\n", 168 | "# for line in f.read().splitlines():\n", 169 | "# count += 1\n", 170 | "# total_set.add(line)\n", 171 | "# splitted = line.split('\\t')\n", 172 | "# print(count)\n", 173 | "\n", 174 | "count = 0\n", 175 | "for line in f.read().splitlines():\n", 176 | " count += 1\n", 177 | " splitted = line.split('\\t')\n", 178 | " lines1.append(splitted[0])\n", 179 | " lines2.append(splitted[1])\n", 180 | " targets.append(splitted[2])\n", 181 | "\n", 182 | "test_input1, test_input2, test_targets = make_input_seq(lines1, lines2, targets)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "with torch.no_grad():\n", 192 | " print('test acc: {}'.format(evaluate_test(compare_regex_model, test_input1, test_input2, test_targets)))" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [] 208 | } 209 | ], 210 | "metadata": { 211 | "kernelspec": { 212 | "display_name": "Python 3", 213 | "language": "python", 214 | "name": "python3" 215 | }, 216 | "language_info": { 217 | "codemirror_mode": { 218 | "name": "ipython", 219 | "version": 3 220 | }, 221 | "file_extension": ".py", 222 | "mimetype": "text/x-python", 223 | "name": "python", 224 | "nbconvert_exporter": "python", 225 | "pygments_lexer": "ipython3", 226 | "version": "3.5.2" 227 | } 228 | }, 229 | "nbformat": 4, 230 | "nbformat_minor": 2 231 | } 232 | -------------------------------------------------------------------------------- /regex_equal_model/compare_vocab.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 2 3 | b 13 4 | [ 3 5 | ( 14 6 | , 4 7 | 6 8 | . 16 9 | \ 17 10 | 18 11 | * 19 12 | 12 13 | { 5 14 | } 21 15 | 22 16 | ] 24 17 | 0 18 | | 25 19 | + 7 20 | 6 8 21 | 27 22 | & 26 23 | 28 24 | 23 25 | 4 9 26 | ~ 29 27 | 7 31 28 | 3 10 29 | ) 11 30 | 5 15 31 | 30 32 | 2 20 33 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | dill 4 | tqdm 5 | torchtext 6 | -------------------------------------------------------------------------------- /seq2seq/__init__.py: -------------------------------------------------------------------------------- 1 | src_field_name = 'src' 2 | tgt_field_name = 'tgt' 3 | -------------------------------------------------------------------------------- /seq2seq/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/__init__.pyc -------------------------------------------------------------------------------- /seq2seq/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .fields import SourceField, TargetField 2 | -------------------------------------------------------------------------------- /seq2seq/dataset/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/dataset/__init__.pyc -------------------------------------------------------------------------------- /seq2seq/dataset/fields.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torchtext 4 | 5 | class SourceField(torchtext.data.Field): 6 | """ Wrapper class of torchtext.data.Field that forces batch_first and include_lengths to be True. """ 7 | 8 | def __init__(self, **kwargs): 9 | logger = logging.getLogger(__name__) 10 | 11 | if kwargs.get('batch_first') is False: 12 | logger.warning("Option batch_first has to be set to use pytorch-seq2seq. Changed to True.") 13 | kwargs['batch_first'] = True 14 | if kwargs.get('include_lengths') is False: 15 | logger.warning("Option include_lengths has to be set to use pytorch-seq2seq. Changed to True.") 16 | kwargs['include_lengths'] = True 17 | 18 | super(SourceField, self).__init__(**kwargs) 19 | 20 | class TargetField(torchtext.data.Field): 21 | """ Wrapper class of torchtext.data.Field that forces batch_first to be True and prepend and append to sequences in preprocessing step. 22 | 23 | Attributes: 24 | sos_id: index of the start of sentence symbol 25 | eos_id: index of the end of sentence symbol 26 | """ 27 | 28 | SYM_SOS = '' 29 | SYM_EOS = '' 30 | 31 | def __init__(self, **kwargs): 32 | logger = logging.getLogger(__name__) 33 | 34 | if kwargs.get('batch_first') == False: 35 | logger.warning("Option batch_first has to be set to use pytorch-seq2seq. Changed to True.") 36 | kwargs['batch_first'] = True 37 | if kwargs.get('preprocessing') is None: 38 | kwargs['preprocessing'] = lambda seq: [self.SYM_SOS] + seq + [self.SYM_EOS] 39 | else: 40 | func = kwargs['preprocessing'] 41 | kwargs['preprocessing'] = lambda seq: [self.SYM_SOS] + func(seq) + [self.SYM_EOS] 42 | 43 | self.sos_id = None 44 | self.eos_id = None 45 | super(TargetField, self).__init__(**kwargs) 46 | 47 | def build_vocab(self, *args, **kwargs): 48 | super(TargetField, self).build_vocab(*args, **kwargs) 49 | self.sos_id = self.vocab.stoi[self.SYM_SOS] 50 | self.eos_id = self.vocab.stoi[self.SYM_EOS] 51 | -------------------------------------------------------------------------------- /seq2seq/dataset/fields.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/dataset/fields.pyc -------------------------------------------------------------------------------- /seq2seq/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluator import Evaluator 2 | from .predictor import Predictor 3 | -------------------------------------------------------------------------------- /seq2seq/evaluator/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/evaluator/__init__.pyc -------------------------------------------------------------------------------- /seq2seq/evaluator/evaluator.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import torch 4 | import torchtext 5 | 6 | import seq2seq 7 | from seq2seq.loss import NLLLoss, PositiveLoss 8 | 9 | class Evaluator(object): 10 | """ Class to evaluate models with given datasets. 11 | 12 | Args: 13 | loss (seq2seq.loss, optional): loss for evaluator (default: seq2seq.loss.NLLLoss) 14 | batch_size (int, optional): batch size for evaluator (default: 64) 15 | """ 16 | 17 | def __init__(self, loss=NLLLoss(), batch_size=64, output_vocab=None): 18 | self.loss = loss 19 | self.batch_size = batch_size 20 | self.output_vocab = output_vocab 21 | 22 | def evaluate(self, model, data): 23 | """ Evaluate a model on given dataset and return performance. 24 | 25 | Args: 26 | model (seq2seq.models): model to evaluate 27 | data (seq2seq.dataset.dataset.Dataset): dataset to evaluate against 28 | 29 | Returns: 30 | loss (float): loss of the given model on the given dataset 31 | """ 32 | model.eval() 33 | 34 | loss = self.loss 35 | loss.reset() 36 | match = 0 37 | total = 0 38 | 39 | device = None if torch.cuda.is_available() else -1 40 | batch_iterator = torchtext.data.BucketIterator( 41 | dataset=data, batch_size=self.batch_size, 42 | sort=True, sort_key=lambda x: len(x.src), 43 | device=device, train=False) 44 | tgt_vocab = data.fields[seq2seq.tgt_field_name].vocab 45 | pad = tgt_vocab.stoi[data.fields[seq2seq.tgt_field_name].pad_token] 46 | 47 | with torch.no_grad(): 48 | for batch in batch_iterator: 49 | input_variables, input_lengths = getattr(batch, seq2seq.src_field_name) 50 | target_variables = getattr(batch, seq2seq.tgt_field_name) 51 | 52 | decoder_outputs, decoder_hidden, other = model(input_variables, input_lengths.tolist(), target_variables) 53 | 54 | # Evaluation 55 | seqlist = other['sequence'] 56 | for step, step_output in enumerate(decoder_outputs): 57 | target = target_variables[:, step + 1] 58 | loss.eval_batch(step_output.view(target_variables.size(0), -1), target) 59 | 60 | non_padding = target.ne(pad) 61 | correct = seqlist[step].view(-1).eq(target).masked_select(non_padding).sum().item() 62 | match += correct 63 | total += non_padding.sum().item() 64 | 65 | if total == 0: 66 | accuracy = float('nan') 67 | else: 68 | accuracy = match / total 69 | 70 | return loss.get_loss(), accuracy 71 | 72 | def evaluate_reward(self, model, data, vocab): 73 | """ Evaluate a model on given dataset and return performance. 74 | 75 | Args: 76 | model (seq2seq.models): model to evaluate 77 | data (seq2seq.dataset.dataset.Dataset): dataset to evaluate against 78 | 79 | Returns: 80 | loss (float): loss of the given model on the given dataset 81 | """ 82 | model.eval() 83 | 84 | loss = PositiveLoss() 85 | loss.reset() 86 | 87 | device = None if torch.cuda.is_available() else -1 88 | batch_iterator = torchtext.data.BucketIterator( 89 | dataset=data, batch_size=self.batch_size, 90 | sort=True, sort_key=lambda x: len(x.src), 91 | device=device, train=False) 92 | tgt_vocab = data.fields[seq2seq.tgt_field_name].vocab 93 | pad = tgt_vocab.stoi[data.fields[seq2seq.tgt_field_name].pad_token] 94 | 95 | with torch.no_grad(): 96 | for batch in batch_iterator: 97 | input_variable, input_lengths = getattr(batch, seq2seq.src_field_name) 98 | target_variable = getattr(batch, seq2seq.tgt_field_name) 99 | 100 | decoder_output, decoder_hidden, other = model(input_variable, input_lengths.tolist(), target_variable) 101 | 102 | seqlist = [] 103 | tensorlist = [] 104 | 105 | for step, step_output in enumerate(decoder_output): 106 | batch_size = target_variable.size(0) 107 | #loss.eval_batch(step_output.contiguous().view(batch_size, -1), target_variable[:, step + 1]) 108 | #step_output.contiguous().view(batch_size, -1) 109 | tensorlist.append(torch.max(step_output, dim=1)[0]) 110 | seqlist.append(torch.max(step_output, dim=1)[1]) 111 | 112 | log_tensor = torch.stack(tensorlist, dim=1) 113 | output_tensor = torch.stack(seqlist, dim=1) 114 | 115 | loss.eval_batch(log_tensor, output_tensor, target_variable, vocab) 116 | 117 | return loss.get_loss() 118 | -------------------------------------------------------------------------------- /seq2seq/evaluator/evaluator.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/evaluator/evaluator.pyc -------------------------------------------------------------------------------- /seq2seq/evaluator/predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | 5 | class Predictor(object): 6 | 7 | def __init__(self, model, src_vocab, tgt_vocab): 8 | """ 9 | Predictor class to evaluate for a given model. 10 | Args: 11 | model (seq2seq.models): trained model. This can be loaded from a checkpoint 12 | using `seq2seq.util.checkpoint.load` 13 | src_vocab (seq2seq.dataset.vocabulary.Vocabulary): source sequence vocabulary 14 | tgt_vocab (seq2seq.dataset.vocabulary.Vocabulary): target sequence vocabulary 15 | """ 16 | if torch.cuda.is_available(): 17 | self.model = model.cuda() 18 | else: 19 | self.model = model.cpu() 20 | self.model.eval() 21 | self.src_vocab = src_vocab 22 | self.tgt_vocab = tgt_vocab 23 | 24 | def get_decoder_features(self, src_seq): 25 | src_id_seq = torch.LongTensor([self.src_vocab.stoi[tok] for tok in src_seq]).view(1, -1) 26 | if torch.cuda.is_available(): 27 | src_id_seq = src_id_seq.cuda() 28 | 29 | with torch.no_grad(): 30 | softmax_list, _, other = self.model(src_id_seq, [len(src_seq)]) 31 | 32 | return other 33 | 34 | def predict(self, src_seq): 35 | """ Make prediction given `src_seq` as input. 36 | 37 | Args: 38 | src_seq (list): list of tokens in source language 39 | 40 | Returns: 41 | tgt_seq (list): list of tokens in target language as predicted 42 | by the pre-trained model 43 | """ 44 | other = self.get_decoder_features(src_seq) 45 | 46 | length = other['length'][0] 47 | 48 | tgt_id_seq = [other['sequence'][di][0].data[0] for di in range(length)] 49 | tgt_seq = [self.tgt_vocab.itos[tok] for tok in tgt_id_seq] 50 | return tgt_seq 51 | 52 | def predict_n(self, src_seq, n=1): 53 | """ Make 'n' predictions given `src_seq` as input. 54 | 55 | Args: 56 | src_seq (list): list of tokens in source language 57 | n (int): number of predicted seqs to return. If None, 58 | it will return just one seq. 59 | 60 | Returns: 61 | tgt_seq (list): list of tokens in target language as predicted 62 | by the pre-trained model 63 | """ 64 | other = self.get_decoder_features(src_seq) 65 | 66 | result = [] 67 | for x in range(0, int(n)): 68 | length = other['topk_length'][0][x] 69 | tgt_id_seq = [other['topk_sequence'][di][0, x, 0].data[0] for di in range(length)] 70 | tgt_seq = [self.tgt_vocab.itos[tok] for tok in tgt_id_seq] 71 | result.append(tgt_seq) 72 | 73 | return result 74 | -------------------------------------------------------------------------------- /seq2seq/evaluator/predictor.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/evaluator/predictor.pyc -------------------------------------------------------------------------------- /seq2seq/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import NLLLoss, Perplexity, PositiveLoss 2 | -------------------------------------------------------------------------------- /seq2seq/loss/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/loss/__init__.pyc -------------------------------------------------------------------------------- /seq2seq/loss/loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import math 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | import torch 7 | 8 | class Loss(object): 9 | """ Base class for encapsulation of the loss functions. 10 | 11 | This class defines interfaces that are commonly used with loss functions 12 | in training and inferencing. For information regarding individual loss 13 | functions, please refer to http://pytorch.org/docs/master/nn.html#loss-functions 14 | 15 | Note: 16 | Do not use this class directly, use one of the sub classes. 17 | 18 | Args: 19 | name (str): name of the loss function used by logging messages. 20 | criterion (torch.nn._Loss): one of PyTorch's loss function. Refer 21 | to http://pytorch.org/docs/master/nn.html#loss-functions for 22 | a list of them. 23 | 24 | Attributes: 25 | name (str): name of the loss function used by logging messages. 26 | criterion (torch.nn._Loss): one of PyTorch's loss function. Refer 27 | to http://pytorch.org/docs/master/nn.html#loss-functions for 28 | a list of them. Implementation depends on individual 29 | sub-classes. 30 | acc_loss (int or torcn.nn.Tensor): variable that stores accumulated loss. 31 | norm_term (float): normalization term that can be used to calculate 32 | the loss of multiple batches. Implementation depends on individual 33 | sub-classes. 34 | """ 35 | 36 | def __init__(self, name, criterion): 37 | self.name = name 38 | self.criterion = criterion 39 | if not issubclass(type(self.criterion), nn.modules.loss._Loss): 40 | raise ValueError("Criterion has to be a subclass of torch.nn._Loss") 41 | # accumulated loss 42 | self.acc_loss = 0 43 | # normalization term 44 | self.norm_term = 0 45 | 46 | def reset(self): 47 | """ Reset the accumulated loss. """ 48 | self.acc_loss = 0 49 | self.norm_term = 0 50 | 51 | def get_loss(self): 52 | """ Get the loss. 53 | 54 | This method defines how to calculate the averaged loss given the 55 | accumulated loss and the normalization term. Override to define your 56 | own logic. 57 | 58 | Returns: 59 | loss (float): value of the loss. 60 | """ 61 | raise NotImplementedError 62 | 63 | def eval_batch(self, outputs, target): 64 | """ Evaluate and accumulate loss given outputs and expected results. 65 | 66 | This method is called after each batch with the batch outputs and 67 | the target (expected) results. The loss and normalization term are 68 | accumulated in this method. Override it to define your own accumulation 69 | method. 70 | 71 | Args: 72 | outputs (torch.Tensor): outputs of a batch. 73 | target (torch.Tensor): expected output of a batch. 74 | """ 75 | raise NotImplementedError 76 | 77 | def cuda(self): 78 | self.criterion.cuda() 79 | 80 | def backward(self): 81 | if type(self.acc_loss) is int: 82 | raise ValueError("No loss to back propagate.") 83 | self.acc_loss.backward() 84 | 85 | 86 | import subprocess 87 | 88 | def score_by_example(gold, predicted, num_examples=1): 89 | #print(gold) 90 | # print(predicted) 91 | 92 | gold = unprocess_regex(gold) 93 | predicted = unprocess_regex(predicted) 94 | 95 | if gold == predicted: 96 | return 1 97 | try: 98 | count = 0 99 | for i in range(num_examples): 100 | example = subprocess.check_output(['java', '-jar', 'random_generate.jar', '{}'.format(gold)]) 101 | result = subprocess.check_output(['java', '-jar', 'membership.jar', '{}'.format(predicted), '{}'.format(example[:-1].decode('utf-8'))]) 102 | #print(gold, result, example[:-1], predicted) 103 | if result == b'true\n': 104 | count = count + 1 105 | 106 | return count / num_examples 107 | 108 | except Exception as e: 109 | return 0 110 | return 0 111 | 112 | def score_by_oracle(gold, predicted): 113 | if gold == predicted: 114 | return 1 115 | try: 116 | score = int(subprocess.check_output(['python2', 'regexDFAEquals.py', '--gold', '{}'.format(gold), '--predicted', '{}'.format(predicted)], timeout=0.01).decode('utf-8')[0]) 117 | except subprocess.TimeoutExpired: 118 | return 0 119 | return score 120 | 121 | def to_pad(a,max_len=40): 122 | sh = a.shape 123 | sh = torch.Size([sh[0], 60-sh[1]]) 124 | return torch.cat((a, torch.ones(sh, device='cuda').type(torch.cuda.LongTensor)), dim=1)[:,:max_len] 125 | 126 | 127 | 128 | def score_by_probability_batch(golds,predicts, prob_model=None,vocab=None): 129 | max_len=40 130 | predicts=to_pad(predicts) 131 | golds=to_pad(golds) 132 | score = prob_model(len(golds), predicts.type(torch.LongTensor).cuda(), golds.type(torch.LongTensor).cuda(), (predicts!=1).sum(1).tolist(), (golds!=1).sum(1).tolist()) 133 | score_list = (torch.max(torch.exp(score),1)[0].float()*torch.max(torch.exp(score),1)[1].float()).tolist() 134 | return score_list 135 | # except: 136 | # return [0 for i in range(len(golds))] 137 | 138 | def score_by_probability(gold, predicted, prob_model = None, vocab = None): 139 | if gold == predicted: 140 | return 1 141 | gold = gold.split(' ') 142 | gold_len = len(gold) 143 | predicted = predicted.split(' ') 144 | predicted = [x for x in predicted if x] 145 | predicted_len = len(predicted) 146 | try: 147 | max_len = 40 148 | if len(gold) > max_len or len(predicted) > max_len: 149 | return 0 150 | gold_padded = gold+['']*(max_len-len(gold)) 151 | gold_idx_input = [[vocab[i] for i in gold_padded]] 152 | 153 | predicted_padded = predicted+['']*(max_len-len(predicted)) 154 | predicted_idx_input = [[vocab[i] for i in predicted_padded]] 155 | 156 | score = prob_model(1, torch.LongTensor(predicted_idx_input).cuda(), torch.LongTensor(gold_idx_input).cuda(), [predicted_len], [gold_len]) 157 | score = math.exp(score[0][1]) 158 | except KeyError: 159 | score = 0 160 | if score-0.5 > 0: 161 | score = score 162 | else: 163 | score=0 164 | # print("score_by oracle probability : {}, predicted: {}, score: {}".format(''.join(gold), ''.join(predicted), score)) 165 | return score 166 | 167 | def score_by_pos_and_neg_example(gold, predicted, num_examples=10): 168 | # print(gold) 169 | # print(predicted) 170 | if gold == predicted: 171 | return 1 172 | gold_tmp = unprocess_regex(gold) 173 | predicted = unprocess_regex(predicted) 174 | gold = gold_tmp 175 | neg_gold = unprocess_regex( '~(' + gold_tmp + ')'+'& (|)') 176 | 177 | try: 178 | for i in range(num_examples): 179 | pos_example = subprocess.check_output(['java', '-jar', 'random_generate.jar', '{}'.format(gold)]) 180 | if pos_example == b'\n': 181 | return 1 182 | pos_result = subprocess.check_output(['java', '-jar', 'membership.jar', '{}'.format(predicted), '{}'.format(pos_example[:-1].decode('utf-8'))]) 183 | if pos_result == b'false\n': 184 | return 0 185 | neg_example = subprocess.check_output(['java', '-jar', 'random_generate.jar', '{}'.format(neg_gold)]) 186 | if neg_example == b'\n': 187 | return 0 188 | neg_result = subprocess.check_output(['java', '-jar', 'membership.jar', '{}'.format(predicted), '{}'.format(neg_example[:-1].decode('utf-8'))]) 189 | # print(result, example[:-1], predicted) 190 | if neg_result == b'true\n': 191 | return 0 192 | return 1 193 | except Exception as e: 194 | return 0 195 | return 0 196 | 197 | def refine_outout(regex): 198 | par_list = [] 199 | word_list = regex.split() 200 | 201 | for idx, word in enumerate(word_list): 202 | if word == '(' or word == '[' or word == '{': 203 | par_list.append(word) 204 | 205 | if word == ')' or word == ']' or word == '}': 206 | if len(par_list) == 0: 207 | word_list[idx] = '' 208 | continue 209 | 210 | par_in_list = par_list.pop() 211 | if par_in_list == '(': 212 | word_list[idx] = ')' 213 | elif par_in_list == '[': 214 | word_list[idx] = ']' 215 | elif par_in_list == '{': 216 | word_list[idx] = '}' 217 | 218 | while len(par_list) != 0: 219 | par_in_list = par_list.pop() 220 | if par_in_list == '(': 221 | word_list.append(')') 222 | elif par_in_list == '[': 223 | word_list.append(']') 224 | elif par_in_list == '{': 225 | word_list.append('}') 226 | 227 | word_list = [word for word in word_list if word != ''] 228 | 229 | return ' '.join(word_list) 230 | 231 | def check_dfa_equality(gold, predicted): 232 | try: 233 | result = subprocess.check_output(['python2', 'regexDFAEquals.py', '--gold', '{}'.format(gold), '--predicted', '{}'.format(predicted)], timeout=5) 234 | except subprocess.TimeoutExpired as exc: 235 | print("Command timeout: {}".format(exc)) 236 | return 0 237 | else: 238 | if str(result)[2] == '1': 239 | return 0.5 240 | else: 241 | return -1.0 242 | 243 | def unprocess_regex(regex): 244 | # regex = regex.replace("", " ".join('[AEIOUaeiou]')) 245 | # regex = regex.replace("", " ".join('[0-9]')) 246 | # regex = regex.replace("", " ".join('[A-Za-z]')) 247 | # regex = regex.replace("", " ".join('[A-Z]')) 248 | # regex = regex.replace("", " ".join('[a-z]')) 249 | 250 | regex = regex.replace("", " ".join('AEIOUaeiou')) 251 | regex = regex.replace("", " ".join('0-9')) 252 | regex = regex.replace("", " ".join('A-Za-z')) 253 | regex = regex.replace("", " ".join('A-Z')) 254 | regex = regex.replace("", " ".join('a-z')) 255 | 256 | regex = regex.replace("", " ".join('dog')) 257 | regex = regex.replace("", " ".join('truck')) 258 | regex = regex.replace("", " ".join('ring')) 259 | regex = regex.replace("", " ".join('lake')) 260 | 261 | regex = regex.replace(" ", "") 262 | 263 | return regex 264 | 265 | 266 | class SelfCriticalLoss(object): 267 | """ Base class for encapsulation of the loss functions. 268 | 269 | This class defines interfaces that are commonly used with loss functions 270 | in training and inferencing. For information regarding individual loss 271 | functions, please refer to http://pytorch.org/docs/master/nn.html#loss-functions 272 | 273 | Note: 274 | Do not use this class directly, use one of the sub classes. 275 | 276 | Args: 277 | name (str): name of the loss function used by logging messages. 278 | criterion (torch.nn._Loss): one of PyTorch's loss function. Refer 279 | to http://pytorch.org/docs/master/nn.html#loss-functions for 280 | a list of them. 281 | 282 | Attributes: 283 | name (str): name of the loss function used by logging messages. 284 | criterion (torch.nn._Loss): one of PyTorch's loss function. Refer 285 | to http://pytorch.org/docs/master/nn.html#loss-functions for 286 | a list of them. Implementation depends on individual 287 | sub-classes. 288 | acc_loss (int or torcn.nn.Tensor): variable that stores accumulated loss. 289 | norm_term (float): normalization term that can be used to calculate 290 | the loss of multiple batches. Implementation depends on individual 291 | sub-classes. 292 | """ 293 | 294 | def __init__(self, name): 295 | self.name = name 296 | # accumulated loss 297 | self.acc_loss = 0 298 | # normalization term 299 | self.norm_term = 0 300 | 301 | def reset(self): 302 | """ Reset the accumulated loss. """ 303 | self.acc_loss = 0 304 | self.norm_term = 0 305 | 306 | def get_loss(self): 307 | """ Get the loss. 308 | 309 | This method defines how to calculate the averaged loss given the 310 | accumulated loss and the normalization term. Override to define your 311 | own logic. 312 | 313 | Returns: 314 | loss (float): value of the loss. 315 | """ 316 | raise NotImplementedError 317 | 318 | def eval_batch(self, outputs, target): 319 | """ Evaluate and accumulate loss given outputs and expected results. 320 | 321 | This method is called after each batch with the batch outputs and 322 | the target (expected) results. The loss and normalization term are 323 | accumulated in this method. Override it to define your own accumulation 324 | method. 325 | 326 | Args: 327 | outputs (torch.Tensor): outputs of a batch. 328 | target (torch.Tensor): expected output of a batch. 329 | """ 330 | raise NotImplementedError 331 | 332 | 333 | def backward(self): 334 | if type(self.acc_loss) is int: 335 | raise ValueError("No loss to back propagate.") 336 | self.acc_loss.backward() 337 | 338 | class PositiveLoss(SelfCriticalLoss): 339 | """ Compute the reward of generated sequecnces by checking the acceptance of positive random strings 340 | 341 | Args: 342 | """ 343 | 344 | _NAME = "Random Positive Acceptance Reward" 345 | 346 | def __init__(self,mode = None, prob_model=None,loss_vocab=None): 347 | self.prob_model = prob_model 348 | self.loss_vocab = loss_vocab 349 | self.mode = mode 350 | super(PositiveLoss, self).__init__( 351 | self._NAME) 352 | 353 | def get_loss(self): 354 | if isinstance(self.acc_loss, int): 355 | return 0 356 | # total loss for all batches 357 | loss = self.acc_loss.data.item() 358 | return loss 359 | 360 | def decode_tensor(self, tensor): 361 | wordlist = [self.vocab.itos[tok] for tok in tensor if self.vocab.itos[tok] != '' and self.vocab.itos[tok] != '' and self.vocab.itos[tok] != ''] 362 | return ' '.join(wordlist) 363 | 364 | def eval_batch(self, logs, outputs, targets, vocab): 365 | self.vocab = vocab 366 | 367 | reward = [] 368 | 369 | if self.mode == 'prob': 370 | for i in range(logs.shape[0]): 371 | reward.append(score_by_probability(self.decode_tensor(targets[i]), self.decode_tensor(outputs[i]), prob_model = self.prob_model, vocab=self.loss_vocab)) 372 | elif self.mode == 'dis': 373 | for i in range(logs.shape[0]): 374 | reward.append(score_by_pos_and_neg_example(self.decode_tensor(targets[i]), self.decode_tensor(outputs[i]))) 375 | else: 376 | for i in range(logs.shape[0]): 377 | reward.append(score_by_oracle(self.decode_tensor(targets[i]), self.decode_tensor(outputs[i]))) 378 | 379 | reward_matrix = torch.from_numpy(np.repeat(np.array(reward)[:, np.newaxis], outputs.shape[1], 1)).float().cuda() 380 | mask = (outputs>1).float().cuda() 381 | logs = logs.cuda() 382 | 383 | self.acc_loss += torch.sum(- logs * reward_matrix * mask) 384 | self.norm_term += 1 385 | 386 | 387 | class NLLLoss(Loss): 388 | """ Batch averaged negative log-likelihood loss. 389 | 390 | Args: 391 | weight (torch.Tensor, optional): refer to http://pytorch.org/docs/master/nn.html#nllloss 392 | mask (int, optional): index of masked token, i.e. weight[mask] = 0. 393 | size_average (bool, optional): refer to http://pytorch.org/docs/master/nn.html#nllloss 394 | """ 395 | 396 | _NAME = "Avg NLLLoss" 397 | 398 | def __init__(self, weight=None, mask=None, size_average=True): 399 | self.mask = mask 400 | self.size_average = size_average 401 | if mask is not None: 402 | if weight is None: 403 | raise ValueError("Must provide weight with a mask.") 404 | weight[mask] = 0 405 | 406 | super(NLLLoss, self).__init__( 407 | self._NAME, 408 | nn.NLLLoss(weight=weight, size_average=size_average)) 409 | 410 | def get_loss(self): 411 | if isinstance(self.acc_loss, int): 412 | return 0 413 | # total loss for all batches 414 | loss = self.acc_loss.data.item() 415 | if self.size_average: 416 | # average loss per batch 417 | loss /= self.norm_term 418 | return loss 419 | 420 | def eval_batch(self, outputs, target): 421 | target=target.to('cuda') 422 | self.acc_loss += self.criterion(outputs, target) 423 | self.norm_term += 1 424 | 425 | class Perplexity(NLLLoss): 426 | """ Language model perplexity loss. 427 | 428 | Perplexity is the token averaged likelihood. When the averaging options are the 429 | same, it is the exponential of negative log-likelihood. 430 | 431 | Args: 432 | weight (torch.Tensor, optional): refer to http://pytorch.org/docs/master/nn.html#nllloss 433 | mask (int, optional): index of masked token, i.e. weight[mask] = 0. 434 | """ 435 | 436 | _NAME = "Perplexity" 437 | _MAX_EXP = 100 438 | 439 | def __init__(self, weight=None, mask=None): 440 | super(Perplexity, self).__init__(weight=weight, mask=mask, size_average=False) 441 | 442 | def eval_batch(self, outputs, target): 443 | self.acc_loss += self.criterion(outputs, target) 444 | if self.mask is None: 445 | self.norm_term += np.prod(target.size()) 446 | else: 447 | self.norm_term += target.data.ne(self.mask).sum() 448 | 449 | def get_loss(self): 450 | nll = super(Perplexity, self).get_loss() 451 | nll /= self.norm_term.item() 452 | if nll > Perplexity._MAX_EXP: 453 | print("WARNING: Loss exceeded maximum value, capping to e^100") 454 | return math.exp(Perplexity._MAX_EXP) 455 | return math.exp(nll) -------------------------------------------------------------------------------- /seq2seq/loss/loss.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/loss/loss.pyc -------------------------------------------------------------------------------- /seq2seq/models/DecoderRNN.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | 10 | from .attention import Attention 11 | from .baseRNN import BaseRNN 12 | 13 | if torch.cuda.is_available(): 14 | import torch.cuda as device 15 | else: 16 | import torch as device 17 | 18 | 19 | class DecoderRNN(BaseRNN): 20 | r""" 21 | Provides functionality for decoding in a seq2seq framework, with an option for attention. 22 | 23 | Args: 24 | vocab_size (int): size of the vocabulary 25 | max_len (int): a maximum allowed length for the sequence to be processed 26 | hidden_size (int): the number of features in the hidden state `h` 27 | sos_id (int): index of the start of sentence symbol 28 | eos_id (int): index of the end of sentence symbol 29 | n_layers (int, optional): number of recurrent layers (default: 1) 30 | rnn_cell (str, optional): type of RNN cell (default: gru) 31 | bidirectional (bool, optional): if the encoder is bidirectional (default False) 32 | input_dropout_p (float, optional): dropout probability for the input sequence (default: 0) 33 | dropout_p (float, optional): dropout probability for the output sequence (default: 0) 34 | use_attention(bool, optional): flag indication whether to use attention mechanism or not (default: false) 35 | 36 | Attributes: 37 | KEY_ATTN_SCORE (str): key used to indicate attention weights in `ret_dict` 38 | KEY_LENGTH (str): key used to indicate a list representing lengths of output sequences in `ret_dict` 39 | KEY_SEQUENCE (str): key used to indicate a list of sequences in `ret_dict` 40 | 41 | Inputs: inputs, encoder_hidden, encoder_outputs, function, teacher_forcing_ratio 42 | - **inputs** (batch, seq_len, input_size): list of sequences, whose length is the batch size and within which 43 | each sequence is a list of token IDs. It is used for teacher forcing when provided. (default `None`) 44 | - **encoder_hidden** (num_layers * num_directions, batch_size, hidden_size): tensor containing the features in the 45 | hidden state `h` of encoder. Used as the initial hidden state of the decoder. (default `None`) 46 | - **encoder_outputs** (batch, seq_len, hidden_size): tensor with containing the outputs of the encoder. 47 | Used for attention mechanism (default is `None`). 48 | - **function** (torch.nn.Module): A function used to generate symbols from RNN hidden state 49 | (default is `torch.nn.functional.log_softmax`). 50 | - **teacher_forcing_ratio** (float): The probability that teacher forcing will be used. A random number is 51 | drawn uniformly from 0-1 for every decoding token, and if the sample is smaller than the given value, 52 | teacher forcing would be used (default is 0). 53 | 54 | Outputs: decoder_outputs, decoder_hidden, ret_dict 55 | - **decoder_outputs** (seq_len, batch, vocab_size): list of tensors with size (batch_size, vocab_size) containing 56 | the outputs of the decoding function. 57 | - **decoder_hidden** (num_layers * num_directions, batch, hidden_size): tensor containing the last hidden 58 | state of the decoder. 59 | - **ret_dict**: dictionary containing additional information as follows {*KEY_LENGTH* : list of integers 60 | representing lengths of output sequences, *KEY_SEQUENCE* : list of sequences, where each sequence is a list of 61 | predicted token IDs }. 62 | """ 63 | 64 | KEY_ATTN_SCORE = 'attention_score' 65 | KEY_LENGTH = 'length' 66 | KEY_SEQUENCE = 'sequence' 67 | 68 | def __init__(self, vocab_size, max_len, hidden_size, 69 | sos_id, eos_id, 70 | n_layers=1, rnn_cell='gru', bidirectional=False, 71 | input_dropout_p=0, dropout_p=0, use_attention=False, word_embedding_size=100): 72 | super(DecoderRNN, self).__init__(vocab_size, max_len, hidden_size, 73 | input_dropout_p, dropout_p, 74 | n_layers, rnn_cell) 75 | 76 | self.word_embedding_size = word_embedding_size 77 | 78 | 79 | self.bidirectional_encoder = bidirectional 80 | self.rnn = self.rnn_cell(self.word_embedding_size, hidden_size, n_layers, batch_first=True, dropout=dropout_p) 81 | 82 | self.output_size = vocab_size 83 | self.max_length = max_len 84 | self.use_attention = use_attention 85 | self.eos_id = eos_id 86 | self.sos_id = sos_id 87 | 88 | self.init_input = None 89 | 90 | self.embedding = nn.Embedding(self.output_size, self.word_embedding_size) 91 | if use_attention: 92 | self.attention = Attention(self.hidden_size) 93 | 94 | self.out = nn.Linear(self.hidden_size, self.output_size) 95 | 96 | def forward_step(self, input_var, hidden, encoder_outputs, function): 97 | batch_size = input_var.size(0) 98 | output_size = input_var.size(1) 99 | embedded = self.embedding(input_var) 100 | embedded = self.input_dropout(embedded) 101 | 102 | output, hidden = self.rnn(embedded, hidden) 103 | 104 | attn = None 105 | if self.use_attention: 106 | output, attn = self.attention(output, encoder_outputs) 107 | 108 | predicted_softmax = function(self.out(output.contiguous().view(-1, self.hidden_size)), dim=1).view(batch_size, output_size, -1) 109 | return predicted_softmax, hidden, attn 110 | 111 | def forward(self, inputs=None, encoder_hidden=None, encoder_outputs=None, 112 | function=F.log_softmax, teacher_forcing_ratio=0): 113 | ret_dict = dict() 114 | if self.use_attention: 115 | ret_dict[DecoderRNN.KEY_ATTN_SCORE] = list() 116 | 117 | inputs, batch_size, max_length = self._validate_args(inputs, encoder_hidden, encoder_outputs, 118 | function, teacher_forcing_ratio) 119 | decoder_hidden = self._init_state(encoder_hidden) 120 | 121 | use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False 122 | 123 | decoder_outputs = [] 124 | sequence_symbols = [] 125 | lengths = np.array([max_length] * batch_size) 126 | 127 | def decode(step, step_output, step_attn): 128 | decoder_outputs.append(step_output) 129 | if self.use_attention: 130 | ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn) 131 | symbols = decoder_outputs[-1].topk(1)[1] 132 | sequence_symbols.append(symbols) 133 | 134 | eos_batches = symbols.data.eq(self.eos_id) 135 | if eos_batches.dim() > 0: 136 | eos_batches = eos_batches.cpu().view(-1).numpy() 137 | update_idx = ((lengths > step) & eos_batches) != 0 138 | lengths[update_idx] = len(sequence_symbols) 139 | return symbols 140 | 141 | # Manual unrolling is used to support random teacher forcing. 142 | # If teacher_forcing_ratio is True or False instead of a probability, the unrolling can be done in graph 143 | if use_teacher_forcing: 144 | decoder_input = inputs[:, :-1] 145 | decoder_output, decoder_hidden, attn = self.forward_step(decoder_input, decoder_hidden, encoder_outputs, 146 | function=function) 147 | 148 | for di in range(decoder_output.size(1)): 149 | step_output = decoder_output[:, di, :] 150 | if attn is not None: 151 | step_attn = attn[:, di, :] 152 | else: 153 | step_attn = None 154 | decode(di, step_output, step_attn) 155 | else: 156 | decoder_input = inputs[:, 0].unsqueeze(1) 157 | for di in range(max_length): 158 | decoder_output, decoder_hidden, step_attn = self.forward_step(decoder_input, decoder_hidden, encoder_outputs, 159 | function=function) 160 | step_output = decoder_output.squeeze(1) 161 | symbols = decode(di, step_output, step_attn) 162 | decoder_input = symbols 163 | 164 | ret_dict[DecoderRNN.KEY_SEQUENCE] = sequence_symbols 165 | ret_dict[DecoderRNN.KEY_LENGTH] = lengths.tolist() 166 | 167 | return decoder_outputs, decoder_hidden, ret_dict 168 | 169 | def _init_state(self, encoder_hidden): 170 | """ Initialize the encoder hidden state. """ 171 | if encoder_hidden is None: 172 | return None 173 | if isinstance(encoder_hidden, tuple): 174 | encoder_hidden = tuple([self._cat_directions(h) for h in encoder_hidden]) 175 | else: 176 | encoder_hidden = self._cat_directions(encoder_hidden) 177 | return encoder_hidden 178 | 179 | def _cat_directions(self, h): 180 | """ If the encoder is bidirectional, do the following transformation. 181 | (#directions * #layers, #batch, hidden_size) -> (#layers, #batch, #directions * hidden_size) 182 | """ 183 | if self.bidirectional_encoder: 184 | h = torch.cat([h[0:h.size(0):2], h[1:h.size(0):2]], 2) 185 | return h 186 | 187 | def _validate_args(self, inputs, encoder_hidden, encoder_outputs, function, teacher_forcing_ratio): 188 | if self.use_attention: 189 | if encoder_outputs is None: 190 | raise ValueError("Argument encoder_outputs cannot be None when attention is used.") 191 | 192 | # inference batch size 193 | if inputs is None and encoder_hidden is None: 194 | batch_size = 1 195 | else: 196 | if inputs is not None: 197 | batch_size = inputs.size(0) 198 | else: 199 | if self.rnn_cell is nn.LSTM: 200 | batch_size = encoder_hidden[0].size(1) 201 | elif self.rnn_cell is nn.GRU: 202 | batch_size = encoder_hidden.size(1) 203 | 204 | # set default input and max decoding length 205 | if inputs is None: 206 | if teacher_forcing_ratio > 0: 207 | raise ValueError("Teacher forcing has to be disabled (set 0) when no inputs is provided.") 208 | inputs = torch.LongTensor([self.sos_id] * batch_size).view(batch_size, 1) 209 | if torch.cuda.is_available(): 210 | inputs = inputs.cuda() 211 | max_length = self.max_length 212 | else: 213 | max_length = inputs.size(1) - 1 # minus the start of sequence symbol 214 | 215 | return inputs, batch_size, max_length 216 | -------------------------------------------------------------------------------- /seq2seq/models/DecoderRNN.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/models/DecoderRNN.pyc -------------------------------------------------------------------------------- /seq2seq/models/DecoderRNN_ex.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | import torch.nn.functional as F 10 | 11 | from .attention import Attention 12 | from .baseRNN import BaseRNN 13 | 14 | if torch.cuda.is_available(): 15 | import torch.cuda as device 16 | else: 17 | import torch as device 18 | 19 | 20 | class DecoderRNN(BaseRNN): 21 | r""" 22 | Provides functionality for decoding in a seq2seq framework, with an option for attention. 23 | 24 | Args: 25 | vocab_size (int): size of the vocabulary 26 | max_len (int): a maximum allowed length for the sequence to be processed 27 | hidden_size (int): the number of features in the hidden state `h` 28 | sos_id (int): index of the start of sentence symbol 29 | eos_id (int): index of the end of sentence symbol 30 | n_layers (int, optional): number of recurrent layers (default: 1) 31 | rnn_cell (str, optional): type of RNN cell (default: gru) 32 | bidirectional (bool, optional): if the encoder is bidirectional (default False) 33 | input_dropout_p (float, optional): dropout probability for the input sequence (default: 0) 34 | dropout_p (float, optional): dropout probability for the output sequence (default: 0) 35 | use_attention(bool, optional): flag indication whether to use attention mechanism or not (default: false) 36 | 37 | Attributes: 38 | KEY_ATTN_SCORE (str): key used to indicate attention weights in `ret_dict` 39 | KEY_LENGTH (str): key used to indicate a list representing lengths of output sequences in `ret_dict` 40 | KEY_SEQUENCE (str): key used to indicate a list of sequences in `ret_dict` 41 | 42 | Inputs: inputs, encoder_hidden, encoder_outputs, function, teacher_forcing_ratio 43 | - **inputs** (batch, seq_len, input_size): list of sequences, whose length is the batch size and within which 44 | each sequence is a list of token IDs. It is used for teacher forcing when provided. (default `None`) 45 | - **encoder_hidden** (num_layers * num_directions, batch_size, hidden_size): tensor containing the features in the 46 | hidden state `h` of encoder. Used as the initial hidden state of the decoder. (default `None`) 47 | - **encoder_outputs** (batch, seq_len, hidden_size): tensor with containing the outputs of the encoder. 48 | Used for attention mechanism (default is `None`). 49 | - **function** (torch.nn.Module): A function used to generate symbols from RNN hidden state 50 | (default is `torch.nn.functional.log_softmax`). 51 | - **teacher_forcing_ratio** (float): The probability that teacher forcing will be used. A random number is 52 | drawn uniformly from 0-1 for every decoding token, and if the sample is smaller than the given value, 53 | teacher forcing would be used (default is 0). 54 | 55 | Outputs: decoder_outputs, decoder_hidden, ret_dict 56 | - **decoder_outputs** (seq_len, batch, vocab_size): list of tensors with size (batch_size, vocab_size) containing 57 | the outputs of the decoding function. 58 | - **decoder_hidden** (num_layers * num_directions, batch, hidden_size): tensor containing the last hidden 59 | state of the decoder. 60 | - **ret_dict**: dictionary containing additional information as follows {*KEY_LENGTH* : list of integers 61 | representing lengths of output sequences, *KEY_SEQUENCE* : list of sequences, where each sequence is a list of 62 | predicted token IDs }. 63 | """ 64 | 65 | KEY_ATTN_SCORE = 'attention_score' 66 | KEY_LENGTH = 'length' 67 | KEY_SEQUENCE = 'sequence' 68 | 69 | def __init__(self, vocab_size, max_len, hidden_size, 70 | sos_id, eos_id, 71 | n_layers=1, rnn_cell='gru', bidirectional=False, 72 | input_dropout_p=0, dropout_p=0, use_attention=False, word_embedding_size=100): 73 | super(DecoderRNN, self).__init__(vocab_size, max_len, hidden_size, 74 | input_dropout_p, dropout_p, 75 | n_layers, rnn_cell) 76 | 77 | self.word_embedding_size = word_embedding_size 78 | 79 | 80 | self.bidirectional_encoder = bidirectional 81 | self.rnn = self.rnn_cell(self.word_embedding_size, hidden_size, n_layers, batch_first=True, dropout=dropout_p) 82 | 83 | self.output_size = vocab_size 84 | self.max_length = max_len 85 | self.use_attention = use_attention 86 | self.eos_id = eos_id 87 | self.sos_id = sos_id 88 | 89 | self.init_input = None 90 | 91 | self.embedding = nn.Embedding(self.output_size, self.word_embedding_size) 92 | if use_attention: 93 | self.attention = Attention(self.hidden_size) 94 | 95 | self.out = nn.Linear(self.hidden_size, self.output_size) 96 | 97 | def forward_step(self, input_var, hidden, encoder_outputs, function): 98 | batch_size = input_var.size(0) 99 | output_size = input_var.size(1) 100 | embedded = self.embedding(input_var) 101 | embedded = self.input_dropout(embedded) 102 | 103 | output, hidden = self.rnn(embedded, hidden) 104 | 105 | attn = None 106 | if self.use_attention: 107 | output, attn = self.attention(output, encoder_outputs) 108 | 109 | predicted_softmax = function(self.out(output.contiguous().view(-1, self.hidden_size)), dim=1).view(batch_size, output_size, -1) 110 | return predicted_softmax, hidden, attn 111 | 112 | def forward(self, inputs=None, encoder_hidden=None, encoder_outputs=None, 113 | function=F.log_softmax, teacher_forcing_ratio=0): 114 | ret_dict = dict() 115 | 116 | if self.use_attention: 117 | ret_dict[DecoderRNN.KEY_ATTN_SCORE] = list() 118 | 119 | inputs, batch_size, max_length = self._validate_args(inputs, encoder_hidden, encoder_outputs, 120 | function, teacher_forcing_ratio) 121 | decoder_hidden = self._init_state(encoder_hidden) 122 | 123 | use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False 124 | 125 | decoder_outputs = [] 126 | sequence_symbols = [] 127 | input_lists = [] 128 | lengths = np.array([max_length] * batch_size) 129 | 130 | def decode(step, step_output, step_attn, test = None): 131 | decoder_outputs.append(step_output) 132 | print(decoder_outputs.size()) 133 | if self.use_attention: 134 | ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn) 135 | probability = decoder_outputs[-1].topk(1)[0] 136 | symbols = decoder_outputs[-1].topk(1)[1] 137 | if test and symbols == 3: 138 | input_lists.append(decoder_outputs[:-1]) 139 | 140 | sequence_symbols.append(symbols) 141 | 142 | eos_batches = symbols.data.eq(self.eos_id) 143 | if eos_batches.dim() > 0: 144 | eos_batches = eos_batches.cpu().view(-1).numpy() 145 | update_idx = ((lengths > step) & eos_batches) != 0 146 | lengths[update_idx] = len(sequence_symbols) 147 | if test: 148 | return symbols 149 | else: 150 | return symbols, input_lists 151 | 152 | # Manual unrolling is used to support random teacher forcing. 153 | # If teacher_forcing_ratio is True or False instead of a probability, the unrolling can be done in graph 154 | if use_teacher_forcing: 155 | decoder_input = inputs[:, :-1] 156 | decoder_output, decoder_hidden, attn = self.forward_step(decoder_input, decoder_hidden, encoder_outputs, 157 | function=function) 158 | 159 | for di in range(decoder_output.size(1)): 160 | step_output = decoder_output[:, di, :] 161 | if attn is not None: 162 | step_attn = attn[:, di, :] 163 | else: 164 | step_attn = None 165 | decode(di, step_output, step_attn) 166 | else: 167 | decoder_input = inputs[:, 0].unsqueeze(1) 168 | for di in range(max_length): 169 | 170 | decoder_output, decoder_hidden, step_attn = self.forward_step(decoder_input, decoder_hidden, encoder_outputs, 171 | function=function) 172 | print(decoder_output.size()) 173 | step_output = decoder_output.squeeze(1) 174 | symbols, _ = decode(di, step_output, step_attn,use_teacher_forcing) 175 | decoder_input = symbols 176 | 177 | ret_dict[DecoderRNN.KEY_SEQUENCE] = sequence_symbols 178 | ret_dict[DecoderRNN.KEY_LENGTH] = lengths.tolist() 179 | return decoder_outputs, decoder_hidden, ret_dict 180 | 181 | def _init_state(self, encoder_hidden): 182 | """ Initialize the encoder hidden state. """ 183 | if encoder_hidden is None: 184 | return None 185 | if isinstance(encoder_hidden, tuple): 186 | encoder_hidden = tuple([self._cat_directions(h) for h in encoder_hidden]) 187 | else: 188 | encoder_hidden = self._cat_directions(encoder_hidden) 189 | return encoder_hidden 190 | 191 | def _cat_directions(self, h): 192 | """ If the encoder is bidirectional, do the following transformation. 193 | (#directions * #layers, #batch, hidden_size) -> (#layers, #batch, #directions * hidden_size) 194 | """ 195 | if self.bidirectional_encoder: 196 | h = torch.cat([h[0:h.size(0):2], h[1:h.size(0):2]], 2) 197 | return h 198 | 199 | def _validate_args(self, inputs, encoder_hidden, encoder_outputs, function, teacher_forcing_ratio): 200 | if self.use_attention: 201 | if encoder_outputs is None: 202 | raise ValueError("Argument encoder_outputs cannot be None when attention is used.") 203 | 204 | # inference batch size 205 | if inputs is None and encoder_hidden is None: 206 | batch_size = 1 207 | else: 208 | if inputs is not None: 209 | batch_size = inputs.size(0) 210 | else: 211 | if self.rnn_cell is nn.LSTM: 212 | batch_size = encoder_hidden[0].size(1) 213 | elif self.rnn_cell is nn.GRU: 214 | batch_size = encoder_hidden.size(1) 215 | 216 | # set default input and max decoding length 217 | if inputs is None: 218 | if teacher_forcing_ratio > 0: 219 | raise ValueError("Teacher forcing has to be disabled (set 0) when no inputs is provided.") 220 | inputs = torch.LongTensor([self.sos_id] * batch_size).view(batch_size, 1) 221 | if torch.cuda.is_available(): 222 | inputs = inputs.cuda() 223 | max_length = self.max_length 224 | else: 225 | max_length = inputs.size(1) - 1 # minus the start of sequence symbol 226 | 227 | return inputs, batch_size, max_length 228 | -------------------------------------------------------------------------------- /seq2seq/models/EncoderRNN.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .baseRNN import BaseRNN 4 | 5 | class EncoderRNN(BaseRNN): 6 | r""" 7 | Applies a multi-layer RNN to an input sequence. 8 | 9 | Args: 10 | vocab_size (int): size of the vocabulary 11 | max_len (int): a maximum allowed length for the sequence to be processed 12 | hidden_size (int): the number of features in the hidden state `h` 13 | input_dropout_p (float, optional): dropout probability for the input sequence (default: 0) 14 | dropout_p (float, optional): dropout probability for the output sequence (default: 0) 15 | n_layers (int, optional): number of recurrent layers (default: 1) 16 | bidirectional (bool, optional): if True, becomes a bidirectional encodr (defulat False) 17 | rnn_cell (str, optional): type of RNN cell (default: gru) 18 | variable_lengths (bool, optional): if use variable length RNN (default: False) 19 | embedding (torch.Tensor, optional): Pre-trained embedding. The size of the tensor has to match 20 | the size of the embedding parameter: (vocab_size, hidden_size). The embedding layer would be initialized 21 | with the tensor if provided (default: None). 22 | update_embedding (bool, optional): If the embedding should be updated during training (default: False). 23 | 24 | Inputs: inputs, input_lengths 25 | - **inputs**: list of sequences, whose length is the batch size and within which each sequence is a list of token IDs. 26 | - **input_lengths** (list of int, optional): list that contains the lengths of sequences 27 | in the mini-batch, it must be provided when using variable length RNN (default: `None`) 28 | 29 | Outputs: output, hidden 30 | - **output** (batch, seq_len, hidden_size): tensor containing the encoded features of the input sequence 31 | - **hidden** (num_layers * num_directions, batch, hidden_size): tensor containing the features in the hidden state `h` 32 | 33 | Examples:: 34 | 35 | >>> encoder = EncoderRNN(input_vocab, max_seq_length, hidden_size) 36 | >>> output, hidden = encoder(input) 37 | 38 | """ 39 | 40 | def __init__(self, vocab_size, max_len, hidden_size, input_dropout_p=0, dropout_p=0, 41 | n_layers=1, bidirectional=False, rnn_cell='gru', variable_lengths=False, 42 | embedding=None, update_embedding=True, word_embedding_size = 100): 43 | super(EncoderRNN, self).__init__(vocab_size, max_len, hidden_size, 44 | input_dropout_p, dropout_p, n_layers, rnn_cell) 45 | self.word_embedding_size = word_embedding_size 46 | self.variable_lengths = variable_lengths 47 | self.embedding = nn.Embedding(vocab_size, self.word_embedding_size) 48 | if embedding is not None: 49 | self.embedding.weight = nn.Parameter(embedding) 50 | self.embedding.weight.requires_grad = update_embedding 51 | self.rnn = self.rnn_cell(self.word_embedding_size, hidden_size, n_layers, 52 | batch_first=True, bidirectional=bidirectional, dropout=dropout_p) 53 | 54 | def forward(self, input_var, input_lengths=None): 55 | """ 56 | Applies a multi-layer RNN to an input sequence. 57 | 58 | Args: 59 | input_var (batch, seq_len): tensor containing the features of the input sequence. 60 | input_lengths (list of int, optional): A list that contains the lengths of sequences 61 | in the mini-batch 62 | 63 | Returns: output, hidden 64 | - **output** (batch, seq_len, hidden_size): variable containing the encoded features of the input sequence 65 | - **hidden** (num_layers * num_directions, batch, hidden_size): variable containing the features in the hidden state h 66 | """ 67 | embedded = self.embedding(input_var) 68 | embedded = self.input_dropout(embedded) 69 | if self.variable_lengths: 70 | embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True) 71 | output, hidden = self.rnn(embedded) 72 | if self.variable_lengths: 73 | output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) 74 | return output, hidden 75 | -------------------------------------------------------------------------------- /seq2seq/models/EncoderRNN.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/models/EncoderRNN.pyc -------------------------------------------------------------------------------- /seq2seq/models/TopKDecoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | 5 | def _inflate(tensor, times, dim): 6 | """ 7 | Given a tensor, 'inflates' it along the given dimension by replicating each slice specified number of times (in-place) 8 | 9 | Args: 10 | tensor: A :class:`Tensor` to inflate 11 | times: number of repetitions 12 | dim: axis for inflation (default=0) 13 | 14 | Returns: 15 | A :class:`Tensor` 16 | 17 | Examples:: 18 | >> a = torch.LongTensor([[1, 2], [3, 4]]) 19 | >> a 20 | 1 2 21 | 3 4 22 | [torch.LongTensor of size 2x2] 23 | >> b = ._inflate(a, 2, dim=1) 24 | >> b 25 | 1 2 1 2 26 | 3 4 3 4 27 | [torch.LongTensor of size 2x4] 28 | >> c = _inflate(a, 2, dim=0) 29 | >> c 30 | 1 2 31 | 3 4 32 | 1 2 33 | 3 4 34 | [torch.LongTensor of size 4x2] 35 | 36 | """ 37 | repeat_dims = [1] * tensor.dim() 38 | repeat_dims[dim] = times 39 | return tensor.repeat(*repeat_dims) 40 | 41 | class TopKDecoder(torch.nn.Module): 42 | r""" 43 | Top-K decoding with beam search. 44 | 45 | Args: 46 | decoder_rnn (DecoderRNN): An object of DecoderRNN used for decoding. 47 | k (int): Size of the beam. 48 | 49 | Inputs: inputs, encoder_hidden, encoder_outputs, function, teacher_forcing_ratio 50 | - **inputs** (seq_len, batch, input_size): list of sequences, whose length is the batch size and within which 51 | each sequence is a list of token IDs. It is used for teacher forcing when provided. (default is `None`) 52 | - **encoder_hidden** (num_layers * num_directions, batch_size, hidden_size): tensor containing the features 53 | in the hidden state `h` of encoder. Used as the initial hidden state of the decoder. 54 | - **encoder_outputs** (batch, seq_len, hidden_size): tensor with containing the outputs of the encoder. 55 | Used for attention mechanism (default is `None`). 56 | - **function** (torch.nn.Module): A function used to generate symbols from RNN hidden state 57 | (default is `torch.nn.functional.log_softmax`). 58 | - **teacher_forcing_ratio** (float): The probability that teacher forcing will be used. A random number is 59 | drawn uniformly from 0-1 for every decoding token, and if the sample is smaller than the given value, 60 | teacher forcing would be used (default is 0). 61 | 62 | Outputs: decoder_outputs, decoder_hidden, ret_dict 63 | - **decoder_outputs** (batch): batch-length list of tensors with size (max_length, hidden_size) containing the 64 | outputs of the decoder. 65 | - **decoder_hidden** (num_layers * num_directions, batch, hidden_size): tensor containing the last hidden 66 | state of the decoder. 67 | - **ret_dict**: dictionary containing additional information as follows {*length* : list of integers 68 | representing lengths of output sequences, *topk_length*: list of integers representing lengths of beam search 69 | sequences, *sequence* : list of sequences, where each sequence is a list of predicted token IDs, 70 | *topk_sequence* : list of beam search sequences, each beam is a list of token IDs, *inputs* : target 71 | outputs if provided for decoding}. 72 | """ 73 | 74 | def __init__(self, decoder_rnn, k): 75 | super(TopKDecoder, self).__init__() 76 | self.rnn = decoder_rnn 77 | self.k = k 78 | self.hidden_size = self.rnn.hidden_size 79 | self.V = self.rnn.output_size 80 | self.SOS = self.rnn.sos_id 81 | self.EOS = self.rnn.eos_id 82 | 83 | def forward(self, inputs=None, encoder_hidden=None, encoder_outputs=None, function=F.log_softmax, 84 | teacher_forcing_ratio=0, retain_output_probs=True): 85 | """ 86 | Forward rnn for MAX_LENGTH steps. Look at :func:`seq2seq.models.DecoderRNN.DecoderRNN.forward_rnn` for details. 87 | """ 88 | 89 | inputs, batch_size, max_length = self.rnn._validate_args(inputs, encoder_hidden, encoder_outputs, 90 | function, teacher_forcing_ratio) 91 | 92 | self.pos_index = Variable(torch.LongTensor(range(batch_size)) * self.k).view(-1, 1) 93 | 94 | # Inflate the initial hidden states to be of size: b*k x h 95 | encoder_hidden = self.rnn._init_state(encoder_hidden) 96 | if encoder_hidden is None: 97 | hidden = None 98 | else: 99 | if isinstance(encoder_hidden, tuple): 100 | hidden = tuple([_inflate(h, self.k, 1) for h in encoder_hidden]) 101 | else: 102 | hidden = _inflate(encoder_hidden, self.k, 1) 103 | 104 | # ... same idea for encoder_outputs and decoder_outputs 105 | if self.rnn.use_attention: 106 | inflated_encoder_outputs = _inflate(encoder_outputs, self.k, 0) 107 | else: 108 | inflated_encoder_outputs = None 109 | 110 | # Initialize the scores; for the first step, 111 | # ignore the inflated copies to avoid duplicate entries in the top k 112 | sequence_scores = torch.Tensor(batch_size * self.k, 1) 113 | sequence_scores.fill_(-float('Inf')) 114 | sequence_scores.index_fill_(0, torch.LongTensor([i * self.k for i in range(0, batch_size)]), 0.0) 115 | sequence_scores = Variable(sequence_scores) 116 | 117 | # Initialize the input vector 118 | input_var = Variable(torch.transpose(torch.LongTensor([[self.SOS] * batch_size * self.k]), 0, 1)) 119 | 120 | # Store decisions for backtracking 121 | stored_outputs = list() 122 | stored_scores = list() 123 | stored_predecessors = list() 124 | stored_emitted_symbols = list() 125 | stored_hidden = list() 126 | 127 | for _ in range(0, max_length): 128 | 129 | # Run the RNN one step forward 130 | log_softmax_output, hidden, _ = self.rnn.forward_step(input_var, hidden, 131 | inflated_encoder_outputs, function=function) 132 | 133 | # If doing local backprop (e.g. supervised training), retain the output layer 134 | if retain_output_probs: 135 | stored_outputs.append(log_softmax_output) 136 | 137 | # To get the full sequence scores for the new candidates, add the local scores for t_i to the predecessor scores for t_(i-1) 138 | sequence_scores = _inflate(sequence_scores, self.V, 1) 139 | sequence_scores += log_softmax_output.squeeze(1) 140 | scores, candidates = sequence_scores.view(batch_size, -1).topk(self.k, dim=1) 141 | 142 | # Reshape input = (bk, 1) and sequence_scores = (bk, 1) 143 | input_var = (candidates % self.V).view(batch_size * self.k, 1) 144 | sequence_scores = scores.view(batch_size * self.k, 1) 145 | 146 | # Update fields for next timestep 147 | predecessors = (candidates / self.V + self.pos_index.expand_as(candidates)).view(batch_size * self.k, 1) 148 | if isinstance(hidden, tuple): 149 | hidden = tuple([h.index_select(1, predecessors.squeeze()) for h in hidden]) 150 | else: 151 | hidden = hidden.index_select(1, predecessors.squeeze()) 152 | 153 | # Update sequence scores and erase scores for end-of-sentence symbol so that they aren't expanded 154 | stored_scores.append(sequence_scores.clone()) 155 | eos_indices = input_var.data.eq(self.EOS) 156 | if eos_indices.nonzero().dim() > 0: 157 | sequence_scores.data.masked_fill_(eos_indices, -float('inf')) 158 | 159 | # Cache results for backtracking 160 | stored_predecessors.append(predecessors) 161 | stored_emitted_symbols.append(input_var) 162 | stored_hidden.append(hidden) 163 | 164 | # Do backtracking to return the optimal values 165 | output, h_t, h_n, s, l, p = self._backtrack(stored_outputs, stored_hidden, 166 | stored_predecessors, stored_emitted_symbols, 167 | stored_scores, batch_size, self.hidden_size) 168 | 169 | # Build return objects 170 | decoder_outputs = [step[:, 0, :] for step in output] 171 | if isinstance(h_n, tuple): 172 | decoder_hidden = tuple([h[:, :, 0, :] for h in h_n]) 173 | else: 174 | decoder_hidden = h_n[:, :, 0, :] 175 | metadata = {} 176 | metadata['inputs'] = inputs 177 | metadata['output'] = output 178 | metadata['h_t'] = h_t 179 | metadata['score'] = s 180 | metadata['topk_length'] = l 181 | metadata['topk_sequence'] = p 182 | metadata['length'] = [seq_len[0] for seq_len in l] 183 | metadata['sequence'] = [seq[0] for seq in p] 184 | return decoder_outputs, decoder_hidden, metadata 185 | 186 | def _backtrack(self, nw_output, nw_hidden, predecessors, symbols, scores, b, hidden_size): 187 | """Backtracks over batch to generate optimal k-sequences. 188 | 189 | Args: 190 | nw_output [(batch*k, vocab_size)] * sequence_length: A Tensor of outputs from network 191 | nw_hidden [(num_layers, batch*k, hidden_size)] * sequence_length: A Tensor of hidden states from network 192 | predecessors [(batch*k)] * sequence_length: A Tensor of predecessors 193 | symbols [(batch*k)] * sequence_length: A Tensor of predicted tokens 194 | scores [(batch*k)] * sequence_length: A Tensor containing sequence scores for every token t = [0, ... , seq_len - 1] 195 | b: Size of the batch 196 | hidden_size: Size of the hidden state 197 | 198 | Returns: 199 | output [(batch, k, vocab_size)] * sequence_length: A list of the output probabilities (p_n) 200 | from the last layer of the RNN, for every n = [0, ... , seq_len - 1] 201 | 202 | h_t [(batch, k, hidden_size)] * sequence_length: A list containing the output features (h_n) 203 | from the last layer of the RNN, for every n = [0, ... , seq_len - 1] 204 | 205 | h_n(batch, k, hidden_size): A Tensor containing the last hidden state for all top-k sequences. 206 | 207 | score [batch, k]: A list containing the final scores for all top-k sequences 208 | 209 | length [batch, k]: A list specifying the length of each sequence in the top-k candidates 210 | 211 | p (batch, k, sequence_len): A Tensor containing predicted sequence 212 | """ 213 | 214 | lstm = isinstance(nw_hidden[0], tuple) 215 | 216 | # initialize return variables given different types 217 | output = list() 218 | h_t = list() 219 | p = list() 220 | # Placeholder for last hidden state of top-k sequences. 221 | # If a (top-k) sequence ends early in decoding, `h_n` contains 222 | # its hidden state when it sees EOS. Otherwise, `h_n` contains 223 | # the last hidden state of decoding. 224 | if lstm: 225 | state_size = nw_hidden[0][0].size() 226 | h_n = tuple([torch.zeros(state_size), torch.zeros(state_size)]) 227 | else: 228 | h_n = torch.zeros(nw_hidden[0].size()) 229 | l = [[self.rnn.max_length] * self.k for _ in range(b)] # Placeholder for lengths of top-k sequences 230 | # Similar to `h_n` 231 | 232 | # the last step output of the beams are not sorted 233 | # thus they are sorted here 234 | sorted_score, sorted_idx = scores[-1].view(b, self.k).topk(self.k) 235 | # initialize the sequence scores with the sorted last step beam scores 236 | s = sorted_score.clone() 237 | 238 | batch_eos_found = [0] * b # the number of EOS found 239 | # in the backward loop below for each batch 240 | 241 | t = self.rnn.max_length - 1 242 | # initialize the back pointer with the sorted order of the last step beams. 243 | # add self.pos_index for indexing variable with b*k as the first dimension. 244 | t_predecessors = (sorted_idx + self.pos_index.expand_as(sorted_idx)).view(b * self.k) 245 | while t >= 0: 246 | # Re-order the variables with the back pointer 247 | current_output = nw_output[t].index_select(0, t_predecessors) 248 | if lstm: 249 | current_hidden = tuple([h.index_select(1, t_predecessors) for h in nw_hidden[t]]) 250 | else: 251 | current_hidden = nw_hidden[t].index_select(1, t_predecessors) 252 | current_symbol = symbols[t].index_select(0, t_predecessors) 253 | # Re-order the back pointer of the previous step with the back pointer of 254 | # the current step 255 | t_predecessors = predecessors[t].index_select(0, t_predecessors).squeeze() 256 | 257 | # This tricky block handles dropped sequences that see EOS earlier. 258 | # The basic idea is summarized below: 259 | # 260 | # Terms: 261 | # Ended sequences = sequences that see EOS early and dropped 262 | # Survived sequences = sequences in the last step of the beams 263 | # 264 | # Although the ended sequences are dropped during decoding, 265 | # their generated symbols and complete backtracking information are still 266 | # in the backtracking variables. 267 | # For each batch, everytime we see an EOS in the backtracking process, 268 | # 1. If there is survived sequences in the return variables, replace 269 | # the one with the lowest survived sequence score with the new ended 270 | # sequences 271 | # 2. Otherwise, replace the ended sequence with the lowest sequence 272 | # score with the new ended sequence 273 | # 274 | eos_indices = symbols[t].data.squeeze(1).eq(self.EOS).nonzero() 275 | if eos_indices.dim() > 0: 276 | for i in range(eos_indices.size(0)-1, -1, -1): 277 | # Indices of the EOS symbol for both variables 278 | # with b*k as the first dimension, and b, k for 279 | # the first two dimensions 280 | idx = eos_indices[i] 281 | b_idx = int(idx[0] / self.k) 282 | # The indices of the replacing position 283 | # according to the replacement strategy noted above 284 | res_k_idx = self.k - (batch_eos_found[b_idx] % self.k) - 1 285 | batch_eos_found[b_idx] += 1 286 | res_idx = b_idx * self.k + res_k_idx 287 | 288 | # Replace the old information in return variables 289 | # with the new ended sequence information 290 | t_predecessors[res_idx] = predecessors[t][idx[0]] 291 | current_output[res_idx, :] = nw_output[t][idx[0], :] 292 | if lstm: 293 | current_hidden[0][:, res_idx, :] = nw_hidden[t][0][:, idx[0], :] 294 | current_hidden[1][:, res_idx, :] = nw_hidden[t][1][:, idx[0], :] 295 | h_n[0][:, res_idx, :] = nw_hidden[t][0][:, idx[0], :].data 296 | h_n[1][:, res_idx, :] = nw_hidden[t][1][:, idx[0], :].data 297 | else: 298 | current_hidden[:, res_idx, :] = nw_hidden[t][:, idx[0], :] 299 | h_n[:, res_idx, :] = nw_hidden[t][:, idx[0], :].data 300 | current_symbol[res_idx, :] = symbols[t][idx[0]] 301 | s[b_idx, res_k_idx] = scores[t][idx[0]].data[0] 302 | l[b_idx][res_k_idx] = t + 1 303 | 304 | # record the back tracked results 305 | output.append(current_output) 306 | h_t.append(current_hidden) 307 | p.append(current_symbol) 308 | 309 | t -= 1 310 | 311 | # Sort and re-order again as the added ended sequences may change 312 | # the order (very unlikely) 313 | s, re_sorted_idx = s.topk(self.k) 314 | for b_idx in range(b): 315 | l[b_idx] = [l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx,:]] 316 | 317 | re_sorted_idx = (re_sorted_idx + self.pos_index.expand_as(re_sorted_idx)).view(b * self.k) 318 | 319 | # Reverse the sequences and re-order at the same time 320 | # It is reversed because the backtracking happens in reverse time order 321 | output = [step.index_select(0, re_sorted_idx).view(b, self.k, -1) for step in reversed(output)] 322 | p = [step.index_select(0, re_sorted_idx).view(b, self.k, -1) for step in reversed(p)] 323 | if lstm: 324 | h_t = [tuple([h.index_select(1, re_sorted_idx).view(-1, b, self.k, hidden_size) for h in step]) for step in reversed(h_t)] 325 | h_n = tuple([h.index_select(1, re_sorted_idx.data).view(-1, b, self.k, hidden_size) for h in h_n]) 326 | else: 327 | h_t = [step.index_select(1, re_sorted_idx).view(-1, b, self.k, hidden_size) for step in reversed(h_t)] 328 | h_n = h_n.index_select(1, re_sorted_idx.data).view(-1, b, self.k, hidden_size) 329 | s = s.data 330 | 331 | return output, h_t, h_n, s, l, p 332 | 333 | def _mask_symbol_scores(self, score, idx, masking_score=-float('inf')): 334 | score[idx] = masking_score 335 | 336 | def _mask(self, tensor, idx, dim=0, masking_score=-float('inf')): 337 | if len(idx.size()) > 0: 338 | indices = idx[:, 0] 339 | tensor.index_fill_(dim, indices, masking_score) 340 | 341 | 342 | -------------------------------------------------------------------------------- /seq2seq/models/TopKDecoder.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/models/TopKDecoder.pyc -------------------------------------------------------------------------------- /seq2seq/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .EncoderRNN import EncoderRNN 2 | from .DecoderRNN import DecoderRNN 3 | from .TopKDecoder import TopKDecoder 4 | from .seq2seq import Seq2seq 5 | -------------------------------------------------------------------------------- /seq2seq/models/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/models/__init__.pyc -------------------------------------------------------------------------------- /seq2seq/models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Attention(nn.Module): 7 | r""" 8 | Applies an attention mechanism on the output features from the decoder. 9 | 10 | .. math:: 11 | \begin{array}{ll} 12 | x = context*output \\ 13 | attn = exp(x_i) / sum_j exp(x_j) \\ 14 | output = \tanh(w * (attn * context) + b * output) 15 | \end{array} 16 | 17 | Args: 18 | dim(int): The number of expected features in the output 19 | 20 | Inputs: output, context 21 | - **output** (batch, output_len, dimensions): tensor containing the output features from the decoder. 22 | - **context** (batch, input_len, dimensions): tensor containing features of the encoded input sequence. 23 | 24 | Outputs: output, attn 25 | - **output** (batch, output_len, dimensions): tensor containing the attended output features from the decoder. 26 | - **attn** (batch, output_len, input_len): tensor containing attention weights. 27 | 28 | Attributes: 29 | linear_out (torch.nn.Linear): applies a linear transformation to the incoming data: :math:`y = Ax + b`. 30 | mask (torch.Tensor, optional): applies a :math:`-inf` to the indices specified in the `Tensor`. 31 | 32 | Examples:: 33 | 34 | >>> attention = seq2seq.models.Attention(256) 35 | >>> context = Variable(torch.randn(5, 3, 256)) 36 | >>> output = Variable(torch.randn(5, 5, 256)) 37 | >>> output, attn = attention(output, context) 38 | 39 | """ 40 | def __init__(self, dim): 41 | super(Attention, self).__init__() 42 | self.linear_out = nn.Linear(dim*2, dim) 43 | self.mask = None 44 | 45 | def set_mask(self, mask): 46 | """ 47 | Sets indices to be masked 48 | 49 | Args: 50 | mask (torch.Tensor): tensor containing indices to be masked 51 | """ 52 | self.mask = mask 53 | 54 | def forward(self, output, context): 55 | batch_size = output.size(0) 56 | hidden_size = output.size(2) 57 | input_size = context.size(1) 58 | # (batch, out_len, dim) * (batch, in_len, dim) -> (batch, out_len, in_len) 59 | attn = torch.bmm(output, context.transpose(1, 2)) 60 | if self.mask is not None: 61 | attn.data.masked_fill_(self.mask, -float('inf')) 62 | attn = F.softmax(attn.view(-1, input_size), dim=1).view(batch_size, -1, input_size) 63 | 64 | # (batch, out_len, in_len) * (batch, in_len, dim) -> (batch, out_len, dim) 65 | mix = torch.bmm(attn, context) 66 | 67 | # concat -> (batch, out_len, 2*dim) 68 | combined = torch.cat((mix, output), dim=2) 69 | # output -> (batch, out_len, dim) 70 | output = F.tanh(self.linear_out(combined.view(-1, 2 * hidden_size))).view(batch_size, -1, hidden_size) 71 | 72 | return output, attn 73 | -------------------------------------------------------------------------------- /seq2seq/models/attention.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/models/attention.pyc -------------------------------------------------------------------------------- /seq2seq/models/baseRNN.py: -------------------------------------------------------------------------------- 1 | """ A base class for RNN. """ 2 | import torch.nn as nn 3 | 4 | 5 | class BaseRNN(nn.Module): 6 | r""" 7 | Applies a multi-layer RNN to an input sequence. 8 | Note: 9 | Do not use this class directly, use one of the sub classes. 10 | Args: 11 | vocab_size (int): size of the vocabulary 12 | max_len (int): maximum allowed length for the sequence to be processed 13 | hidden_size (int): number of features in the hidden state `h` 14 | input_dropout_p (float): dropout probability for the input sequence 15 | dropout_p (float): dropout probability for the output sequence 16 | n_layers (int): number of recurrent layers 17 | rnn_cell (str): type of RNN cell (Eg. 'LSTM' , 'GRU') 18 | 19 | Inputs: ``*args``, ``**kwargs`` 20 | - ``*args``: variable length argument list. 21 | - ``**kwargs``: arbitrary keyword arguments. 22 | 23 | Attributes: 24 | SYM_MASK: masking symbol 25 | SYM_EOS: end-of-sequence symbol 26 | """ 27 | SYM_MASK = "MASK" 28 | SYM_EOS = "EOS" 29 | 30 | def __init__(self, vocab_size, max_len, hidden_size, input_dropout_p, dropout_p, n_layers, rnn_cell): 31 | super(BaseRNN, self).__init__() 32 | self.vocab_size = vocab_size 33 | self.max_len = max_len 34 | self.hidden_size = hidden_size 35 | self.n_layers = n_layers 36 | self.input_dropout_p = input_dropout_p 37 | self.input_dropout = nn.Dropout(p=input_dropout_p) 38 | if rnn_cell.lower() == 'lstm': 39 | self.rnn_cell = nn.LSTM 40 | elif rnn_cell.lower() == 'gru': 41 | self.rnn_cell = nn.GRU 42 | else: 43 | raise ValueError("Unsupported RNN Cell: {0}".format(rnn_cell)) 44 | 45 | self.dropout_p = dropout_p 46 | 47 | def forward(self, *args, **kwargs): 48 | raise NotImplementedError() 49 | -------------------------------------------------------------------------------- /seq2seq/models/baseRNN.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/models/baseRNN.pyc -------------------------------------------------------------------------------- /seq2seq/models/seq2seq.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class Seq2seq(nn.Module): 5 | """ Standard sequence-to-sequence architecture with configurable encoder 6 | and decoder. 7 | 8 | Args: 9 | encoder (EncoderRNN): object of EncoderRNN 10 | decoder (DecoderRNN): object of DecoderRNN 11 | decode_function (func, optional): function to generate symbols from output hidden states (default: F.log_softmax) 12 | 13 | Inputs: input_variable, input_lengths, target_variable, teacher_forcing_ratio 14 | - **input_variable** (list, option): list of sequences, whose length is the batch size and within which 15 | each sequence is a list of token IDs. This information is forwarded to the encoder. 16 | - **input_lengths** (list of int, optional): A list that contains the lengths of sequences 17 | in the mini-batch, it must be provided when using variable length RNN (default: `None`) 18 | - **target_variable** (list, optional): list of sequences, whose length is the batch size and within which 19 | each sequence is a list of token IDs. This information is forwarded to the decoder. 20 | - **teacher_forcing_ratio** (int, optional): The probability that teacher forcing will be used. A random number 21 | is drawn uniformly from 0-1 for every decoding token, and if the sample is smaller than the given value, 22 | teacher forcing would be used (default is 0) 23 | 24 | Outputs: decoder_outputs, decoder_hidden, ret_dict 25 | - **decoder_outputs** (batch): batch-length list of tensors with size (max_length, hidden_size) containing the 26 | outputs of the decoder. 27 | - **decoder_hidden** (num_layers * num_directions, batch, hidden_size): tensor containing the last hidden 28 | state of the decoder. 29 | - **ret_dict**: dictionary containing additional information as follows {*KEY_LENGTH* : list of integers 30 | representing lengths of output sequences, *KEY_SEQUENCE* : list of sequences, where each sequence is a list of 31 | predicted token IDs, *KEY_INPUT* : target outputs if provided for decoding, *KEY_ATTN_SCORE* : list of 32 | sequences, where each list is of attention weights }. 33 | 34 | """ 35 | 36 | def __init__(self, encoder, decoder, decode_function=F.log_softmax): 37 | super(Seq2seq, self).__init__() 38 | self.encoder = encoder 39 | self.decoder = decoder 40 | self.decode_function = decode_function 41 | 42 | def flatten_parameters(self): 43 | self.encoder.rnn.flatten_parameters() 44 | self.decoder.rnn.flatten_parameters() 45 | 46 | def forward(self, input_variable, input_lengths=None, target_variable=None, 47 | teacher_forcing_ratio=0): 48 | encoder_outputs, encoder_hidden = self.encoder(input_variable, input_lengths) 49 | result = self.decoder(inputs=target_variable, 50 | encoder_hidden=encoder_hidden, 51 | encoder_outputs=encoder_outputs, 52 | function=self.decode_function, 53 | teacher_forcing_ratio=teacher_forcing_ratio) 54 | return result 55 | -------------------------------------------------------------------------------- /seq2seq/models/seq2seq.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/models/seq2seq.pyc -------------------------------------------------------------------------------- /seq2seq/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .optim import Optimizer 2 | -------------------------------------------------------------------------------- /seq2seq/optim/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/optim/__init__.pyc -------------------------------------------------------------------------------- /seq2seq/optim/optim.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import torch 4 | 5 | class Optimizer(object): 6 | """ The Optimizer class encapsulates torch.optim package and provides functionalities 7 | for learning rate scheduling and gradient norm clipping. 8 | 9 | Args: 10 | optim (torch.optim.Optimizer): optimizer object, the parameters to be optimized 11 | should be given when instantiating the object, e.g. torch.optim.SGD(params) 12 | max_grad_norm (float, optional): value used for gradient norm clipping, 13 | set 0 to disable (default 0) 14 | """ 15 | 16 | _ARG_MAX_GRAD_NORM = 'max_grad_norm' 17 | 18 | def __init__(self, optim, max_grad_norm=0): 19 | self.optimizer = optim 20 | self.scheduler = None 21 | self.max_grad_norm = max_grad_norm 22 | 23 | def set_scheduler(self, scheduler): 24 | """ Set the learning rate scheduler. 25 | 26 | Args: 27 | scheduler (torch.optim.lr_scheduler.*): object of learning rate scheduler, 28 | e.g. torch.optim.lr_scheduler.StepLR 29 | """ 30 | self.scheduler = scheduler 31 | 32 | def step(self): 33 | """ Performs a single optimization step, including gradient norm clipping if necessary. """ 34 | if self.max_grad_norm > 0: 35 | params = itertools.chain.from_iterable([group['params'] for group in self.optimizer.param_groups]) 36 | torch.nn.utils.clip_grad_norm_(params, self.max_grad_norm) 37 | self.optimizer.step() 38 | 39 | def update(self, loss, epoch): 40 | """ Update the learning rate if the criteria of the scheduler are met. 41 | 42 | Args: 43 | loss (float): The current loss. It could be training loss or developing loss 44 | depending on the caller. By default the supervised trainer uses developing 45 | loss. 46 | epoch (int): The current epoch number. 47 | """ 48 | if self.scheduler is None: 49 | pass 50 | elif isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): 51 | self.scheduler.step(loss) 52 | else: 53 | self.scheduler.step() 54 | -------------------------------------------------------------------------------- /seq2seq/optim/optim.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/optim/optim.pyc -------------------------------------------------------------------------------- /seq2seq/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .supervised_trainer import SupervisedTrainer 2 | from .self_critical_trainer import SelfCriticalTrainer 3 | -------------------------------------------------------------------------------- /seq2seq/trainer/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/trainer/__init__.pyc -------------------------------------------------------------------------------- /seq2seq/trainer/evaluate.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import torch 4 | import torchtext 5 | 6 | import os 7 | import argparse 8 | import logging 9 | 10 | import torch 11 | from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau 12 | import torchtext 13 | 14 | import seq2seq 15 | #from seq2seq.models import EncoderRNN, DecoderRNN, Seq2seq, TopKDecoder 16 | from seq2seq.loss import Perplexity, NLLLoss, PositiveLoss 17 | #from seq2seq.optim import Optimizer 18 | from seq2seq.dataset import SourceField, TargetField 19 | from seq2seq.evaluator import Predictor, Evaluator 20 | from seq2seq.util.checkpoint import Checkpoint 21 | 22 | import subprocess 23 | 24 | def decode_tensor(tensor, vocab): 25 | tensor = tensor.view(-1) 26 | words = [] 27 | for i in tensor: 28 | word = vocab.itos[i.cpu().numpy()] 29 | if word == '': 30 | return ' '.join(words) 31 | if word != '' and word != '' and word != '': 32 | words.append(word) 33 | #if word != '': 34 | # words.append(word) 35 | #print('|' + word + '|') 36 | return ' '.join(words) 37 | 38 | from regexDFAEquals import regex_equiv_from_raw, unprocess_regex, regex_equiv 39 | 40 | def refine_outout(regex): 41 | par_list = [] 42 | word_list = regex.split() 43 | 44 | for idx, word in enumerate(word_list): 45 | if word == '(' or word == '[' or word == '{': 46 | par_list.append(word) 47 | 48 | if word == ')' or word == ']' or word == '}': 49 | if len(par_list) == 0: 50 | word_list[idx] = '' 51 | continue 52 | 53 | par_in_list = par_list.pop() 54 | if par_in_list == '(': 55 | word_list[idx] = ')' 56 | elif par_in_list == '[': 57 | word_list[idx] = ']' 58 | elif par_in_list == '{': 59 | word_list[idx] = '}' 60 | 61 | while len(par_list) != 0: 62 | par_in_list = par_list.pop() 63 | if par_in_list == '(': 64 | word_list.append(')') 65 | elif par_in_list == '[': 66 | word_list.append(']') 67 | elif par_in_list == '{': 68 | word_list.append('}') 69 | 70 | word_list = [word for word in word_list if word != ''] 71 | 72 | return ' '.join(word_list) 73 | 74 | def eval_fa_equiv(model, data, input_vocab, output_vocab): 75 | loss = NLLLoss() 76 | batch_size = 1 77 | 78 | model.eval() 79 | 80 | loss.reset() 81 | match = 0 82 | total = 0 83 | 84 | device = None if torch.cuda.is_available() else -1 85 | batch_iterator = torchtext.data.BucketIterator( 86 | dataset=data, batch_size=batch_size, 87 | sort=False, sort_key=lambda x: len(x.src), 88 | device=device, train=False) 89 | tgt_vocab = data.fields[seq2seq.tgt_field_name].vocab 90 | pad = tgt_vocab.stoi[data.fields[seq2seq.tgt_field_name].pad_token] 91 | 92 | predictor = Predictor(model, input_vocab, output_vocab) 93 | 94 | num_samples = 0 95 | perfect_samples = 0 96 | dfa_perfect_samples = 0 97 | 98 | match = 0 99 | total = 0 100 | 101 | with torch.no_grad(): 102 | for batch in batch_iterator: 103 | num_samples = num_samples + 1 104 | 105 | input_variables, input_lengths = getattr(batch, seq2seq.src_field_name) 106 | 107 | target_variables = getattr(batch, seq2seq.tgt_field_name) 108 | 109 | 110 | 111 | target_string = decode_tensor(target_variables, output_vocab) 112 | 113 | #target_string = target_string + " " 114 | 115 | input_string = decode_tensor(input_variables, input_vocab) 116 | 117 | generated_string = ' '.join([x for x in predictor.predict(input_string.strip().split())[:-1] if x != '']) 118 | 119 | 120 | #str(pos_example)[2] 121 | 122 | generated_string = refine_outout(generated_string) 123 | 124 | #str(pos_example)[2] 125 | 126 | pos_example = subprocess.check_output(['python2', 'regexDFAEquals.py', '--gold', '{}'.format(target_string), '--predicted', '{}'.format(generated_string)]) 127 | 128 | if target_string == generated_string: 129 | perfect_samples = perfect_samples + 1 130 | dfa_perfect_samples = dfa_perfect_samples + 1 131 | elif str(pos_example)[2] == '1': 132 | dfa_perfect_samples = dfa_perfect_samples + 1 133 | 134 | 135 | 136 | target_tokens = target_string.split() 137 | generated_tokens = generated_string.split() 138 | 139 | shorter_len = min(len(target_tokens), len(generated_tokens)) 140 | 141 | for idx in range(len(generated_tokens)): 142 | total = total + 1 143 | 144 | if idx >= len(target_tokens): 145 | total = total + 1 146 | elif target_tokens[idx] == generated_tokens[idx]: 147 | match = match + 1 148 | 149 | 150 | if total == 0: 151 | accuracy = float('nan') 152 | else: 153 | accuracy = match / total 154 | 155 | string_accuracy = perfect_samples / num_samples 156 | dfa_accuracy = dfa_perfect_samples /num_samples 157 | 158 | f=open('./time_logs/log_score_time.txt','a') 159 | f.write('{}\n'.format(dfa_accuracy)) 160 | f.close() -------------------------------------------------------------------------------- /seq2seq/trainer/self_critical_trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import logging 3 | import os 4 | import random 5 | import time 6 | 7 | import torch 8 | import torchtext 9 | from torch import optim 10 | 11 | import seq2seq 12 | from seq2seq.evaluator import Evaluator 13 | from seq2seq.loss import NLLLoss, PositiveLoss 14 | from seq2seq.optim import Optimizer 15 | from seq2seq.util.checkpoint import Checkpoint 16 | 17 | class SelfCriticalTrainer(object): 18 | """ The SupervisedTrainer class helps in setting up a training framework in a 19 | supervised setting. 20 | 21 | Args: 22 | expt_dir (optional, str): experiment Directory to store details of the experiment, 23 | by default it makes a folder in the current directory to store the details (default: `experiment`). 24 | loss (seq2seq.loss.loss.Loss, optional): loss for training, (default: seq2seq.loss.NLLLoss) 25 | batch_size (int, optional): batch size for experiment, (default: 64) 26 | checkpoint_every (int, optional): number of batches to checkpoint after, (default: 100) 27 | """ 28 | def __init__(self, expt_dir='experiment_sc', loss=PositiveLoss(), batch_size=64, 29 | random_seed=None, 30 | checkpoint_every=100, print_every=100, output_vocab=None): 31 | 32 | self._trainer = "Self Critical Trainer" 33 | self.random_seed = random_seed 34 | if random_seed is not None: 35 | random.seed(random_seed) 36 | torch.manual_seed(random_seed) 37 | self.loss = loss 38 | self.evaluator = Evaluator(loss=NLLLoss(), batch_size=batch_size) 39 | self.optimizer = None 40 | self.checkpoint_every = checkpoint_every 41 | self.print_every = print_every 42 | 43 | self.output_vocab = output_vocab 44 | 45 | if not os.path.isabs(expt_dir): 46 | expt_dir = os.path.join(os.getcwd(), expt_dir) 47 | self.expt_dir = expt_dir 48 | if not os.path.exists(self.expt_dir): 49 | os.makedirs(self.expt_dir) 50 | self.batch_size = batch_size 51 | 52 | self.logger = logging.getLogger(__name__) 53 | 54 | def _train_batch(self, input_variable, input_lengths, target_variable, model, teacher_forcing_ratio): 55 | loss = self.loss 56 | # Forward propagation 57 | decoder_outputs, decoder_hidden, other = model(input_variable, input_lengths, target_variable, teacher_forcing_ratio=teacher_forcing_ratio) 58 | # Get loss 59 | loss.reset() 60 | 61 | seqlist = [] 62 | tensorlist = [] 63 | 64 | for step, step_output in enumerate(decoder_outputs): 65 | #batch_size = target_variable.size(0) 66 | #loss.eval_batch(step_output.contiguous().view(batch_size, -1), target_variable[:, step + 1]) 67 | #step_output.contiguous().view(batch_size, -1) 68 | tensorlist.append(torch.max(step_output, dim=1)[0]) 69 | seqlist.append(torch.max(step_output, dim=1)[1]) 70 | 71 | log_tensor = torch.stack(tensorlist, dim=1) 72 | output_tensor = torch.stack(seqlist, dim=1) 73 | 74 | loss.eval_batch(log_tensor, output_tensor, target_variable, self.output_vocab) 75 | 76 | #print(len(decoder_outputs), len(target_variable)) 77 | #print(decoder_outputs[0].shape, decoder_outputs[0], target_variable[0]) 78 | #print(decoder_outputs[0].contiguous().view(32,-1), decoder_outputs[0].contiguous().view(32,-1).shape) 79 | #print(target_variable) 80 | 81 | # Backward propagation 82 | model.zero_grad() 83 | loss.backward() 84 | self.optimizer.step() 85 | 86 | return loss.get_loss() 87 | 88 | def _train_epoches(self, data, model, n_epochs, start_epoch, start_step, 89 | dev_data=None, teacher_forcing_ratio=0): 90 | log = self.logger 91 | 92 | print_loss_total = 0 # Reset every print_every 93 | epoch_loss_total = 0 # Reset every epoch 94 | 95 | device = None if torch.cuda.is_available() else -1 96 | batch_iterator = torchtext.data.BucketIterator( 97 | dataset=data, batch_size=self.batch_size, 98 | sort=False, sort_within_batch=True, 99 | sort_key=lambda x: len(x.src), 100 | device=device, repeat=False, shuffle=True) 101 | 102 | steps_per_epoch = len(batch_iterator) 103 | total_steps = steps_per_epoch * n_epochs 104 | 105 | step = start_step 106 | step_elapsed = 0 107 | start_time = time.time() 108 | for epoch in range(start_epoch, n_epochs + 1): 109 | print('epoch: {}, time: {}'.format(epoch, (time.time()-start_time)/60)) 110 | log.debug("Epoch: %d, Step: %d" % (epoch, step)) 111 | 112 | batch_generator = batch_iterator.__iter__() 113 | # consuming seen batches from previous training 114 | for _ in range((epoch - 1) * steps_per_epoch, step): 115 | next(batch_generator) 116 | 117 | model.train(True) 118 | for batch in batch_generator: 119 | step += 1 120 | step_elapsed += 1 121 | 122 | input_variables, input_lengths = getattr(batch, seq2seq.src_field_name) 123 | target_variables = getattr(batch, seq2seq.tgt_field_name) 124 | 125 | loss = self._train_batch(input_variables, input_lengths.tolist(), target_variables, model, teacher_forcing_ratio) 126 | 127 | # Record average loss 128 | print_loss_total += loss 129 | epoch_loss_total += loss 130 | 131 | if step % self.print_every == 0 and step_elapsed > self.print_every: 132 | print_loss_avg = print_loss_total / self.print_every 133 | print_loss_total = 0 134 | log_msg = 'Progress: %d%%, Train %s: %.4f' % ( 135 | step / total_steps * 100, 136 | self.loss.name, 137 | print_loss_avg) 138 | log.info(log_msg) 139 | print(log_msg) 140 | 141 | # Checkpoint 142 | if step % self.checkpoint_every == 0 or step == total_steps: 143 | Checkpoint(model=model, 144 | optimizer=self.optimizer, 145 | epoch=epoch, step=step, 146 | input_vocab=data.fields[seq2seq.src_field_name].vocab, 147 | output_vocab=data.fields[seq2seq.tgt_field_name].vocab).save(self.expt_dir) 148 | 149 | if step_elapsed == 0: continue 150 | 151 | epoch_loss_avg = epoch_loss_total / min(steps_per_epoch, step - start_step) 152 | epoch_loss_total = 0 153 | log_msg = "Finished epoch %d: Train %s: %.4f" % (epoch, self.loss.name, epoch_loss_avg) 154 | if dev_data is not None: 155 | #dev_loss = self.evaluator.evaluate_reward(model, dev_data, self.output_vocab) 156 | dev_loss, accuracy = self.evaluator.evaluate(model, dev_data) 157 | self.optimizer.update(dev_loss, epoch) 158 | log_msg += ", Dev %s: %.4f, Accuracy: %.4f" % (self.loss.name, dev_loss, accuracy) 159 | model.train(mode=True) 160 | else: 161 | self.optimizer.update(epoch_loss_avg, epoch) 162 | 163 | log.info(log_msg) 164 | print(log_msg) 165 | 166 | def train(self, model, data, num_epochs=1, 167 | resume=False, dev_data=None, 168 | optimizer=None, teacher_forcing_ratio=0): 169 | """ Run training for a given model. 170 | 171 | Args: 172 | model (seq2seq.models): model to run training on, if `resume=True`, it would be 173 | overwritten by the model loaded from the latest checkpoint. 174 | data (seq2seq.dataset.dataset.Dataset): dataset object to train on 175 | num_epochs (int, optional): number of epochs to run (default 5) 176 | resume(bool, optional): resume training with the latest checkpoint, (default False) 177 | dev_data (seq2seq.dataset.dataset.Dataset, optional): dev Dataset (default None) 178 | optimizer (seq2seq.optim.Optimizer, optional): optimizer for training 179 | (default: Optimizer(pytorch.optim.Adam, max_grad_norm=5)) 180 | teacher_forcing_ratio (float, optional): teaching forcing ratio (default 0) 181 | Returns: 182 | model (seq2seq.models): trained model. 183 | """ 184 | # If training is set to resume 185 | if resume: 186 | latest_checkpoint_path = Checkpoint.get_latest_checkpoint(self.expt_dir) 187 | resume_checkpoint = Checkpoint.load(latest_checkpoint_path) 188 | model = resume_checkpoint.model 189 | self.optimizer = resume_checkpoint.optimizer 190 | 191 | # A walk around to set optimizing parameters properly 192 | resume_optim = self.optimizer.optimizer 193 | defaults = resume_optim.param_groups[0] 194 | defaults.pop('params', None) 195 | defaults.pop('initial_lr', None) 196 | self.optimizer.optimizer = resume_optim.__class__(model.parameters(), **defaults) 197 | 198 | start_epoch = resume_checkpoint.epoch 199 | step = resume_checkpoint.step 200 | else: 201 | start_epoch = 1 202 | step = 0 203 | if optimizer is None: 204 | optimizer = Optimizer(optim.Adam(model.parameters()), max_grad_norm=5) 205 | self.optimizer = optimizer 206 | 207 | self.logger.info("Optimizer: %s, Scheduler: %s" % (self.optimizer.optimizer, self.optimizer.scheduler)) 208 | 209 | self._train_epoches(data, model, num_epochs, 210 | start_epoch, step, dev_data=dev_data, teacher_forcing_ratio=teacher_forcing_ratio) 211 | return model 212 | -------------------------------------------------------------------------------- /seq2seq/trainer/self_critical_trainer.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/trainer/self_critical_trainer.pyc -------------------------------------------------------------------------------- /seq2seq/trainer/supervised_trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import logging 3 | import os 4 | import random 5 | import time 6 | 7 | import torch 8 | import torchtext 9 | from torch import optim 10 | 11 | import seq2seq 12 | from seq2seq.evaluator import Evaluator 13 | from seq2seq.loss import NLLLoss 14 | from seq2seq.optim import Optimizer 15 | from seq2seq.util.checkpoint import Checkpoint 16 | 17 | class SupervisedTrainer(object): 18 | """ The SupervisedTrainer class helps in setting up a training framework in a 19 | supervised setting. 20 | 21 | Args: 22 | expt_dir (optional, str): experiment Directory to store details of the experiment, 23 | by default it makes a folder in the current directory to store the details (default: `experiment`). 24 | loss (seq2seq.loss.loss.Loss, optional): loss for training, (default: seq2seq.loss.NLLLoss) 25 | batch_size (int, optional): batch size for experiment, (default: 64) 26 | checkpoint_every (int, optional): number of batches to checkpoint after, (default: 100) 27 | """ 28 | def __init__(self, expt_dir='experiment', loss=NLLLoss(), batch_size=64, 29 | random_seed=None, 30 | checkpoint_every=100, print_every=100): 31 | self._trainer = "Simple Trainer" 32 | self.random_seed = random_seed 33 | if random_seed is not None: 34 | random.seed(random_seed) 35 | torch.manual_seed(random_seed) 36 | self.loss = loss 37 | self.evaluator = Evaluator(loss=self.loss, batch_size=batch_size) 38 | self.optimizer = None 39 | self.checkpoint_every = checkpoint_every 40 | self.print_every = print_every 41 | 42 | if not os.path.isabs(expt_dir): 43 | expt_dir = os.path.join(os.getcwd(), expt_dir) 44 | self.expt_dir = expt_dir 45 | if not os.path.exists(self.expt_dir): 46 | os.makedirs(self.expt_dir) 47 | self.batch_size = batch_size 48 | 49 | self.logger = logging.getLogger(__name__) 50 | 51 | def _train_batch(self, input_variable, input_lengths, target_variable, model, teacher_forcing_ratio): 52 | loss = self.loss 53 | # Forward propagation 54 | decoder_outputs, decoder_hidden, other = model(input_variable, input_lengths, target_variable, 55 | teacher_forcing_ratio=teacher_forcing_ratio) 56 | # Get loss 57 | loss.reset() 58 | for step, step_output in enumerate(decoder_outputs): 59 | batch_size = target_variable.size(0) 60 | loss.eval_batch(step_output.contiguous().view(batch_size, -1), target_variable[:, step + 1]) 61 | # Backward propagation 62 | model.zero_grad() 63 | loss.backward() 64 | self.optimizer.step() 65 | 66 | return loss.get_loss() 67 | 68 | def _train_epoches(self, data, model, n_epochs, start_epoch, start_step, 69 | dev_data=None, teacher_forcing_ratio=0): 70 | log = self.logger 71 | 72 | print_loss_total = 0 # Reset every print_every 73 | epoch_loss_total = 0 # Reset every epoch 74 | 75 | device = None if torch.cuda.is_available() else -1 76 | batch_iterator = torchtext.data.BucketIterator( 77 | dataset=data, batch_size=self.batch_size, 78 | sort=False, sort_within_batch=True, 79 | sort_key=lambda x: len(x.src), 80 | device=device, repeat=False, shuffle=True) 81 | 82 | steps_per_epoch = len(batch_iterator) 83 | total_steps = steps_per_epoch * n_epochs 84 | 85 | step = start_step 86 | step_elapsed = 0 87 | for epoch in range(start_epoch, n_epochs + 1): 88 | log.debug("Epoch: %d, Step: %d" % (epoch, step)) 89 | 90 | batch_generator = batch_iterator.__iter__() 91 | # consuming seen batches from previous training 92 | for _ in range((epoch - 1) * steps_per_epoch, step): 93 | next(batch_generator) 94 | 95 | model.train(True) 96 | for batch in batch_generator: 97 | step += 1 98 | step_elapsed += 1 99 | 100 | input_variables, input_lengths = getattr(batch, seq2seq.src_field_name) 101 | target_variables = getattr(batch, seq2seq.tgt_field_name) 102 | 103 | loss = self._train_batch(input_variables, input_lengths.tolist(), target_variables, model, teacher_forcing_ratio) 104 | 105 | # Record average loss 106 | print_loss_total += loss 107 | epoch_loss_total += loss 108 | 109 | if step % self.print_every == 0 and step_elapsed > self.print_every: 110 | print_loss_avg = print_loss_total / self.print_every 111 | print_loss_total = 0 112 | log_msg = 'Progress: %d%%, Train %s: %.4f' % ( 113 | step / total_steps * 100, 114 | self.loss.name, 115 | print_loss_avg) 116 | log.info(log_msg) 117 | 118 | # Checkpoint 119 | if step % self.checkpoint_every == 0 or step == total_steps: 120 | Checkpoint(model=model, 121 | optimizer=self.optimizer, 122 | epoch=epoch, step=step, 123 | input_vocab=data.fields[seq2seq.src_field_name].vocab, 124 | output_vocab=data.fields[seq2seq.tgt_field_name].vocab).save(self.expt_dir) 125 | 126 | if step_elapsed == 0: continue 127 | 128 | epoch_loss_avg = epoch_loss_total / min(steps_per_epoch, step - start_step) 129 | epoch_loss_total = 0 130 | log_msg = "Finished epoch %d: Train %s: %.4f" % (epoch, self.loss.name, epoch_loss_avg) 131 | if dev_data is not None: 132 | dev_loss, accuracy = self.evaluator.evaluate(model, dev_data) 133 | self.optimizer.update(dev_loss, epoch) 134 | log_msg += ", Dev %s: %.4f, Accuracy: %.4f" % (self.loss.name, dev_loss, accuracy) 135 | model.train(mode=True) 136 | else: 137 | self.optimizer.update(epoch_loss_avg, epoch) 138 | 139 | log.info(log_msg) 140 | print(log_msg) 141 | 142 | def train(self, model, data, num_epochs=5, 143 | resume=False, dev_data=None, 144 | optimizer=None, teacher_forcing_ratio=0): 145 | """ Run training for a given model. 146 | 147 | Args: 148 | model (seq2seq.models): model to run training on, if `resume=True`, it would be 149 | overwritten by the model loaded from the latest checkpoint. 150 | data (seq2seq.dataset.dataset.Dataset): dataset object to train on 151 | num_epochs (int, optional): number of epochs to run (default 5) 152 | resume(bool, optional): resume training with the latest checkpoint, (default False) 153 | dev_data (seq2seq.dataset.dataset.Dataset, optional): dev Dataset (default None) 154 | optimizer (seq2seq.optim.Optimizer, optional): optimizer for training 155 | (default: Optimizer(pytorch.optim.Adam, max_grad_norm=5)) 156 | teacher_forcing_ratio (float, optional): teaching forcing ratio (default 0) 157 | Returns: 158 | model (seq2seq.models): trained model. 159 | """ 160 | # If training is set to resume 161 | if resume: 162 | latest_checkpoint_path = Checkpoint.get_latest_checkpoint(self.expt_dir) 163 | resume_checkpoint = Checkpoint.load(latest_checkpoint_path) 164 | model = resume_checkpoint.model 165 | self.optimizer = resume_checkpoint.optimizer 166 | 167 | # A walk around to set optimizing parameters properly 168 | resume_optim = self.optimizer.optimizer 169 | defaults = resume_optim.param_groups[0] 170 | defaults.pop('params', None) 171 | defaults.pop('initial_lr', None) 172 | self.optimizer.optimizer = resume_optim.__class__(model.parameters(), **defaults) 173 | 174 | start_epoch = resume_checkpoint.epoch 175 | step = resume_checkpoint.step 176 | else: 177 | start_epoch = 1 178 | step = 0 179 | if optimizer is None: 180 | optimizer = Optimizer(optim.Adam(model.parameters()), max_grad_norm=5) 181 | self.optimizer = optimizer 182 | 183 | self.logger.info("Optimizer: %s, Scheduler: %s" % (self.optimizer.optimizer, self.optimizer.scheduler)) 184 | 185 | self._train_epoches(data, model, num_epochs, 186 | start_epoch, step, dev_data=dev_data, 187 | teacher_forcing_ratio=teacher_forcing_ratio) 188 | return model 189 | -------------------------------------------------------------------------------- /seq2seq/trainer/supervised_trainer.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/trainer/supervised_trainer.pyc -------------------------------------------------------------------------------- /seq2seq/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/util/__init__.py -------------------------------------------------------------------------------- /seq2seq/util/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/util/__init__.pyc -------------------------------------------------------------------------------- /seq2seq/util/checkpoint.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import time 4 | import shutil 5 | 6 | import torch 7 | import dill 8 | 9 | class Checkpoint(object): 10 | """ 11 | The Checkpoint class manages the saving and loading of a model during training. It allows training to be suspended 12 | and resumed at a later time (e.g. when running on a cluster using sequential jobs). 13 | 14 | To make a checkpoint, initialize a Checkpoint object with the following args; then call that object's save() method 15 | to write parameters to disk. 16 | 17 | Args: 18 | model (seq2seq): seq2seq model being trained 19 | optimizer (Optimizer): stores the state of the optimizer 20 | epoch (int): current epoch (an epoch is a loop through the full training data) 21 | step (int): number of examples seen within the current epoch 22 | input_vocab (Vocabulary): vocabulary for the input language 23 | output_vocab (Vocabulary): vocabulary for the output language 24 | 25 | Attributes: 26 | CHECKPOINT_DIR_NAME (str): name of the checkpoint directory 27 | TRAINER_STATE_NAME (str): name of the file storing trainer states 28 | MODEL_NAME (str): name of the file storing model 29 | INPUT_VOCAB_FILE (str): name of the input vocab file 30 | OUTPUT_VOCAB_FILE (str): name of the output vocab file 31 | """ 32 | 33 | CHECKPOINT_DIR_NAME = 'checkpoints' 34 | TRAINER_STATE_NAME = 'trainer_states.pt' 35 | MODEL_NAME = 'model.pt' 36 | INPUT_VOCAB_FILE = 'input_vocab.pt' 37 | OUTPUT_VOCAB_FILE = 'output_vocab.pt' 38 | 39 | def __init__(self, model, optimizer, epoch, step, input_vocab, output_vocab, path=None): 40 | self.model = model 41 | self.optimizer = optimizer 42 | self.input_vocab = input_vocab 43 | self.output_vocab = output_vocab 44 | self.epoch = epoch 45 | self.step = step 46 | self._path = path 47 | 48 | @property 49 | def path(self): 50 | if self._path is None: 51 | raise LookupError("The checkpoint has not been saved.") 52 | return self._path 53 | 54 | def save(self, experiment_dir): 55 | """ 56 | Saves the current model and related training parameters into a subdirectory of the checkpoint directory. 57 | The name of the subdirectory is the current local time in Y_M_D_H_M_S format. 58 | Args: 59 | experiment_dir (str): path to the experiment root directory 60 | Returns: 61 | str: path to the saved checkpoint subdirectory 62 | """ 63 | date_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime()) 64 | 65 | self._path = os.path.join(experiment_dir, self.CHECKPOINT_DIR_NAME, date_time) 66 | path = self._path 67 | 68 | if os.path.exists(path): 69 | shutil.rmtree(path) 70 | os.makedirs(path) 71 | torch.save({'epoch': self.epoch, 72 | 'step': self.step, 73 | 'optimizer': self.optimizer 74 | }, 75 | os.path.join(path, self.TRAINER_STATE_NAME)) 76 | torch.save(self.model, os.path.join(path, self.MODEL_NAME)) 77 | 78 | with open(os.path.join(path, self.INPUT_VOCAB_FILE), 'wb') as fout: 79 | dill.dump(self.input_vocab, fout) 80 | with open(os.path.join(path, self.OUTPUT_VOCAB_FILE), 'wb') as fout: 81 | dill.dump(self.output_vocab, fout) 82 | 83 | return path 84 | 85 | @classmethod 86 | def load(cls, path): 87 | """ 88 | Loads a Checkpoint object that was previously saved to disk. 89 | Args: 90 | path (str): path to the checkpoint subdirectory 91 | Returns: 92 | checkpoint (Checkpoint): checkpoint object with fields copied from those stored on disk 93 | """ 94 | if torch.cuda.is_available(): 95 | resume_checkpoint = torch.load(os.path.join(path, cls.TRAINER_STATE_NAME)) 96 | model = torch.load(os.path.join(path, cls.MODEL_NAME)) 97 | else: 98 | resume_checkpoint = torch.load(os.path.join(path, cls.TRAINER_STATE_NAME), map_location=lambda storage, loc: storage) 99 | model = torch.load(os.path.join(path, cls.MODEL_NAME), map_location=lambda storage, loc: storage) 100 | 101 | #model.flatten_parameters() # make RNN parameters contiguous 102 | with open(os.path.join(path, cls.INPUT_VOCAB_FILE), 'rb') as fin: 103 | input_vocab = dill.load(fin) 104 | with open(os.path.join(path, cls.OUTPUT_VOCAB_FILE), 'rb') as fin: 105 | output_vocab = dill.load(fin) 106 | optimizer = resume_checkpoint['optimizer'] 107 | return Checkpoint(model=model, input_vocab=input_vocab, 108 | output_vocab=output_vocab, 109 | optimizer=optimizer, 110 | epoch=resume_checkpoint['epoch'], 111 | step=resume_checkpoint['step'], 112 | path=path) 113 | 114 | @classmethod 115 | def get_latest_checkpoint(cls, experiment_path): 116 | """ 117 | Given the path to an experiment directory, returns the path to the last saved checkpoint's subdirectory. 118 | 119 | Precondition: at least one checkpoint has been made (i.e., latest checkpoint subdirectory exists). 120 | Args: 121 | experiment_path (str): path to the experiment directory 122 | Returns: 123 | str: path to the last saved checkpoint's subdirectory 124 | """ 125 | checkpoints_path = os.path.join(experiment_path, cls.CHECKPOINT_DIR_NAME) 126 | all_times = sorted(os.listdir(checkpoints_path), reverse=True) 127 | return os.path.join(checkpoints_path, all_times[0]) 128 | -------------------------------------------------------------------------------- /seq2seq/util/checkpoint.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacger2/softregex/34742a197cbfdc248ea01de4df76488ecd3b5ce4/seq2seq/util/checkpoint.pyc -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | # To use a consistent encoding 3 | from codecs import open 4 | from os import path 5 | 6 | here = path.abspath(path.dirname(__file__)) 7 | 8 | # Get the long description from the README file 9 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 10 | long_description = f.read() 11 | 12 | setup( 13 | name='seq2seq', 14 | 15 | # Versions should comply with PEP440. For a discussion on single-sourcing 16 | # the version across setup.py and the project code, see 17 | # https://packaging.python.org/en/latest/single_source_version.html 18 | version='0.1.6', 19 | 20 | description='A framework for sequence-to-sequence (seq2seq) models implemented in PyTorch.', 21 | long_description=long_description, 22 | 23 | # The project's main homepage. 24 | url='https://github.com/IBM/pytorch-seq2seq', 25 | 26 | # Choose your license 27 | license='Apache License 2.0', 28 | 29 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers 30 | classifiers=[ 31 | # How mature is this project? Common values are 32 | # 3 - Alpha 33 | # 4 - Beta 34 | # 5 - Production/Stable 35 | 'Development Status :: 3 - Alpha', 36 | 37 | # Indicate who your project is intended for 38 | 'Intended Audience :: Research', 39 | 'Topic :: Software Development', 40 | 41 | # Pick your license as you wish (should match "license" above) 42 | 'License :: Apache License 2.0', 43 | 44 | # Specify the Python versions you support here. In particular, ensure 45 | # that you indicate whether you support Python 2, Python 3 or both. 46 | 'Programming Language :: Python :: 2.7', 47 | 'Programming Language :: Python :: 3.6' 48 | ], 49 | 50 | # What does your project relate to? 51 | keywords='seq2seq py-torch development', 52 | 53 | # You can just specify the packages manually here if your project is 54 | # simple. Or you can use find_packages(). 55 | packages=find_packages(exclude=['contrib', 'docs', 'tests']), 56 | 57 | # Alternatively, if you want to distribute just a my_module.py, uncomment 58 | # this: 59 | # py_modules=["my_module"], 60 | 61 | # List run-time dependencies here. These will be installed by pip when 62 | # your project is installed. For an analysis of "install_requires" vs pip's 63 | # requirements files see: 64 | # https://packaging.python.org/en/latest/requirements.html 65 | install_requires=['numpy', 'torch', 'torchtext'], 66 | 67 | # List additional groups of dependencies here (e.g. development 68 | # dependencies). You can install these using the following syntax, 69 | # for example: 70 | # $ pip install -e .[dev,test] 71 | extras_require={ 72 | 'dev': ['check-manifest'], 73 | 'test': ['coverage'], 74 | } 75 | ) 76 | -------------------------------------------------------------------------------- /softregex-eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # Evaluation 5 | 6 | # In[14]: 7 | 8 | 9 | import os 10 | import argparse 11 | import logging 12 | 13 | import torch 14 | from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau 15 | import torchtext 16 | 17 | import seq2seq 18 | 19 | from seq2seq.trainer import SupervisedTrainer, SelfCriticalTrainer 20 | from seq2seq.models import EncoderRNN, DecoderRNN, Seq2seq, TopKDecoder 21 | from seq2seq.loss import Perplexity, NLLLoss, PositiveLoss 22 | from seq2seq.optim import Optimizer 23 | from seq2seq.dataset import SourceField, TargetField 24 | from seq2seq.evaluator import Predictor, Evaluator 25 | from seq2seq.util.checkpoint import Checkpoint 26 | import torch.nn.functional as F 27 | 28 | import subprocess 29 | import sys 30 | 31 | import warnings 32 | warnings.filterwarnings('ignore') 33 | 34 | try: 35 | raw_input # Python 2 36 | except NameError: 37 | raw_input = input # Python 3 38 | 39 | # Prepare dataset 40 | src = SourceField() 41 | tgt = TargetField() 42 | 43 | # data/kb/train/data.txt 44 | #data/NL-RX-Synth/train/data.txt 45 | #data/NL-RX-Turk/train/data.txt 46 | 47 | 48 | dataset = 'kb13' 49 | 50 | if len(sys.argv) < 1: 51 | sys.exit(-1) 52 | 53 | dataset = sys.argv[1] 54 | 55 | 56 | datasets = { 57 | 'kb13': ('KB13', 30, 60), 58 | 'NL-RX-Synth': ('NL-RX-Synth', 10, 40), 59 | 'NL-RX-Turk': ('NL-RX-Turk', 10, 40) 60 | } 61 | 62 | 63 | data_tuple = datasets[dataset] 64 | 65 | # max_len = 60 66 | max_len = data_tuple[2] 67 | def len_filter(example): 68 | return len(example.src) <= max_len and len(example.tgt) <= max_len 69 | train = torchtext.data.TabularDataset( 70 | path='data/' + data_tuple[0] + '/train/data.txt', format='tsv', 71 | fields=[('src', src), ('tgt', tgt)], 72 | filter_pred=len_filter 73 | ) 74 | dev = torchtext.data.TabularDataset( 75 | path='data/' + data_tuple[0] + '/val/data.txt', format='tsv', 76 | fields=[('src', src), ('tgt', tgt)], 77 | filter_pred=len_filter 78 | ) 79 | test = torchtext.data.TabularDataset( 80 | path='data/' + data_tuple[0] + '/test/data.txt', format='tsv', 81 | fields=[('src', src), ('tgt', tgt)], 82 | filter_pred=len_filter 83 | ) 84 | src.build_vocab(train, max_size=500) 85 | tgt.build_vocab(train, max_size=500) 86 | input_vocab = src.vocab 87 | output_vocab = tgt.vocab 88 | 89 | # Prepare loss 90 | weight = torch.ones(len(tgt.vocab)) 91 | pad = tgt.vocab.stoi[tgt.pad_token] 92 | 93 | loss = NLLLoss(weight, pad) 94 | 95 | if torch.cuda.is_available(): 96 | loss.cuda() 97 | 98 | seq2seq_model = None 99 | optimizer = None 100 | 101 | 102 | # In[15]: 103 | 104 | 105 | def decode_tensor(tensor, vocab): 106 | tensor = tensor.view(-1) 107 | words = [] 108 | for i in tensor: 109 | word = vocab.itos[i.cpu().numpy()] 110 | if word == '': 111 | return ' '.join(words) 112 | if word != '' and word != '' and word != '': 113 | words.append(word) 114 | #if word != '': 115 | # words.append(word) 116 | #print('|' + word + '|') 117 | return ' '.join(words) 118 | 119 | 120 | # In[16]: 121 | 122 | 123 | from regexDFAEquals import regex_equiv_from_raw, unprocess_regex, regex_equiv 124 | 125 | 126 | # In[17]: 127 | 128 | 129 | batch_size = 1 130 | 131 | 132 | # In[18]: 133 | 134 | 135 | hidden_size = 256 136 | word_embedding_size = 128 137 | 138 | bidirectional = True 139 | 140 | encoder = EncoderRNN(len(src.vocab), max_len, hidden_size, dropout_p=0.1,rnn_cell='lstm', 141 | bidirectional=bidirectional, n_layers=2, variable_lengths=True) 142 | decoder = DecoderRNN(len(tgt.vocab), max_len, hidden_size * 2 if bidirectional else hidden_size,rnn_cell='lstm', 143 | dropout_p=0.25, use_attention=True, bidirectional=bidirectional, n_layers=2, 144 | eos_id=tgt.eos_id, sos_id=tgt.sos_id) 145 | 146 | seq2seq_model = Seq2seq(encoder, decoder) 147 | if torch.cuda.is_available(): 148 | seq2seq_model.cuda() 149 | 150 | for param in seq2seq_model.parameters(): 151 | param.data.uniform_(-0.1, 0.1) 152 | 153 | 154 | optimizer = Optimizer(torch.optim.Adam(seq2seq_model.parameters()), max_grad_norm=5) 155 | 156 | 157 | t = SupervisedTrainer(loss=loss, batch_size=8, 158 | checkpoint_every=100, 159 | print_every=10000, expt_dir='./lstm_model/'+data_tuple[0]+'/Deepregex') 160 | 161 | seq2seq_model = torch.nn.DataParallel(seq2seq_model) 162 | 163 | seq2seq_model = t.train(seq2seq_model, train, 164 | num_epochs=1, dev_data=dev, 165 | optimizer=optimizer, 166 | teacher_forcing_ratio=0.5, 167 | resume=True) 168 | 169 | optimizer_new = Optimizer(torch.optim.Adadelta(seq2seq_model.parameters(), lr=0.05)) 170 | 171 | 172 | sc_t = SelfCriticalTrainer(loss=PositiveLoss(mode='prob', prob_model=None, loss_vocab=None), batch_size=32, 173 | checkpoint_every=200000, print_every=100, expt_dir='./lstm_model/'+data_tuple[0]+'/SoftRegex', output_vocab=output_vocab) 174 | 175 | 176 | 177 | seq2seq_model = sc_t.train(seq2seq_model, train, 178 | num_epochs=1, dev_data=dev, 179 | optimizer=optimizer_new, teacher_forcing_ratio=0.5, 180 | resume=True) 181 | 182 | 183 | data = test 184 | 185 | 186 | # In[19]: 187 | 188 | 189 | seq2seq_model.eval() 190 | 191 | loss.reset() 192 | match = 0 193 | total = 0 194 | 195 | device = None if torch.cuda.is_available() else -1 196 | batch_iterator = torchtext.data.BucketIterator( 197 | dataset=data, batch_size=batch_size, 198 | sort=False, sort_key=lambda x: len(x.src), 199 | device=device, train=False) 200 | tgt_vocab = data.fields[seq2seq.tgt_field_name].vocab 201 | pad = tgt_vocab.stoi[data.fields[seq2seq.tgt_field_name].pad_token] 202 | 203 | 204 | # In[20]: 205 | 206 | 207 | def refine_outout(regex): 208 | par_list = [] 209 | word_list = regex.split() 210 | 211 | for idx, word in enumerate(word_list): 212 | if word == '(' or word == '[' or word == '{': 213 | par_list.append(word) 214 | 215 | if word == ')' or word == ']' or word == '}': 216 | if len(par_list) == 0: 217 | word_list[idx] = '' 218 | continue 219 | 220 | par_in_list = par_list.pop() 221 | if par_in_list == '(': 222 | word_list[idx] = ')' 223 | elif par_in_list == '[': 224 | word_list[idx] = ']' 225 | elif par_in_list == '{': 226 | word_list[idx] = '}' 227 | 228 | while len(par_list) != 0: 229 | par_in_list = par_list.pop() 230 | if par_in_list == '(': 231 | word_list.append(')') 232 | elif par_in_list == '[': 233 | word_list.append(']') 234 | elif par_in_list == '{': 235 | word_list.append('}') 236 | 237 | word_list = [word for word in word_list if word != ''] 238 | 239 | return ' '.join(word_list) 240 | 241 | 242 | # In[22]: 243 | 244 | 245 | predictor = Predictor(seq2seq_model, input_vocab, output_vocab) 246 | 247 | num_samples = 0 248 | perfect_samples = 0 249 | dfa_perfect_samples = 0 250 | 251 | match = 0 252 | total = 0 253 | 254 | 255 | model_correct = 0 256 | model_wrong = 0 257 | 258 | with torch.no_grad(): 259 | for batch in batch_iterator: 260 | num_samples = num_samples + 1 261 | 262 | input_variables, input_lengths = getattr(batch, seq2seq.src_field_name) 263 | 264 | target_variables = getattr(batch, seq2seq.tgt_field_name) 265 | 266 | 267 | 268 | target_string = decode_tensor(target_variables, output_vocab) 269 | 270 | 271 | input_string = decode_tensor(input_variables, input_vocab) 272 | 273 | generated_string = ' '.join([x for x in predictor.predict(input_string.strip().split())[:-1] if x != '']) 274 | 275 | print("Input string: ", input_string) 276 | print("Targ : ", target_string) 277 | print("Pred : ", refine_outout(generated_string)) 278 | 279 | 280 | generated_string = refine_outout(generated_string) 281 | 282 | 283 | pos_example = subprocess.check_output(['python2', 'regexDFAEquals.py', '--gold', '{}'.format(target_string), '--predicted', '{}'.format(generated_string)]) 284 | 285 | if target_string == generated_string: 286 | perfect_samples = perfect_samples + 1 287 | dfa_perfect_samples = dfa_perfect_samples + 1 288 | print('String Equivalent') 289 | elif str(pos_example)[2] == '1': 290 | print("DFA Equivalent") 291 | dfa_perfect_samples = dfa_perfect_samples + 1 292 | 293 | else: 294 | print("DFA not Equivalent") 295 | 296 | 297 | target_tokens = target_string.split() 298 | generated_tokens = generated_string.split() 299 | 300 | shorter_len = min(len(target_tokens), len(generated_tokens)) 301 | 302 | for idx in range(len(generated_tokens)): 303 | total = total + 1 304 | 305 | if idx >= len(target_tokens): 306 | total = total + 1 307 | elif target_tokens[idx] == generated_tokens[idx]: 308 | match = match + 1 309 | 310 | 311 | if total == 0: 312 | accuracy = float('nan') 313 | else: 314 | accuracy = match / total 315 | 316 | string_accuracy = perfect_samples / num_samples 317 | dfa_accuracy = dfa_perfect_samples /num_samples 318 | 319 | print("Iterations ",num_samples, " : ", "{0:.3f}".format(dfa_accuracy) + '\n') 320 | 321 | 322 | # In[ ]: 323 | 324 | 325 | 326 | 327 | 328 | # In[ ]: 329 | 330 | 331 | 332 | 333 | -------------------------------------------------------------------------------- /softregex-train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import os 8 | import argparse 9 | import logging 10 | 11 | import torch 12 | from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau 13 | import torchtext 14 | 15 | import seq2seq 16 | 17 | from seq2seq.trainer import SupervisedTrainer, SelfCriticalTrainer 18 | from seq2seq.models import EncoderRNN, DecoderRNN, Seq2seq, TopKDecoder 19 | from seq2seq.loss import Perplexity, NLLLoss, PositiveLoss 20 | from seq2seq.optim import Optimizer 21 | from seq2seq.dataset import SourceField, TargetField 22 | from seq2seq.evaluator import Predictor, Evaluator 23 | from seq2seq.util.checkpoint import Checkpoint 24 | import torch.nn.functional as F 25 | import sys 26 | 27 | 28 | # In[13]: 29 | 30 | dataset = 'kb13' 31 | 32 | if len(sys.argv) < 1: 33 | sys.exit(-1) 34 | 35 | dataset = sys.argv[1] 36 | 37 | import warnings 38 | warnings.filterwarnings('ignore') 39 | 40 | 41 | # In[14]: 42 | 43 | 44 | try: 45 | raw_input # Python 2 46 | except NameError: 47 | raw_input = input # Python 3 48 | 49 | 50 | # In[17]: 51 | 52 | 53 | # Prepare dataset 54 | src = SourceField() 55 | tgt = TargetField() 56 | 57 | # data/kb/train/data.txt 58 | #data/NL-RX-Synth/train/data.txt 59 | #data/NL-RX-Turk/train/data.txt 60 | 61 | datasets = { 62 | 'kb13': ('KB13', 35, 60), 63 | 'NL-RX-Synth': ('NL-RX-Synth', 10, 40), 64 | 'NL-RX-Turk': ('NL-RX-Turk', 10, 40) 65 | } 66 | 67 | data_tuple = datasets[dataset] 68 | 69 | # max_len = 60 70 | max_len = data_tuple[2] 71 | def len_filter(example): 72 | return len(example.src) <= max_len and len(example.tgt) <= max_len 73 | train = torchtext.data.TabularDataset( 74 | path='data/' + data_tuple[0] + '/train/data.txt', format='tsv', 75 | fields=[('src', src), ('tgt', tgt)], 76 | filter_pred=len_filter 77 | ) 78 | dev = torchtext.data.TabularDataset( 79 | path='data/' + data_tuple[0] + '/val/data.txt', format='tsv', 80 | fields=[('src', src), ('tgt', tgt)], 81 | filter_pred=len_filter 82 | ) 83 | test = torchtext.data.TabularDataset( 84 | path='data/' + data_tuple[0] + '/test/data.txt', format='tsv', 85 | fields=[('src', src), ('tgt', tgt)], 86 | filter_pred=len_filter 87 | ) 88 | src.build_vocab(train, max_size=500) 89 | tgt.build_vocab(train, max_size=500) 90 | input_vocab = src.vocab 91 | output_vocab = tgt.vocab 92 | 93 | 94 | # In[18]: 95 | 96 | 97 | # Prepare loss 98 | weight = torch.ones(len(tgt.vocab)) 99 | pad = tgt.vocab.stoi[tgt.pad_token] 100 | 101 | loss = NLLLoss(weight, pad) 102 | 103 | if torch.cuda.is_available(): 104 | loss.cuda() 105 | 106 | seq2seq_model = None 107 | optimizer = None 108 | 109 | 110 | # In[19]: 111 | 112 | 113 | hidden_size = 256 114 | word_embedding_size = 128 115 | 116 | bidirectional = True 117 | 118 | encoder = EncoderRNN(len(src.vocab), max_len, hidden_size, dropout_p=0.1,rnn_cell='lstm', 119 | bidirectional=bidirectional, n_layers=2, variable_lengths=True) 120 | decoder = DecoderRNN(len(tgt.vocab), max_len, hidden_size * 2 if bidirectional else hidden_size,rnn_cell='lstm', 121 | dropout_p=0.25, use_attention=True, bidirectional=bidirectional, n_layers=2, 122 | eos_id=tgt.eos_id, sos_id=tgt.sos_id) 123 | 124 | seq2seq_model = Seq2seq(encoder, decoder) 125 | if torch.cuda.is_available(): 126 | seq2seq_model.cuda() 127 | 128 | for param in seq2seq_model.parameters(): 129 | param.data.uniform_(-0.1, 0.1) 130 | 131 | 132 | optimizer = Optimizer(torch.optim.Adam(seq2seq_model.parameters()), max_grad_norm=5) 133 | 134 | 135 | # In[20]: 136 | 137 | 138 | seq2seq_model = torch.nn.DataParallel(seq2seq_model) 139 | 140 | 141 | # In[21]: 142 | 143 | 144 | # train 145 | 146 | t = SupervisedTrainer(loss=loss, batch_size=8, 147 | checkpoint_every=200, 148 | print_every=10000, expt_dir='./lstm_model/'+data_tuple[0]+'/Deepregex') 149 | 150 | 151 | # In[22]: 152 | 153 | 154 | seq2seq_model = t.train(seq2seq_model, train, 155 | num_epochs=data_tuple[1], dev_data=dev, 156 | optimizer=optimizer, 157 | teacher_forcing_ratio=0.5, 158 | resume=False) 159 | 160 | 161 | # ### Self Critical Training 162 | 163 | # In[23]: 164 | 165 | 166 | class compare_regex(torch.nn.Module): 167 | def __init__(self, vocab_size, embedding_dim, hidden_dim, target_size): 168 | super(compare_regex, self).__init__() 169 | self.hidden_dim = hidden_dim 170 | self.embedding_dim = embedding_dim 171 | self.embed = Embedding(vocab_size, embedding_dim, padding_idx=0) 172 | self.lstm1 = LSTM(embedding_dim ,hidden_dim, bidirectional=True, num_layers=1, batch_first=True) 173 | self.lstm2 = LSTM(embedding_dim, hidden_dim, bidirectional=True, num_layers=1, batch_first=True) 174 | self.fc1 = Linear(hidden_dim*2*2, 60) 175 | self.fc2 = Linear(60, 20) 176 | self.fc3 = Linear(20, target_size) 177 | 178 | 179 | def init_hidden(self, bs): 180 | if torch.cuda.is_available(): 181 | return (torch.zeros(2, bs, self.hidden_dim).cuda(), 182 | torch.zeros(2, bs, self.hidden_dim).cuda()) 183 | else: 184 | return (torch.zeros(2, bs, self.hidden_dim), 185 | torch.zeros(2, bs, self.hidden_dim)) 186 | 187 | def forward(self, bs, line1, line2, input1_lengths,input2_lengths): 188 | embeded1 = self.embed(line1) 189 | embeded2 = self.embed(line2) 190 | 191 | hidden1 = self.init_hidden(bs) 192 | lstm1_out, last_hidden1 = self.lstm1(embeded1,hidden1) 193 | hidden2 = self.init_hidden(bs) 194 | lstm2_out, last_hidden2 = self.lstm2(embeded2,hidden2) 195 | 196 | 197 | fc1_out = self.fc1(torch.cat((lstm1_out.mean(1), lstm2_out.mean(1)),1)) #encoder outputs 평균값 concat 97.8% 198 | 199 | 200 | fc1_out = F.tanh(fc1_out) 201 | fc2_out = self.fc2(fc1_out) 202 | fc2_out = F.tanh(fc2_out) 203 | fc3_out = self.fc3(fc2_out) 204 | score = F.log_softmax(fc3_out,dim=1) 205 | return score 206 | 207 | 208 | # In[24]: 209 | 210 | 211 | f = open('./regex_equal_model/compare_vocab.txt') 212 | sc_loss_vocab = dict() 213 | for line in f.read().splitlines(): 214 | line = line.split('\t') 215 | sc_loss_vocab[line[0]] = int(line[1]) 216 | f.close() 217 | compare_regex_model = torch.load('./regex_equal_model/compare_regex_model.pth') 218 | compare_regex_model.eval() 219 | 220 | 221 | # In[25]: 222 | 223 | 224 | optimizer_new = Optimizer(torch.optim.Adadelta(seq2seq_model.parameters(), lr=0.05)) 225 | 226 | #if you want to train by oracle, put mode to None 227 | sc_t = SelfCriticalTrainer(loss=PositiveLoss(mode='prob', prob_model=compare_regex_model, loss_vocab=sc_loss_vocab), batch_size=32, 228 | checkpoint_every=100, print_every=100, expt_dir='./lstm_model/'+data_tuple[0]+'/SoftRegex', output_vocab=output_vocab) 229 | 230 | 231 | 232 | seq2seq_model = sc_t.train(seq2seq_model, train, 233 | num_epochs=30, dev_data=dev, 234 | optimizer=optimizer_new, teacher_forcing_ratio=0.5, 235 | resume=False) 236 | 237 | 238 | # In[26]: 239 | 240 | 241 | evaluator = Evaluator() 242 | 243 | 244 | # In[27]: 245 | 246 | 247 | evaluator.evaluate(seq2seq_model, dev) # (5.799417234628771, 0.6468332123976366) --------------------------------------------------------------------------------